Make spacy train respect LOG_FRIENDLY

This commit is contained in:
Matthew Honnibal 2018-12-10 09:46:53 +01:00
parent 6936ca1664
commit b1c8731b4d
1 changed files with 15 additions and 3 deletions

View File

@ -2,6 +2,7 @@
from __future__ import unicode_literals, division, print_function from __future__ import unicode_literals, division, print_function
import plac import plac
import os
from pathlib import Path from pathlib import Path
import tqdm import tqdm
from thinc.neural._classes.model import Model from thinc.neural._classes.model import Model
@ -9,7 +10,8 @@ from timeit import default_timer as timer
import shutil import shutil
import srsly import srsly
from wasabi import Printer from wasabi import Printer
from thinc.rates import slanted_triangular import contextlib
import random
from .._ml import create_default_optimizer from .._ml import create_default_optimizer
from ..attrs import PROB, IS_OOV, CLUSTER, LANG from ..attrs import PROB, IS_OOV, CLUSTER, LANG
@ -207,7 +209,7 @@ def train(
nlp, noise_level=noise_level, gold_preproc=gold_preproc, max_length=0 nlp, noise_level=noise_level, gold_preproc=gold_preproc, max_length=0
) )
words_seen = 0 words_seen = 0
with tqdm.tqdm(total=n_train_words, leave=False) as pbar: with _create_progress_bar(n_train_words) as pbar:
losses = {} losses = {}
for batch in util.minibatch_by_words(train_docs, size=batch_sizes): for batch in util.minibatch_by_words(train_docs, size=batch_sizes):
if not batch: if not batch:
@ -220,6 +222,7 @@ def train(
drop=next(dropout_rates), drop=next(dropout_rates),
losses=losses, losses=losses,
) )
if not int(os.environ.get('LOG_FRIENDLY', 0)):
pbar.update(sum(len(doc) for doc in docs)) pbar.update(sum(len(doc) for doc in docs))
words_seen += sum(len(doc) for doc in docs) words_seen += sum(len(doc) for doc in docs)
with nlp.use_params(optimizer.averages): with nlp.use_params(optimizer.averages):
@ -281,6 +284,15 @@ def train(
msg.good("Created best model", best_model_path) msg.good("Created best model", best_model_path)
@contextlib.contextmanager
def _create_progress_bar(total):
if int(os.environ.get('LOG_FRIENDLY', 0)):
yield
else:
pbar = tqdm.tqdm(total=total, leave=False)
yield pbar
def _load_vectors(nlp, vectors): def _load_vectors(nlp, vectors):
util.load_model(vectors, vocab=nlp.vocab) util.load_model(vectors, vocab=nlp.vocab)
for lex in nlp.vocab: for lex in nlp.vocab: