diff --git a/decanlp/predict.py b/decanlp/predict.py index 40b7d414..8369f899 100644 --- a/decanlp/predict.py +++ b/decanlp/predict.py @@ -38,6 +38,7 @@ import sys import logging from pprint import pformat +from .text.vocab import Vocab from .util import set_seed, preprocess_examples, load_config_json, make_data_loader from .metrics import compute_metrics from .utils.embeddings import load_embeddings @@ -70,12 +71,14 @@ def get_all_splits(args, new_field): def prepare_data(args, FIELD): - new_vocab = ReversibleField(batch_first=True, init_token='', eos_token='', lower=args.lower, include_lengths=True) - splits = get_all_splits(args, new_vocab) - new_vocab.build_vocab(Example.vocab_fields, *splits) + new_field = ReversibleField(batch_first=True, lower=args.lower, include_lengths=True) + splits = get_all_splits(args, new_field) + new_vocab = Vocab.build_from_data(Example.vocab_fields, *splits, + init_token=FIELD.init_token, eos_token=FIELD.eos_token, + pad_token=FIELD.pad_token, unk_token=FIELD.unk_token) logger.info(f'Vocabulary has {len(FIELD.vocab)} tokens from training') args.max_generative_vocab = min(len(FIELD.vocab), args.max_generative_vocab) - FIELD.append_vocab(new_vocab) + FIELD.vocab.extend(new_vocab) logger.info(f'Vocabulary has expanded to {len(FIELD.vocab)} tokens') vectors = load_embeddings(args) FIELD.vocab.load_vectors(vectors, True) diff --git a/decanlp/text/__init__.py b/decanlp/text/__init__.py index fd7557f0..a0073f14 100644 --- a/decanlp/text/__init__.py +++ b/decanlp/text/__init__.py @@ -1,5 +1,3 @@ __version__ = '0.2.1' -__all__ = ['data', - 'datasets', - 'utils'] +__all__ = ['data', 'utils'] diff --git a/decanlp/text/data/field.py b/decanlp/text/data/field.py index ad194852..9e92ee1a 100644 --- a/decanlp/text/data/field.py +++ b/decanlp/text/data/field.py @@ -1,13 +1,8 @@ # coding: utf8 -from copy import deepcopy -from collections import Counter, OrderedDict -import six import torch -from tqdm import tqdm -from .dataset import Dataset from .utils import get_tokenizer -from ..vocab import Vocab, SubwordVocab +from ..vocab import Vocab class Field(object): @@ -103,37 +98,6 @@ class Field(object): self.pad_first = pad_first - def build_vocab(self, field_names, *args, **kwargs): - """Construct the Vocab object for this field from one or more datasets. - - Arguments: - Positional arguments: Dataset objects or other iterable data - sources from which to construct the Vocab object that - represents the set of possible values for this field. If - a Dataset object is provided, all columns corresponding - to this field are used; individual columns can also be - provided directly. - Remaining keyword arguments: Passed to the constructor of Vocab. - """ - counter = Counter() - sources = [] - for arg in args: - sources += [getattr(ex, name) for name in field_names for ex in arg] - for data in sources: - for x in data: - if not self.sequential: - x = [x] - counter.update(x) - specials = [self.unk_token, self.pad_token, self.init_token, self.eos_token] - specials = list(OrderedDict.fromkeys(tok for tok in specials if tok is not None)) - self.vocab = self.vocab_cls(counter, specials=specials, **kwargs) - - def append_vocab(self, other_field): - for w, count in other_field.vocab.stoi.items(): - if w not in self.vocab.stoi: - self.vocab.stoi[w] = len(self.vocab.itos) - self.vocab.itos.append(w) - class ReversibleField(Field): diff --git a/decanlp/text/vocab.py b/decanlp/text/vocab.py index e8ce9e5e..881f60d3 100644 --- a/decanlp/text/vocab.py +++ b/decanlp/text/vocab.py @@ -1,6 +1,6 @@ from __future__ import unicode_literals import array -from collections import defaultdict +from collections import defaultdict, Counter import io import logging import os @@ -21,6 +21,7 @@ logger = logging.getLogger(__name__) MAX_WORD_LENGTH = 100 + class Vocab(object): """Defines a vocabulary object that will be used to numericalize a field. @@ -31,7 +32,7 @@ class Vocab(object): numerical identifiers. itos: A list of token strings indexed by their numerical identifiers. """ - def __init__(self, counter, max_size=None, min_freq=1, specials=[''], + def __init__(self, counter, max_size=None, min_freq=1, specials=('',), vectors=None, cat_vectors=True): """Create a Vocab object from a collections.Counter. @@ -175,47 +176,29 @@ class Vocab(object): else: self.vectors[i] = unk_init(self.vectors[i]) - -class SubwordVocab(Vocab): - - def __init__(self, counter, max_size=None, specials=[''], - vectors=None, unk_init=torch.Tensor.zero_, expand_vocab=False, cat_vectors=True): - """Create a revtok subword vocabulary from a collections.Counter. + @staticmethod + def build_from_data(field_names, *args, unk_token=None, pad_token=None, init_token=None, eos_token=None, **kwargs): + """Construct the Vocab object for this field from one or more datasets. Arguments: - counter: collections.Counter object holding the frequencies of - each word found in the data. - max_size: The maximum size of the subword vocabulary, or None for no - maximum. Default: None. - specials: The list of special tokens (e.g., padding or eos) that - will be prepended to the vocabulary in addition to an - token. + Positional arguments: Dataset objects or other iterable data + sources from which to construct the Vocab object that + represents the set of possible values for this field. If + a Dataset object is provided, all columns corresponding + to this field are used; individual columns can also be + provided directly. + Remaining keyword arguments: Passed to the constructor of Vocab. """ - try: - import revtok - except ImportError: - print("Please install revtok.") - raise - - self.stoi = defaultdict(_default_unk_index) - self.stoi.update({tok: i for i, tok in enumerate(specials)}) - self.itos = specials - - self.segment = revtok.SubwordSegmenter(counter, max_size) - - max_size = None if max_size is None else max_size + len(self.itos) - - # sort by frequency/entropy, then alphabetically - toks = sorted(self.segment.vocab.items(), - key=lambda tup: (len(tup[0]) != 1, -tup[1], tup[0])) - - for tok, _ in toks: - self.itos.append(tok) - self.stoi[tok] = len(self.itos) - 1 - - self.vectors = None - if vectors is not None: - self.load_vectors(vectors, cat=cat_vectors) + counter = Counter() + sources = [] + for arg in args: + sources += [getattr(ex, name) for name in field_names for ex in arg] + for data in sources: + for x in data: + counter.update(x) + specials = [unk_token, pad_token, init_token, eos_token] + specials = [tok for tok in specials if tok is not None] + return Vocab(counter, specials=specials, **kwargs) def string_hash(x): diff --git a/decanlp/train.py b/decanlp/train.py index 39377ba6..86e56b5d 100644 --- a/decanlp/train.py +++ b/decanlp/train.py @@ -41,6 +41,7 @@ import numpy as np import torch from tensorboardX import SummaryWriter +from .text.vocab import Vocab from . import arguments from .validate import validate from .multiprocess import Multiprocess @@ -78,7 +79,7 @@ def log(rank='main'): def prepare_data(args, field, logger): if field is None: logger.info(f'Constructing field') - field = ReversibleField(batch_first=True, init_token='', eos_token='', lower=args.lower, include_lengths=True) + field = ReversibleField(batch_first=True, init_token='', eos_token='', include_lengths=True) train_sets, val_sets, aux_sets, vocab_sets = [], [], [], [] for task in args.train_tasks: @@ -124,7 +125,13 @@ def prepare_data(args, field, logger): vectors = load_embeddings(args, logger) vocab_sets = (train_sets + val_sets) if len(vocab_sets) == 0 else vocab_sets logger.info(f'Building vocabulary') - field.build_vocab(Example.vocab_fields, *vocab_sets, max_size=args.max_effective_vocab, vectors=vectors) + field.vocab = Vocab.build_from_data(Example.vocab_fields, *vocab_sets, + unk_token=field.unk_token, + init_token=field.init_token, + eos_token=field.eos_token, + pad_token=field.pad_token, + max_size=args.max_effective_vocab, + vectors=vectors) field.decoder_vocab = DecoderVocabulary(field.vocab.itos[:args.max_generative_vocab], field.vocab)