diff --git a/spacy/cli/train.py b/spacy/cli/train.py index 58c30baf2..896868419 100644 --- a/spacy/cli/train.py +++ b/spacy/cli/train.py @@ -16,7 +16,7 @@ from .. import util def train(language, output_dir, train_data, dev_data, n_iter, tagger, parser, ner): output_path = Path(output_dir) train_path = Path(train_data) - dev_path = Path(dev_data) + dev_path = Path(dev_data) if dev_data else None check_dirs(output_path, data_path, dev_path) lang = util.get_lang_class(language) @@ -26,12 +26,13 @@ def train(language, output_dir, train_data, dev_data, n_iter, tagger, parser, ne parser_cfg['features'] = lang.Defaults.parser_features entity_cfg['features'] = lang.Defaults.entity_features gold_train = list(read_gold_json(train_path)) - gold_dev = list(read_gold_json(dev_path)) + gold_dev = list(read_gold_json(dev_path)) if dev_path else None train_model(lang, gold_train, gold_dev, output_path, tagger_cfg, parser_cfg, entity_cfg, n_iter) - scorer = evaluate(lang, list(read_gold_json(dev_loc)), output_path) - print_results(scorer) + if gold_dev: + scorer = evaluate(lang, gold_dev, output_path) + print_results(scorer) def train_config(config): @@ -54,7 +55,7 @@ def train_model(Language, train_data, dev_data, output_path, tagger_cfg, parser_ for itn, epoch in enumerate(trainer.epochs(n_iter, augment_data=None)): for doc, gold in epoch: trainer.update(doc, gold) - dev_scores = trainer.evaluate(dev_data) + dev_scores = trainer.evaluate(dev_data) if dev_data else [] print_progress(itn, trainer.nlp.parser.model.nr_weight, trainer.nlp.parser.model.nr_active_feat, **dev_scores.scores) @@ -82,8 +83,10 @@ def evaluate(Language, gold_tuples, output_path): def check_dirs(input_path, train_path, dev_path): if not output_path.exists(): util.sys_exit(output_path.as_posix(), title="Output directory not found") - if not train_path.exists() and train_path.is_file(): + if not train_path.exists() or not train_path.is_file(): util.sys_exit(train_path.as_posix(), title="Training data not found") + if dev_path and not dev_path.exists(): + util.sys_exit(dev_path.as_posix(), title="Development data not found") def print_progress(itn, nr_weight, nr_active_feat, **scores):