mirror of https://github.com/explosion/spaCy.git
Fix inference of epoch_resume (#9084)
* Fix inference of epoch_resume When an epoch_resume value is not specified individually, it can often be inferred from the filename. The value inference code was there but the value wasn't passed back to the training loop. This also adds a specific error in the case where no epoch_resume value is provided and it can't be inferred from the filename. * Add new error * Always use the epoch resume value if specified Before this the value in the filename was used if found
This commit is contained in:
parent
a17b06d18b
commit
f803a84571
|
@ -869,6 +869,8 @@ class Errors:
|
|||
E1019 = ("`noun_chunks` requires the pos tagging, which requires a "
|
||||
"statistical model to be installed and loaded. For more info, see "
|
||||
"the documentation:\nhttps://spacy.io/usage/models")
|
||||
E1020 = ("No `epoch_resume` value specified and could not infer one from "
|
||||
"filename. Specify an epoch to resume from.")
|
||||
|
||||
|
||||
# Deprecated model shortcuts, only used in errors and warnings
|
||||
|
|
|
@ -41,10 +41,11 @@ def pretrain(
|
|||
optimizer = P["optimizer"]
|
||||
# Load in pretrained weights to resume from
|
||||
if resume_path is not None:
|
||||
_resume_model(model, resume_path, epoch_resume, silent=silent)
|
||||
epoch_resume = _resume_model(model, resume_path, epoch_resume, silent=silent)
|
||||
else:
|
||||
# Without '--resume-path' the '--epoch-resume' argument is ignored
|
||||
epoch_resume = 0
|
||||
|
||||
objective = model.attrs["loss"]
|
||||
# TODO: move this to logger function?
|
||||
tracker = ProgressTracker(frequency=10000)
|
||||
|
@ -93,20 +94,25 @@ def ensure_docs(examples_or_docs: Iterable[Union[Doc, Example]]) -> List[Doc]:
|
|||
|
||||
def _resume_model(
|
||||
model: Model, resume_path: Path, epoch_resume: int, silent: bool = True
|
||||
) -> None:
|
||||
) -> int:
|
||||
msg = Printer(no_print=silent)
|
||||
msg.info(f"Resume training tok2vec from: {resume_path}")
|
||||
with resume_path.open("rb") as file_:
|
||||
weights_data = file_.read()
|
||||
model.get_ref("tok2vec").from_bytes(weights_data)
|
||||
# Parse the epoch number from the given weight file
|
||||
model_name = re.search(r"model\d+\.bin", str(resume_path))
|
||||
if model_name:
|
||||
# Default weight file name so read epoch_start from it by cutting off 'model' and '.bin'
|
||||
epoch_resume = int(model_name.group(0)[5:][:-4]) + 1
|
||||
msg.info(f"Resuming from epoch: {epoch_resume}")
|
||||
else:
|
||||
msg.info(f"Resuming from epoch: {epoch_resume}")
|
||||
|
||||
if epoch_resume is None:
|
||||
# Parse the epoch number from the given weight file
|
||||
model_name = re.search(r"model\d+\.bin", str(resume_path))
|
||||
if model_name:
|
||||
# Default weight file name so read epoch_start from it by cutting off 'model' and '.bin'
|
||||
epoch_resume = int(model_name.group(0)[5:][:-4]) + 1
|
||||
else:
|
||||
# No epoch given and couldn't infer it
|
||||
raise ValueError(Errors.E1020)
|
||||
|
||||
msg.info(f"Resuming from epoch: {epoch_resume}")
|
||||
return epoch_resume
|
||||
|
||||
|
||||
def make_update(
|
||||
|
|
Loading…
Reference in New Issue