diff --git a/spacy/cli/train.py b/spacy/cli/train.py index 96233406d..d71523a9c 100644 --- a/spacy/cli/train.py +++ b/spacy/cli/train.py @@ -18,6 +18,7 @@ from ..gold import GoldParse, merge_sents from ..gold import GoldCorpus, minibatch from ..util import prints from .. import util +from .. import about from .. import displacy from ..compat import json_dumps @@ -35,10 +36,11 @@ from ..compat import json_dumps no_parser=("Don't train parser", "flag", "P", bool), no_entities=("Don't train NER", "flag", "N", bool), gold_preproc=("Use gold preprocessing", "flag", "G", bool), + meta_path=("Optional path to meta.json. All relevant properties will be overwritten.", "option", "m", Path) ) def train(cmd, lang, output_dir, train_data, dev_data, n_iter=20, n_sents=0, use_gpu=-1, vectors=None, no_tagger=False, no_parser=False, no_entities=False, - gold_preproc=False): + gold_preproc=False, meta_path=None): """ Train a model. Expects data in spaCy's JSON format. """ @@ -47,13 +49,19 @@ def train(cmd, lang, output_dir, train_data, dev_data, n_iter=20, n_sents=0, output_path = util.ensure_path(output_dir) train_path = util.ensure_path(train_data) dev_path = util.ensure_path(dev_data) + meta_path = util.ensure_path(meta_path) if not output_path.exists(): output_path.mkdir() if not train_path.exists(): prints(train_path, title="Training data not found", exits=1) if dev_path and not dev_path.exists(): prints(dev_path, title="Development data not found", exits=1) - + if meta_path is not None and not meta_path.exists(): + prints(meta_path, title="meta.json not found", exits=1) + meta = util.read_json(meta_path) if meta_path else {} + if not isinstance(meta, dict): + prints("Expected dict but got: {}".format(type(meta)), + title="Not a valid meta.json format", exits=1) pipeline = ['token_vectors', 'tags', 'dependencies', 'entities'] if no_tagger and 'tags' in pipeline: pipeline.remove('tags') @@ -105,9 +113,16 @@ def train(cmd, lang, output_dir, train_data, dev_data, n_iter=20, n_sents=0, corpus.dev_docs( nlp, gold_preproc=gold_preproc)) - acc_loc =(output_path / ('model%d' % i) / 'accuracy.json') - with acc_loc.open('w') as file_: - file_.write(json_dumps(scorer.scores)) + meta_loc = output_path / ('model%d' % i) / 'meta.json' + meta['accuracy'] = scorer.scores + meta['lang'] = nlp.lang + meta['pipeline'] = pipeline + meta['spacy_version'] = '>=%s' % about.__version__ + meta.setdefault('name', 'model%d' % i) + meta.setdefault('version', '0.0.0') + + with meta_loc.open('w') as file_: + file_.write(json_dumps(meta)) util.set_env_log(True) print_progress(i, losses, scorer.scores) finally: