diff --git a/spacy/cli/ud_train.py b/spacy/cli/ud_train.py index b827d4a4f..853cff9b3 100644 --- a/spacy/cli/ud_train.py +++ b/spacy/cli/ud_train.py @@ -254,7 +254,7 @@ def load_nlp(corpus, config): nlp.vocab.from_disk(Path(config.vectors) / 'vocab') return nlp -def initialize_pipeline(nlp, docs, golds, config): +def initialize_pipeline(nlp, docs, golds, config, device): nlp.add_pipe(nlp.create_pipe('parser')) if config.multitask_tag: nlp.parser.add_multitask_objective('tag') @@ -265,7 +265,7 @@ def initialize_pipeline(nlp, docs, golds, config): for tag in gold.tags: if tag is not None: nlp.tagger.add_label(tag) - return nlp.begin_training(lambda: golds_to_gold_tuples(docs, golds)) + return nlp.begin_training(lambda: golds_to_gold_tuples(docs, golds), device=device) ######################## @@ -318,15 +318,14 @@ class TreebankPaths(object): "positional", None, str), parses_dir=("Directory to write the development parses", "positional", None, Path), config=("Path to json formatted config file", "positional"), - limit=("Size limit", "option", "n", int) + limit=("Size limit", "option", "n", int), + use_gpu=("Use GPU", "option", "g", int) ) -def main(ud_dir, parses_dir, config, corpus, limit=0): +def main(ud_dir, parses_dir, config, corpus, limit=0, use_gpu=-1): + spacy.util.fix_random_seed() lang.zh.Chinese.Defaults.use_jieba = False lang.ja.Japanese.Defaults.use_janome = False - random.seed(0) - numpy.random.seed(0) - config = Config.load(config) paths = TreebankPaths(ud_dir, corpus) if not (parses_dir / corpus).exists(): @@ -337,9 +336,9 @@ def main(ud_dir, parses_dir, config, corpus, limit=0): docs, golds = read_data(nlp, paths.train.conllu.open(), paths.train.text.open(), max_doc_length=config.max_doc_length, limit=limit) - optimizer = initialize_pipeline(nlp, docs, golds, config) + optimizer = initialize_pipeline(nlp, docs, golds, config, use_gpu) - batch_sizes = compounding(config.batch_size //10, config.batch_size, 1.001) + batch_sizes = compounding(config.batch_size//10, config.batch_size, 1.001) for i in range(config.nr_epoch): docs = [nlp.make_doc(doc.text) for doc in docs] Xs = list(zip(docs, golds))