From 59cf533879b77852db9af41b5714217e8f670a29 Mon Sep 17 00:00:00 2001 From: Matthew Honnibal Date: Thu, 13 Sep 2018 14:24:08 +0200 Subject: [PATCH] Improve ud-train script. Make config optional --- spacy/cli/ud_train.py | 24 +++++++++++++++--------- 1 file changed, 15 insertions(+), 9 deletions(-) diff --git a/spacy/cli/ud_train.py b/spacy/cli/ud_train.py index fd463f4c8..f8714fa33 100644 --- a/spacy/cli/ud_train.py +++ b/spacy/cli/ud_train.py @@ -13,7 +13,7 @@ import spacy import spacy.util from ..tokens import Token, Doc from ..gold import GoldParse -from ..util import compounding, minibatch_by_words +from ..util import compounding, minibatch, minibatch_by_words from ..syntax.nonproj import projectivize from ..matcher import Matcher from .. import displacy @@ -302,8 +302,8 @@ def initialize_pipeline(nlp, docs, golds, config, device): class Config(object): def __init__(self, vectors=None, max_doc_length=10, multitask_tag=True, multitask_sent=True, multitask_dep=True, multitask_vectors=False, - nr_epoch=30, batch_size=1000, dropout=0.2, - conv_depth=4, subword_features=True): + nr_epoch=30, min_batch_size=1, max_batch_size=16, batch_by_words=False, + dropout=0.2, conv_depth=4, subword_features=True): for key, value in locals().items(): setattr(self, key, value) @@ -346,20 +346,23 @@ class TreebankPaths(object): corpus=("UD corpus to train and evaluate on, e.g. en, es_ancora, etc", "positional", None, str), parses_dir=("Directory to write the development parses", "positional", None, Path), - config=("Path to json formatted config file", "positional"), + config=("Path to json formatted config file", "option", "C", Path), limit=("Size limit", "option", "n", int), use_gpu=("Use GPU", "option", "g", int), use_oracle_segments=("Use oracle segments", "flag", "G", int), vectors_dir=("Path to directory with pre-trained vectors, named e.g. en/", "option", "v", Path), ) -def main(ud_dir, parses_dir, config, corpus, limit=0, use_gpu=-1, vectors_dir=None, +def main(ud_dir, parses_dir, config=None, corpus, limit=0, use_gpu=-1, vectors_dir=None, use_oracle_segments=False): spacy.util.fix_random_seed() lang.zh.Chinese.Defaults.use_jieba = False lang.ja.Japanese.Defaults.use_janome = False - - config = Config.load(config) + + if config is not None: + config = Config.load(config) + else: + config = Config() paths = TreebankPaths(ud_dir, corpus) if not (parses_dir / corpus).exists(): (parses_dir / corpus).mkdir() @@ -372,7 +375,7 @@ def main(ud_dir, parses_dir, config, corpus, limit=0, use_gpu=-1, vectors_dir=No optimizer = initialize_pipeline(nlp, docs, golds, config, use_gpu) - batch_sizes = compounding(config.batch_size//10, config.batch_size, 1.001) + batch_sizes = compounding(config.min_batch_size, config.max_batch_size, 1.001) beam_prob = compounding(0.2, 0.8, 1.001) for i in range(config.nr_epoch): docs, golds = read_data(nlp, paths.train.conllu.open(), paths.train.text.open(), @@ -381,7 +384,10 @@ def main(ud_dir, parses_dir, config, corpus, limit=0, use_gpu=-1, vectors_dir=No raw_text=not use_oracle_segments) Xs = list(zip(docs, golds)) random.shuffle(Xs) - batches = minibatch_by_words(Xs, size=batch_sizes) + if config.batch_by_words: + batches = minibatch_by_words(Xs, size=batch_sizes) + else: + batches = minibatch(Xs, size=batch_sizes) losses = {} n_train_words = sum(len(doc) for doc in docs) with tqdm.tqdm(total=n_train_words, leave=False) as pbar: