From 86405e4ad1ea443c231a9a5a22b23959d8ddcd42 Mon Sep 17 00:00:00 2001 From: Matthew Honnibal Date: Sun, 18 Feb 2018 10:59:11 +0100 Subject: [PATCH] Fix CLI for multitask objectives --- spacy/cli/train.py | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/spacy/cli/train.py b/spacy/cli/train.py index 957a19ba3..be5be0f0b 100644 --- a/spacy/cli/train.py +++ b/spacy/cli/train.py @@ -30,8 +30,8 @@ from ..compat import json_dumps no_tagger=("Don't train tagger", "flag", "T", bool), no_parser=("Don't train parser", "flag", "P", bool), no_entities=("Don't train NER", "flag", "N", bool), - parser_multitasks=("Side objectives for parser CNN, e.g. dep dep,tag", "option", "pt", ","), - entity_multitasks=("Side objectives for ner CNN, e.g. dep dep,tag", "option", "et", ","), + parser_multitasks=("Side objectives for parser CNN, e.g. dep dep,tag", "option", "pt", str), + entity_multitasks=("Side objectives for ner CNN, e.g. dep dep,tag", "option", "et", str), gold_preproc=("Use gold preprocessing", "flag", "G", bool), version=("Model version", "option", "V", str), meta_path=("Optional path to meta.json. All relevant properties will be " @@ -105,10 +105,12 @@ def train(lang, output_dir, train_data, dev_data, n_iter=30, n_sents=0, lex.is_oov = False for name in pipeline: nlp.add_pipe(nlp.create_pipe(name), name=name) - for objective in parser_multitasks.split(','): - nlp.parser.add_multitask_objective(objective) - for objective in entity_multitasks.split(','): - nlp.entity.add_multitask_objective(objective) + if parser_multitasks: + for objective in parser_multitasks.split(','): + nlp.parser.add_multitask_objective(objective) + if entity_multitasks: + for objective in entity_multitasks.split(','): + nlp.entity.add_multitask_objective(objective) optimizer = nlp.begin_training(lambda: corpus.train_tuples, device=use_gpu) nlp._optimizer = None