diff --git a/spacy/errors.py b/spacy/errors.py index a206826ff..0e1a294c3 100644 --- a/spacy/errors.py +++ b/spacy/errors.py @@ -869,6 +869,8 @@ class Errors: E1019 = ("`noun_chunks` requires the pos tagging, which requires a " "statistical model to be installed and loaded. For more info, see " "the documentation:\nhttps://spacy.io/usage/models") + E1020 = ("No `epoch_resume` value specified and could not infer one from " + "filename. Specify an epoch to resume from.") # Deprecated model shortcuts, only used in errors and warnings diff --git a/spacy/training/pretrain.py b/spacy/training/pretrain.py index 6d7850212..0228f2947 100644 --- a/spacy/training/pretrain.py +++ b/spacy/training/pretrain.py @@ -41,10 +41,11 @@ def pretrain( optimizer = P["optimizer"] # Load in pretrained weights to resume from if resume_path is not None: - _resume_model(model, resume_path, epoch_resume, silent=silent) + epoch_resume = _resume_model(model, resume_path, epoch_resume, silent=silent) else: # Without '--resume-path' the '--epoch-resume' argument is ignored epoch_resume = 0 + objective = model.attrs["loss"] # TODO: move this to logger function? tracker = ProgressTracker(frequency=10000) @@ -93,20 +94,25 @@ def ensure_docs(examples_or_docs: Iterable[Union[Doc, Example]]) -> List[Doc]: def _resume_model( model: Model, resume_path: Path, epoch_resume: int, silent: bool = True -) -> None: +) -> int: msg = Printer(no_print=silent) msg.info(f"Resume training tok2vec from: {resume_path}") with resume_path.open("rb") as file_: weights_data = file_.read() model.get_ref("tok2vec").from_bytes(weights_data) - # Parse the epoch number from the given weight file - model_name = re.search(r"model\d+\.bin", str(resume_path)) - if model_name: - # Default weight file name so read epoch_start from it by cutting off 'model' and '.bin' - epoch_resume = int(model_name.group(0)[5:][:-4]) + 1 - msg.info(f"Resuming from epoch: {epoch_resume}") - else: - msg.info(f"Resuming from epoch: {epoch_resume}") + + if epoch_resume is None: + # Parse the epoch number from the given weight file + model_name = re.search(r"model\d+\.bin", str(resume_path)) + if model_name: + # Default weight file name so read epoch_start from it by cutting off 'model' and '.bin' + epoch_resume = int(model_name.group(0)[5:][:-4]) + 1 + else: + # No epoch given and couldn't infer it + raise ValueError(Errors.E1020) + + msg.info(f"Resuming from epoch: {epoch_resume}") + return epoch_resume def make_update(