Support option to not batch by number of words

This commit is contained in:
Matthw Honnibal 2020-07-08 11:26:54 +02:00
parent 433dc3c9c9
commit 42e1109def
1 changed files with 14 additions and 6 deletions

View File

@ -203,7 +203,8 @@ def train(
msg.info(f"Initializing the nlp pipeline: {nlp.pipe_names}")
train_examples = list(
corpus.train_dataset(
nlp, shuffle=False, gold_preproc=training["gold_preproc"]
nlp, shuffle=False, gold_preproc=training["gold_preproc"],
max_length=training["max_length"]
)
)
nlp.begin_training(lambda: train_examples)
@ -306,11 +307,18 @@ def create_train_batches(nlp, corpus, cfg):
if len(train_examples) == 0:
raise ValueError(Errors.E988)
epoch += 1
if cfg.get("batch_by_words"):
batches = util.minibatch_by_words(
train_examples,
size=cfg["batch_size"],
discard_oversize=cfg["discard_oversize"],
)
else:
batches = util.minibatch(
train_examples,
size=cfg["batch_size"],
)
# make sure the minibatch_by_words result is not empty, or we'll have an infinite training loop
try:
first = next(batches)