Move vocabulary building to Vocab class
Leaving nothing in the Field class
This commit is contained in:
parent
cde004b3e4
commit
ccd1fb6c80
|
@ -38,6 +38,7 @@ import sys
|
|||
import logging
|
||||
from pprint import pformat
|
||||
|
||||
from .text.vocab import Vocab
|
||||
from .util import set_seed, preprocess_examples, load_config_json, make_data_loader
|
||||
from .metrics import compute_metrics
|
||||
from .utils.embeddings import load_embeddings
|
||||
|
@ -70,12 +71,14 @@ def get_all_splits(args, new_field):
|
|||
|
||||
|
||||
def prepare_data(args, FIELD):
|
||||
new_vocab = ReversibleField(batch_first=True, init_token='<init>', eos_token='<eos>', lower=args.lower, include_lengths=True)
|
||||
splits = get_all_splits(args, new_vocab)
|
||||
new_vocab.build_vocab(Example.vocab_fields, *splits)
|
||||
new_field = ReversibleField(batch_first=True, lower=args.lower, include_lengths=True)
|
||||
splits = get_all_splits(args, new_field)
|
||||
new_vocab = Vocab.build_from_data(Example.vocab_fields, *splits,
|
||||
init_token=FIELD.init_token, eos_token=FIELD.eos_token,
|
||||
pad_token=FIELD.pad_token, unk_token=FIELD.unk_token)
|
||||
logger.info(f'Vocabulary has {len(FIELD.vocab)} tokens from training')
|
||||
args.max_generative_vocab = min(len(FIELD.vocab), args.max_generative_vocab)
|
||||
FIELD.append_vocab(new_vocab)
|
||||
FIELD.vocab.extend(new_vocab)
|
||||
logger.info(f'Vocabulary has expanded to {len(FIELD.vocab)} tokens')
|
||||
vectors = load_embeddings(args)
|
||||
FIELD.vocab.load_vectors(vectors, True)
|
||||
|
|
|
@ -1,5 +1,3 @@
|
|||
__version__ = '0.2.1'
|
||||
|
||||
__all__ = ['data',
|
||||
'datasets',
|
||||
'utils']
|
||||
__all__ = ['data', 'utils']
|
||||
|
|
|
@ -1,13 +1,8 @@
|
|||
# coding: utf8
|
||||
from copy import deepcopy
|
||||
from collections import Counter, OrderedDict
|
||||
import six
|
||||
import torch
|
||||
from tqdm import tqdm
|
||||
|
||||
from .dataset import Dataset
|
||||
from .utils import get_tokenizer
|
||||
from ..vocab import Vocab, SubwordVocab
|
||||
from ..vocab import Vocab
|
||||
|
||||
|
||||
class Field(object):
|
||||
|
@ -103,37 +98,6 @@ class Field(object):
|
|||
self.pad_first = pad_first
|
||||
|
||||
|
||||
def build_vocab(self, field_names, *args, **kwargs):
|
||||
"""Construct the Vocab object for this field from one or more datasets.
|
||||
|
||||
Arguments:
|
||||
Positional arguments: Dataset objects or other iterable data
|
||||
sources from which to construct the Vocab object that
|
||||
represents the set of possible values for this field. If
|
||||
a Dataset object is provided, all columns corresponding
|
||||
to this field are used; individual columns can also be
|
||||
provided directly.
|
||||
Remaining keyword arguments: Passed to the constructor of Vocab.
|
||||
"""
|
||||
counter = Counter()
|
||||
sources = []
|
||||
for arg in args:
|
||||
sources += [getattr(ex, name) for name in field_names for ex in arg]
|
||||
for data in sources:
|
||||
for x in data:
|
||||
if not self.sequential:
|
||||
x = [x]
|
||||
counter.update(x)
|
||||
specials = [self.unk_token, self.pad_token, self.init_token, self.eos_token]
|
||||
specials = list(OrderedDict.fromkeys(tok for tok in specials if tok is not None))
|
||||
self.vocab = self.vocab_cls(counter, specials=specials, **kwargs)
|
||||
|
||||
def append_vocab(self, other_field):
|
||||
for w, count in other_field.vocab.stoi.items():
|
||||
if w not in self.vocab.stoi:
|
||||
self.vocab.stoi[w] = len(self.vocab.itos)
|
||||
self.vocab.itos.append(w)
|
||||
|
||||
|
||||
class ReversibleField(Field):
|
||||
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
from __future__ import unicode_literals
|
||||
import array
|
||||
from collections import defaultdict
|
||||
from collections import defaultdict, Counter
|
||||
import io
|
||||
import logging
|
||||
import os
|
||||
|
@ -21,6 +21,7 @@ logger = logging.getLogger(__name__)
|
|||
|
||||
MAX_WORD_LENGTH = 100
|
||||
|
||||
|
||||
class Vocab(object):
|
||||
"""Defines a vocabulary object that will be used to numericalize a field.
|
||||
|
||||
|
@ -31,7 +32,7 @@ class Vocab(object):
|
|||
numerical identifiers.
|
||||
itos: A list of token strings indexed by their numerical identifiers.
|
||||
"""
|
||||
def __init__(self, counter, max_size=None, min_freq=1, specials=['<pad>'],
|
||||
def __init__(self, counter, max_size=None, min_freq=1, specials=('<pad>',),
|
||||
vectors=None, cat_vectors=True):
|
||||
"""Create a Vocab object from a collections.Counter.
|
||||
|
||||
|
@ -175,47 +176,29 @@ class Vocab(object):
|
|||
else:
|
||||
self.vectors[i] = unk_init(self.vectors[i])
|
||||
|
||||
|
||||
class SubwordVocab(Vocab):
|
||||
|
||||
def __init__(self, counter, max_size=None, specials=['<pad>'],
|
||||
vectors=None, unk_init=torch.Tensor.zero_, expand_vocab=False, cat_vectors=True):
|
||||
"""Create a revtok subword vocabulary from a collections.Counter.
|
||||
@staticmethod
|
||||
def build_from_data(field_names, *args, unk_token=None, pad_token=None, init_token=None, eos_token=None, **kwargs):
|
||||
"""Construct the Vocab object for this field from one or more datasets.
|
||||
|
||||
Arguments:
|
||||
counter: collections.Counter object holding the frequencies of
|
||||
each word found in the data.
|
||||
max_size: The maximum size of the subword vocabulary, or None for no
|
||||
maximum. Default: None.
|
||||
specials: The list of special tokens (e.g., padding or eos) that
|
||||
will be prepended to the vocabulary in addition to an <unk>
|
||||
token.
|
||||
Positional arguments: Dataset objects or other iterable data
|
||||
sources from which to construct the Vocab object that
|
||||
represents the set of possible values for this field. If
|
||||
a Dataset object is provided, all columns corresponding
|
||||
to this field are used; individual columns can also be
|
||||
provided directly.
|
||||
Remaining keyword arguments: Passed to the constructor of Vocab.
|
||||
"""
|
||||
try:
|
||||
import revtok
|
||||
except ImportError:
|
||||
print("Please install revtok.")
|
||||
raise
|
||||
|
||||
self.stoi = defaultdict(_default_unk_index)
|
||||
self.stoi.update({tok: i for i, tok in enumerate(specials)})
|
||||
self.itos = specials
|
||||
|
||||
self.segment = revtok.SubwordSegmenter(counter, max_size)
|
||||
|
||||
max_size = None if max_size is None else max_size + len(self.itos)
|
||||
|
||||
# sort by frequency/entropy, then alphabetically
|
||||
toks = sorted(self.segment.vocab.items(),
|
||||
key=lambda tup: (len(tup[0]) != 1, -tup[1], tup[0]))
|
||||
|
||||
for tok, _ in toks:
|
||||
self.itos.append(tok)
|
||||
self.stoi[tok] = len(self.itos) - 1
|
||||
|
||||
self.vectors = None
|
||||
if vectors is not None:
|
||||
self.load_vectors(vectors, cat=cat_vectors)
|
||||
counter = Counter()
|
||||
sources = []
|
||||
for arg in args:
|
||||
sources += [getattr(ex, name) for name in field_names for ex in arg]
|
||||
for data in sources:
|
||||
for x in data:
|
||||
counter.update(x)
|
||||
specials = [unk_token, pad_token, init_token, eos_token]
|
||||
specials = [tok for tok in specials if tok is not None]
|
||||
return Vocab(counter, specials=specials, **kwargs)
|
||||
|
||||
|
||||
def string_hash(x):
|
||||
|
|
|
@ -41,6 +41,7 @@ import numpy as np
|
|||
import torch
|
||||
from tensorboardX import SummaryWriter
|
||||
|
||||
from .text.vocab import Vocab
|
||||
from . import arguments
|
||||
from .validate import validate
|
||||
from .multiprocess import Multiprocess
|
||||
|
@ -78,7 +79,7 @@ def log(rank='main'):
|
|||
def prepare_data(args, field, logger):
|
||||
if field is None:
|
||||
logger.info(f'Constructing field')
|
||||
field = ReversibleField(batch_first=True, init_token='<init>', eos_token='<eos>', lower=args.lower, include_lengths=True)
|
||||
field = ReversibleField(batch_first=True, init_token='<init>', eos_token='<eos>', include_lengths=True)
|
||||
|
||||
train_sets, val_sets, aux_sets, vocab_sets = [], [], [], []
|
||||
for task in args.train_tasks:
|
||||
|
@ -124,7 +125,13 @@ def prepare_data(args, field, logger):
|
|||
vectors = load_embeddings(args, logger)
|
||||
vocab_sets = (train_sets + val_sets) if len(vocab_sets) == 0 else vocab_sets
|
||||
logger.info(f'Building vocabulary')
|
||||
field.build_vocab(Example.vocab_fields, *vocab_sets, max_size=args.max_effective_vocab, vectors=vectors)
|
||||
field.vocab = Vocab.build_from_data(Example.vocab_fields, *vocab_sets,
|
||||
unk_token=field.unk_token,
|
||||
init_token=field.init_token,
|
||||
eos_token=field.eos_token,
|
||||
pad_token=field.pad_token,
|
||||
max_size=args.max_effective_vocab,
|
||||
vectors=vectors)
|
||||
|
||||
field.decoder_vocab = DecoderVocabulary(field.vocab.itos[:args.max_generative_vocab], field.vocab)
|
||||
|
||||
|
|
Loading…
Reference in New Issue