mirror of https://github.com/explosion/spaCy.git
Make spacy train respect LOG_FRIENDLY
This commit is contained in:
parent
6936ca1664
commit
b1c8731b4d
|
@ -2,6 +2,7 @@
|
||||||
from __future__ import unicode_literals, division, print_function
|
from __future__ import unicode_literals, division, print_function
|
||||||
|
|
||||||
import plac
|
import plac
|
||||||
|
import os
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
import tqdm
|
import tqdm
|
||||||
from thinc.neural._classes.model import Model
|
from thinc.neural._classes.model import Model
|
||||||
|
@ -9,7 +10,8 @@ from timeit import default_timer as timer
|
||||||
import shutil
|
import shutil
|
||||||
import srsly
|
import srsly
|
||||||
from wasabi import Printer
|
from wasabi import Printer
|
||||||
from thinc.rates import slanted_triangular
|
import contextlib
|
||||||
|
import random
|
||||||
|
|
||||||
from .._ml import create_default_optimizer
|
from .._ml import create_default_optimizer
|
||||||
from ..attrs import PROB, IS_OOV, CLUSTER, LANG
|
from ..attrs import PROB, IS_OOV, CLUSTER, LANG
|
||||||
|
@ -207,7 +209,7 @@ def train(
|
||||||
nlp, noise_level=noise_level, gold_preproc=gold_preproc, max_length=0
|
nlp, noise_level=noise_level, gold_preproc=gold_preproc, max_length=0
|
||||||
)
|
)
|
||||||
words_seen = 0
|
words_seen = 0
|
||||||
with tqdm.tqdm(total=n_train_words, leave=False) as pbar:
|
with _create_progress_bar(n_train_words) as pbar:
|
||||||
losses = {}
|
losses = {}
|
||||||
for batch in util.minibatch_by_words(train_docs, size=batch_sizes):
|
for batch in util.minibatch_by_words(train_docs, size=batch_sizes):
|
||||||
if not batch:
|
if not batch:
|
||||||
|
@ -220,6 +222,7 @@ def train(
|
||||||
drop=next(dropout_rates),
|
drop=next(dropout_rates),
|
||||||
losses=losses,
|
losses=losses,
|
||||||
)
|
)
|
||||||
|
if not int(os.environ.get('LOG_FRIENDLY', 0)):
|
||||||
pbar.update(sum(len(doc) for doc in docs))
|
pbar.update(sum(len(doc) for doc in docs))
|
||||||
words_seen += sum(len(doc) for doc in docs)
|
words_seen += sum(len(doc) for doc in docs)
|
||||||
with nlp.use_params(optimizer.averages):
|
with nlp.use_params(optimizer.averages):
|
||||||
|
@ -281,6 +284,15 @@ def train(
|
||||||
msg.good("Created best model", best_model_path)
|
msg.good("Created best model", best_model_path)
|
||||||
|
|
||||||
|
|
||||||
|
@contextlib.contextmanager
|
||||||
|
def _create_progress_bar(total):
|
||||||
|
if int(os.environ.get('LOG_FRIENDLY', 0)):
|
||||||
|
yield
|
||||||
|
else:
|
||||||
|
pbar = tqdm.tqdm(total=total, leave=False)
|
||||||
|
yield pbar
|
||||||
|
|
||||||
|
|
||||||
def _load_vectors(nlp, vectors):
|
def _load_vectors(nlp, vectors):
|
||||||
util.load_model(vectors, vocab=nlp.vocab)
|
util.load_model(vectors, vocab=nlp.vocab)
|
||||||
for lex in nlp.vocab:
|
for lex in nlp.vocab:
|
||||||
|
|
Loading…
Reference in New Issue