mirror of https://github.com/explosion/spaCy.git
Support option to not batch by number of words
This commit is contained in:
parent
433dc3c9c9
commit
42e1109def
|
@ -203,7 +203,8 @@ def train(
|
|||
msg.info(f"Initializing the nlp pipeline: {nlp.pipe_names}")
|
||||
train_examples = list(
|
||||
corpus.train_dataset(
|
||||
nlp, shuffle=False, gold_preproc=training["gold_preproc"]
|
||||
nlp, shuffle=False, gold_preproc=training["gold_preproc"],
|
||||
max_length=training["max_length"]
|
||||
)
|
||||
)
|
||||
nlp.begin_training(lambda: train_examples)
|
||||
|
@ -306,11 +307,18 @@ def create_train_batches(nlp, corpus, cfg):
|
|||
if len(train_examples) == 0:
|
||||
raise ValueError(Errors.E988)
|
||||
epoch += 1
|
||||
batches = util.minibatch_by_words(
|
||||
train_examples,
|
||||
size=cfg["batch_size"],
|
||||
discard_oversize=cfg["discard_oversize"],
|
||||
)
|
||||
if cfg.get("batch_by_words"):
|
||||
batches = util.minibatch_by_words(
|
||||
train_examples,
|
||||
size=cfg["batch_size"],
|
||||
discard_oversize=cfg["discard_oversize"],
|
||||
)
|
||||
else:
|
||||
batches = util.minibatch(
|
||||
train_examples,
|
||||
size=cfg["batch_size"],
|
||||
)
|
||||
|
||||
# make sure the minibatch_by_words result is not empty, or we'll have an infinite training loop
|
||||
try:
|
||||
first = next(batches)
|
||||
|
|
Loading…
Reference in New Issue