diff --git a/spacy/cli/train.py b/spacy/cli/train.py index 61278e2a3..af028dae5 100644 --- a/spacy/cli/train.py +++ b/spacy/cli/train.py @@ -18,6 +18,7 @@ from ..gold import GoldCorpus, minibatch from ..util import prints from .. import util from .. import displacy +from ..compat import json_dumps @plac.annotations( @@ -44,7 +45,7 @@ def train(cmd, lang, output_dir, train_data, dev_data, n_iter=20, n_sents=0, train_path = util.ensure_path(train_data) dev_path = util.ensure_path(dev_data) if not output_path.exists(): - prints(output_path, title="Output directory not found", exits=1) + output_path.mkdir() if not train_path.exists(): prints(train_path, title="Training data not found", exits=1) if dev_path and not dev_path.exists(): @@ -74,7 +75,7 @@ def train(cmd, lang, output_dir, train_data, dev_data, n_iter=20, n_sents=0, else: nlp = lang_class(pipeline=pipeline) corpus = GoldCorpus(train_path, dev_path, limit=n_sents) - n_train_docs = corpus.count_train() + n_train_words = corpus.count_train() optimizer = nlp.begin_training(lambda: corpus.train_tuples, device=use_gpu) @@ -83,7 +84,7 @@ def train(cmd, lang, output_dir, train_data, dev_data, n_iter=20, n_sents=0, for i in range(n_iter): if resume: i += 20 - with tqdm.tqdm(total=corpus.count_train(), leave=False) as pbar: + with tqdm.tqdm(total=n_train_words, leave=False) as pbar: train_docs = corpus.train_docs(nlp, projectivize=True, gold_preproc=False, max_length=0) losses = {} @@ -91,7 +92,7 @@ def train(cmd, lang, output_dir, train_data, dev_data, n_iter=20, n_sents=0, docs, golds = zip(*batch) nlp.update(docs, golds, sgd=optimizer, drop=next(dropout_rates), losses=losses) - pbar.update(len(docs)) + pbar.update(sum(len(doc) for doc in docs)) with nlp.use_params(optimizer.averages): util.set_env_log(False) @@ -105,6 +106,9 @@ def train(cmd, lang, output_dir, train_data, dev_data, n_iter=20, n_sents=0, corpus.dev_docs( nlp_loaded, gold_preproc=False)) + acc_loc =(output_path / ('model%d' % i) / 'accuracy.json') + with acc_loc.open('w') as file_: + file_.write(json_dumps(scorer.scores)) util.set_env_log(True) print_progress(i, losses, scorer.scores) finally: