Fix inference of epoch_resume (#9084)

* Fix inference of epoch_resume

When an epoch_resume value is not specified individually, it can often
be inferred from the filename. The value inference code was there but
the value wasn't passed back to the training loop.

This also adds a specific error in the case where no epoch_resume value
is provided and it can't be inferred from the filename.

* Add new error

* Always use the epoch resume value if specified

Before this the value in the filename was used if found
This commit is contained in:
Paul O'Leary McCann 2021-09-01 14:17:42 +09:00 committed by GitHub
parent a17b06d18b
commit f803a84571
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 18 additions and 10 deletions

View File

@ -869,6 +869,8 @@ class Errors:
E1019 = ("`noun_chunks` requires the pos tagging, which requires a " E1019 = ("`noun_chunks` requires the pos tagging, which requires a "
"statistical model to be installed and loaded. For more info, see " "statistical model to be installed and loaded. For more info, see "
"the documentation:\nhttps://spacy.io/usage/models") "the documentation:\nhttps://spacy.io/usage/models")
E1020 = ("No `epoch_resume` value specified and could not infer one from "
"filename. Specify an epoch to resume from.")
# Deprecated model shortcuts, only used in errors and warnings # Deprecated model shortcuts, only used in errors and warnings

View File

@ -41,10 +41,11 @@ def pretrain(
optimizer = P["optimizer"] optimizer = P["optimizer"]
# Load in pretrained weights to resume from # Load in pretrained weights to resume from
if resume_path is not None: if resume_path is not None:
_resume_model(model, resume_path, epoch_resume, silent=silent) epoch_resume = _resume_model(model, resume_path, epoch_resume, silent=silent)
else: else:
# Without '--resume-path' the '--epoch-resume' argument is ignored # Without '--resume-path' the '--epoch-resume' argument is ignored
epoch_resume = 0 epoch_resume = 0
objective = model.attrs["loss"] objective = model.attrs["loss"]
# TODO: move this to logger function? # TODO: move this to logger function?
tracker = ProgressTracker(frequency=10000) tracker = ProgressTracker(frequency=10000)
@ -93,20 +94,25 @@ def ensure_docs(examples_or_docs: Iterable[Union[Doc, Example]]) -> List[Doc]:
def _resume_model( def _resume_model(
model: Model, resume_path: Path, epoch_resume: int, silent: bool = True model: Model, resume_path: Path, epoch_resume: int, silent: bool = True
) -> None: ) -> int:
msg = Printer(no_print=silent) msg = Printer(no_print=silent)
msg.info(f"Resume training tok2vec from: {resume_path}") msg.info(f"Resume training tok2vec from: {resume_path}")
with resume_path.open("rb") as file_: with resume_path.open("rb") as file_:
weights_data = file_.read() weights_data = file_.read()
model.get_ref("tok2vec").from_bytes(weights_data) model.get_ref("tok2vec").from_bytes(weights_data)
# Parse the epoch number from the given weight file
model_name = re.search(r"model\d+\.bin", str(resume_path)) if epoch_resume is None:
if model_name: # Parse the epoch number from the given weight file
# Default weight file name so read epoch_start from it by cutting off 'model' and '.bin' model_name = re.search(r"model\d+\.bin", str(resume_path))
epoch_resume = int(model_name.group(0)[5:][:-4]) + 1 if model_name:
msg.info(f"Resuming from epoch: {epoch_resume}") # Default weight file name so read epoch_start from it by cutting off 'model' and '.bin'
else: epoch_resume = int(model_name.group(0)[5:][:-4]) + 1
msg.info(f"Resuming from epoch: {epoch_resume}") else:
# No epoch given and couldn't infer it
raise ValueError(Errors.E1020)
msg.info(f"Resuming from epoch: {epoch_resume}")
return epoch_resume
def make_update( def make_update(