Move vocabulary building to Vocab class

Leaving nothing in the Field class
This commit is contained in:
Giovanni Campagna 2019-12-21 07:59:38 -08:00
parent cde004b3e4
commit ccd1fb6c80
5 changed files with 41 additions and 86 deletions

View File

@ -38,6 +38,7 @@ import sys
import logging
from pprint import pformat
from .text.vocab import Vocab
from .util import set_seed, preprocess_examples, load_config_json, make_data_loader
from .metrics import compute_metrics
from .utils.embeddings import load_embeddings
@ -70,12 +71,14 @@ def get_all_splits(args, new_field):
def prepare_data(args, FIELD):
new_vocab = ReversibleField(batch_first=True, init_token='<init>', eos_token='<eos>', lower=args.lower, include_lengths=True)
splits = get_all_splits(args, new_vocab)
new_vocab.build_vocab(Example.vocab_fields, *splits)
new_field = ReversibleField(batch_first=True, lower=args.lower, include_lengths=True)
splits = get_all_splits(args, new_field)
new_vocab = Vocab.build_from_data(Example.vocab_fields, *splits,
init_token=FIELD.init_token, eos_token=FIELD.eos_token,
pad_token=FIELD.pad_token, unk_token=FIELD.unk_token)
logger.info(f'Vocabulary has {len(FIELD.vocab)} tokens from training')
args.max_generative_vocab = min(len(FIELD.vocab), args.max_generative_vocab)
FIELD.append_vocab(new_vocab)
FIELD.vocab.extend(new_vocab)
logger.info(f'Vocabulary has expanded to {len(FIELD.vocab)} tokens')
vectors = load_embeddings(args)
FIELD.vocab.load_vectors(vectors, True)

View File

@ -1,5 +1,3 @@
__version__ = '0.2.1'
__all__ = ['data',
'datasets',
'utils']
__all__ = ['data', 'utils']

View File

@ -1,13 +1,8 @@
# coding: utf8
from copy import deepcopy
from collections import Counter, OrderedDict
import six
import torch
from tqdm import tqdm
from .dataset import Dataset
from .utils import get_tokenizer
from ..vocab import Vocab, SubwordVocab
from ..vocab import Vocab
class Field(object):
@ -103,37 +98,6 @@ class Field(object):
self.pad_first = pad_first
def build_vocab(self, field_names, *args, **kwargs):
"""Construct the Vocab object for this field from one or more datasets.
Arguments:
Positional arguments: Dataset objects or other iterable data
sources from which to construct the Vocab object that
represents the set of possible values for this field. If
a Dataset object is provided, all columns corresponding
to this field are used; individual columns can also be
provided directly.
Remaining keyword arguments: Passed to the constructor of Vocab.
"""
counter = Counter()
sources = []
for arg in args:
sources += [getattr(ex, name) for name in field_names for ex in arg]
for data in sources:
for x in data:
if not self.sequential:
x = [x]
counter.update(x)
specials = [self.unk_token, self.pad_token, self.init_token, self.eos_token]
specials = list(OrderedDict.fromkeys(tok for tok in specials if tok is not None))
self.vocab = self.vocab_cls(counter, specials=specials, **kwargs)
def append_vocab(self, other_field):
for w, count in other_field.vocab.stoi.items():
if w not in self.vocab.stoi:
self.vocab.stoi[w] = len(self.vocab.itos)
self.vocab.itos.append(w)
class ReversibleField(Field):

View File

@ -1,6 +1,6 @@
from __future__ import unicode_literals
import array
from collections import defaultdict
from collections import defaultdict, Counter
import io
import logging
import os
@ -21,6 +21,7 @@ logger = logging.getLogger(__name__)
MAX_WORD_LENGTH = 100
class Vocab(object):
"""Defines a vocabulary object that will be used to numericalize a field.
@ -31,7 +32,7 @@ class Vocab(object):
numerical identifiers.
itos: A list of token strings indexed by their numerical identifiers.
"""
def __init__(self, counter, max_size=None, min_freq=1, specials=['<pad>'],
def __init__(self, counter, max_size=None, min_freq=1, specials=('<pad>',),
vectors=None, cat_vectors=True):
"""Create a Vocab object from a collections.Counter.
@ -175,47 +176,29 @@ class Vocab(object):
else:
self.vectors[i] = unk_init(self.vectors[i])
class SubwordVocab(Vocab):
def __init__(self, counter, max_size=None, specials=['<pad>'],
vectors=None, unk_init=torch.Tensor.zero_, expand_vocab=False, cat_vectors=True):
"""Create a revtok subword vocabulary from a collections.Counter.
@staticmethod
def build_from_data(field_names, *args, unk_token=None, pad_token=None, init_token=None, eos_token=None, **kwargs):
"""Construct the Vocab object for this field from one or more datasets.
Arguments:
counter: collections.Counter object holding the frequencies of
each word found in the data.
max_size: The maximum size of the subword vocabulary, or None for no
maximum. Default: None.
specials: The list of special tokens (e.g., padding or eos) that
will be prepended to the vocabulary in addition to an <unk>
token.
Positional arguments: Dataset objects or other iterable data
sources from which to construct the Vocab object that
represents the set of possible values for this field. If
a Dataset object is provided, all columns corresponding
to this field are used; individual columns can also be
provided directly.
Remaining keyword arguments: Passed to the constructor of Vocab.
"""
try:
import revtok
except ImportError:
print("Please install revtok.")
raise
self.stoi = defaultdict(_default_unk_index)
self.stoi.update({tok: i for i, tok in enumerate(specials)})
self.itos = specials
self.segment = revtok.SubwordSegmenter(counter, max_size)
max_size = None if max_size is None else max_size + len(self.itos)
# sort by frequency/entropy, then alphabetically
toks = sorted(self.segment.vocab.items(),
key=lambda tup: (len(tup[0]) != 1, -tup[1], tup[0]))
for tok, _ in toks:
self.itos.append(tok)
self.stoi[tok] = len(self.itos) - 1
self.vectors = None
if vectors is not None:
self.load_vectors(vectors, cat=cat_vectors)
counter = Counter()
sources = []
for arg in args:
sources += [getattr(ex, name) for name in field_names for ex in arg]
for data in sources:
for x in data:
counter.update(x)
specials = [unk_token, pad_token, init_token, eos_token]
specials = [tok for tok in specials if tok is not None]
return Vocab(counter, specials=specials, **kwargs)
def string_hash(x):

View File

@ -41,6 +41,7 @@ import numpy as np
import torch
from tensorboardX import SummaryWriter
from .text.vocab import Vocab
from . import arguments
from .validate import validate
from .multiprocess import Multiprocess
@ -78,7 +79,7 @@ def log(rank='main'):
def prepare_data(args, field, logger):
if field is None:
logger.info(f'Constructing field')
field = ReversibleField(batch_first=True, init_token='<init>', eos_token='<eos>', lower=args.lower, include_lengths=True)
field = ReversibleField(batch_first=True, init_token='<init>', eos_token='<eos>', include_lengths=True)
train_sets, val_sets, aux_sets, vocab_sets = [], [], [], []
for task in args.train_tasks:
@ -124,7 +125,13 @@ def prepare_data(args, field, logger):
vectors = load_embeddings(args, logger)
vocab_sets = (train_sets + val_sets) if len(vocab_sets) == 0 else vocab_sets
logger.info(f'Building vocabulary')
field.build_vocab(Example.vocab_fields, *vocab_sets, max_size=args.max_effective_vocab, vectors=vectors)
field.vocab = Vocab.build_from_data(Example.vocab_fields, *vocab_sets,
unk_token=field.unk_token,
init_token=field.init_token,
eos_token=field.eos_token,
pad_token=field.pad_token,
max_size=args.max_effective_vocab,
vectors=vectors)
field.decoder_vocab = DecoderVocabulary(field.vocab.itos[:args.max_generative_vocab], field.vocab)