mirror of https://github.com/explosion/spaCy.git
Tidy up train-from-config a bit
This commit is contained in:
parent
fda7355508
commit
60e8da4813
|
@ -193,10 +193,11 @@ def train_from_config(
|
||||||
optimizer,
|
optimizer,
|
||||||
train_batches,
|
train_batches,
|
||||||
evaluate,
|
evaluate,
|
||||||
training["dropout"],
|
dropout=training["dropout"],
|
||||||
training["patience"],
|
accumulate_gradient=training["accumulate_gradient"],
|
||||||
training["eval_frequency"],
|
patience=training.get("patience", 0),
|
||||||
training["accumulate_gradient"]
|
max_steps=training.get("max_steps", 0),
|
||||||
|
eval_frequency=training["eval_frequency"],
|
||||||
)
|
)
|
||||||
|
|
||||||
msg.info(f"Training. Initial learn rate: {optimizer.learn_rate}")
|
msg.info(f"Training. Initial learn rate: {optimizer.learn_rate}")
|
||||||
|
@ -214,17 +215,17 @@ def train_from_config(
|
||||||
progress = tqdm.tqdm(total=training["eval_frequency"], leave=False)
|
progress = tqdm.tqdm(total=training["eval_frequency"], leave=False)
|
||||||
finally:
|
finally:
|
||||||
if output_path is not None:
|
if output_path is not None:
|
||||||
with nlp.use_params(optimizer.averages):
|
final_model_path = output_path / "model-final"
|
||||||
final_model_path = output_path / "model-final"
|
if optimizer.averages:
|
||||||
|
with nlp.use_params(optimizer.averages):
|
||||||
|
nlp.to_disk(final_model_path)
|
||||||
|
else:
|
||||||
nlp.to_disk(final_model_path)
|
nlp.to_disk(final_model_path)
|
||||||
msg.good("Saved model to output directory", final_model_path)
|
msg.good("Saved model to output directory", final_model_path)
|
||||||
# with msg.loading("Creating best model..."):
|
|
||||||
# best_model_path = _collate_best_model(meta, output_path, nlp.pipe_names)
|
|
||||||
# msg.good("Created best model", best_model_path)
|
|
||||||
|
|
||||||
|
|
||||||
def create_train_batches(nlp, corpus, cfg):
|
def create_train_batches(nlp, corpus, cfg):
|
||||||
is_first = True
|
epochs_todo = cfg.get("max_epochs", 0)
|
||||||
while True:
|
while True:
|
||||||
train_examples = list(corpus.train_dataset(
|
train_examples = list(corpus.train_dataset(
|
||||||
nlp,
|
nlp,
|
||||||
|
@ -240,6 +241,11 @@ def create_train_batches(nlp, corpus, cfg):
|
||||||
batches = util.minibatch_by_words(train_examples, size=cfg["batch_size"])
|
batches = util.minibatch_by_words(train_examples, size=cfg["batch_size"])
|
||||||
for batch in batches:
|
for batch in batches:
|
||||||
yield batch
|
yield batch
|
||||||
|
epochs_todo -= 1
|
||||||
|
# We intentionally compare exactly to 0 here, so that max_epochs < 1
|
||||||
|
# will not break.
|
||||||
|
if epochs_todo == 0:
|
||||||
|
break
|
||||||
|
|
||||||
|
|
||||||
def create_evaluation_callback(nlp, optimizer, corpus, cfg):
|
def create_evaluation_callback(nlp, optimizer, corpus, cfg):
|
||||||
|
@ -270,8 +276,8 @@ def create_evaluation_callback(nlp, optimizer, corpus, cfg):
|
||||||
|
|
||||||
|
|
||||||
def train_while_improving(
|
def train_while_improving(
|
||||||
nlp, optimizer, train_data, evaluate, dropout, patience, eval_frequency,
|
nlp, optimizer, train_data, evaluate, *, dropout, eval_frequency,
|
||||||
accumulate_gradient
|
accumulate_gradient=1, patience=0, max_steps=0
|
||||||
):
|
):
|
||||||
"""Train until an evaluation stops improving. Works as a generator,
|
"""Train until an evaluation stops improving. Works as a generator,
|
||||||
with each iteration yielding a tuple `(batch, info, is_best_checkpoint)`,
|
with each iteration yielding a tuple `(batch, info, is_best_checkpoint)`,
|
||||||
|
@ -281,6 +287,7 @@ def train_while_improving(
|
||||||
|
|
||||||
Positional arguments:
|
Positional arguments:
|
||||||
nlp: The spaCy pipeline to evaluate.
|
nlp: The spaCy pipeline to evaluate.
|
||||||
|
optimizer: The optimizer callable.
|
||||||
train_data (Iterable[Batch]): A generator of batches, with the training
|
train_data (Iterable[Batch]): A generator of batches, with the training
|
||||||
data. Each batch should be a Sized[Tuple[Input, Annot]]. The training
|
data. Each batch should be a Sized[Tuple[Input, Annot]]. The training
|
||||||
data iterable needs to take care of iterating over the epochs and
|
data iterable needs to take care of iterating over the epochs and
|
||||||
|
@ -344,9 +351,12 @@ def train_while_improving(
|
||||||
yield batch, info, is_best_checkpoint
|
yield batch, info, is_best_checkpoint
|
||||||
if is_best_checkpoint is not None:
|
if is_best_checkpoint is not None:
|
||||||
losses = {}
|
losses = {}
|
||||||
# Stop if no improvement in `patience` updates
|
# Stop if no improvement in `patience` updates (if specified)
|
||||||
best_score, best_step = max(results)
|
best_score, best_step = max(results)
|
||||||
if (step - best_step) >= patience:
|
if patience and (step - best_step) >= patience:
|
||||||
|
break
|
||||||
|
# Stop if we've exhausted our max steps (if specified)
|
||||||
|
if max_steps and (step * accumulate_gradient) >= max_steps:
|
||||||
break
|
break
|
||||||
|
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue