mirror of https://github.com/explosion/spaCy.git
Improve train CLI script
This commit is contained in:
parent
d21459f87d
commit
43353b5413
|
@ -28,15 +28,17 @@ from .. import displacy
|
|||
n_iter=("number of iterations", "option", "n", int),
|
||||
n_sents=("number of sentences", "option", "ns", int),
|
||||
use_gpu=("Use GPU", "flag", "G", bool),
|
||||
resume=("Whether to resume training", "flag", "R", bool),
|
||||
no_tagger=("Don't train tagger", "flag", "T", bool),
|
||||
no_parser=("Don't train parser", "flag", "P", bool),
|
||||
no_entities=("Don't train NER", "flag", "N", bool)
|
||||
)
|
||||
def train(cmd, lang, output_dir, train_data, dev_data, n_iter=20, n_sents=0,
|
||||
use_gpu=False, no_tagger=False, no_parser=False, no_entities=False):
|
||||
use_gpu=False, resume=False, no_tagger=False, no_parser=False, no_entities=False):
|
||||
"""
|
||||
Train a model. Expects data in spaCy's JSON format.
|
||||
"""
|
||||
util.set_env_log(True)
|
||||
n_sents = n_sents or None
|
||||
output_path = util.ensure_path(output_dir)
|
||||
train_path = util.ensure_path(train_data)
|
||||
|
@ -66,7 +68,11 @@ def train(cmd, lang, output_dir, train_data, dev_data, n_iter=20, n_sents=0,
|
|||
util.env_opt('batch_to', 64),
|
||||
util.env_opt('batch_compound', 1.001))
|
||||
|
||||
nlp = lang_class(pipeline=pipeline)
|
||||
if resume:
|
||||
prints(output_path / 'model19.pickle', title="Resuming training")
|
||||
nlp = dill.load((output_path / 'model19.pickle').open('rb'))
|
||||
else:
|
||||
nlp = lang_class(pipeline=pipeline)
|
||||
corpus = GoldCorpus(train_path, dev_path, limit=n_sents)
|
||||
n_train_docs = corpus.count_train()
|
||||
|
||||
|
@ -75,6 +81,8 @@ def train(cmd, lang, output_dir, train_data, dev_data, n_iter=20, n_sents=0,
|
|||
print("Itn.\tLoss\tUAS\tNER P.\tNER R.\tNER F.\tTag %\tToken %")
|
||||
try:
|
||||
for i in range(n_iter):
|
||||
if resume:
|
||||
i += 20
|
||||
with tqdm.tqdm(total=corpus.count_train(), leave=False) as pbar:
|
||||
train_docs = corpus.train_docs(nlp, projectivize=True,
|
||||
gold_preproc=False, max_length=0)
|
||||
|
@ -86,14 +94,18 @@ def train(cmd, lang, output_dir, train_data, dev_data, n_iter=20, n_sents=0,
|
|||
pbar.update(len(docs))
|
||||
|
||||
with nlp.use_params(optimizer.averages):
|
||||
util.set_env_log(False)
|
||||
epoch_model_path = output_path / ('model%d' % i)
|
||||
nlp.to_disk(epoch_model_path)
|
||||
with (output_path / ('model%d.pickle' % i)).open('wb') as file_:
|
||||
dill.dump(nlp, file_, -1)
|
||||
with (output_path / ('model%d.bin' % i)).open('wb') as file_:
|
||||
file_.write(nlp.to_bytes())
|
||||
with (output_path / ('model%d.bin' % i)).open('rb') as file_:
|
||||
nlp_loaded = lang_class(pipeline=pipeline)
|
||||
nlp_loaded.from_bytes(file_.read())
|
||||
scorer = nlp_loaded.evaluate(corpus.dev_docs(nlp_loaded, gold_preproc=False))
|
||||
nlp_loaded = lang_class(pipeline=pipeline)
|
||||
nlp_loaded = nlp_loaded.from_disk(epoch_model_path)
|
||||
scorer = nlp_loaded.evaluate(
|
||||
corpus.dev_docs(
|
||||
nlp_loaded,
|
||||
gold_preproc=False))
|
||||
util.set_env_log(True)
|
||||
print_progress(i, losses, scorer.scores)
|
||||
finally:
|
||||
print("Saving model...")
|
||||
|
|
Loading…
Reference in New Issue