mirror of https://github.com/explosion/spaCy.git
Improve train CLI
This commit is contained in:
parent
a053b1218e
commit
c52fde40f4
|
@ -18,6 +18,7 @@ from ..gold import GoldCorpus, minibatch
|
|||
from ..util import prints
|
||||
from .. import util
|
||||
from .. import displacy
|
||||
from ..compat import json_dumps
|
||||
|
||||
|
||||
@plac.annotations(
|
||||
|
@ -44,7 +45,7 @@ def train(cmd, lang, output_dir, train_data, dev_data, n_iter=20, n_sents=0,
|
|||
train_path = util.ensure_path(train_data)
|
||||
dev_path = util.ensure_path(dev_data)
|
||||
if not output_path.exists():
|
||||
prints(output_path, title="Output directory not found", exits=1)
|
||||
output_path.mkdir()
|
||||
if not train_path.exists():
|
||||
prints(train_path, title="Training data not found", exits=1)
|
||||
if dev_path and not dev_path.exists():
|
||||
|
@ -74,7 +75,7 @@ def train(cmd, lang, output_dir, train_data, dev_data, n_iter=20, n_sents=0,
|
|||
else:
|
||||
nlp = lang_class(pipeline=pipeline)
|
||||
corpus = GoldCorpus(train_path, dev_path, limit=n_sents)
|
||||
n_train_docs = corpus.count_train()
|
||||
n_train_words = corpus.count_train()
|
||||
|
||||
optimizer = nlp.begin_training(lambda: corpus.train_tuples, device=use_gpu)
|
||||
|
||||
|
@ -83,7 +84,7 @@ def train(cmd, lang, output_dir, train_data, dev_data, n_iter=20, n_sents=0,
|
|||
for i in range(n_iter):
|
||||
if resume:
|
||||
i += 20
|
||||
with tqdm.tqdm(total=corpus.count_train(), leave=False) as pbar:
|
||||
with tqdm.tqdm(total=n_train_words, leave=False) as pbar:
|
||||
train_docs = corpus.train_docs(nlp, projectivize=True,
|
||||
gold_preproc=False, max_length=0)
|
||||
losses = {}
|
||||
|
@ -91,7 +92,7 @@ def train(cmd, lang, output_dir, train_data, dev_data, n_iter=20, n_sents=0,
|
|||
docs, golds = zip(*batch)
|
||||
nlp.update(docs, golds, sgd=optimizer,
|
||||
drop=next(dropout_rates), losses=losses)
|
||||
pbar.update(len(docs))
|
||||
pbar.update(sum(len(doc) for doc in docs))
|
||||
|
||||
with nlp.use_params(optimizer.averages):
|
||||
util.set_env_log(False)
|
||||
|
@ -105,6 +106,9 @@ def train(cmd, lang, output_dir, train_data, dev_data, n_iter=20, n_sents=0,
|
|||
corpus.dev_docs(
|
||||
nlp_loaded,
|
||||
gold_preproc=False))
|
||||
acc_loc =(output_path / ('model%d' % i) / 'accuracy.json')
|
||||
with acc_loc.open('w') as file_:
|
||||
file_.write(json_dumps(scorer.scores))
|
||||
util.set_env_log(True)
|
||||
print_progress(i, losses, scorer.scores)
|
||||
finally:
|
||||
|
|
Loading…
Reference in New Issue