diff --git a/spacy/default_config_pretraining.cfg b/spacy/default_config_pretraining.cfg index 16f767772..d70ecf04c 100644 --- a/spacy/default_config_pretraining.cfg +++ b/spacy/default_config_pretraining.cfg @@ -5,6 +5,7 @@ raw_text = null max_epochs = 1000 dropout = 0.2 n_save_every = null +n_save_epoch = null component = "tok2vec" layer = "" corpus = "corpora.pretrain" diff --git a/spacy/schemas.py b/spacy/schemas.py index 83623b104..bd3f0ecf0 100644 --- a/spacy/schemas.py +++ b/spacy/schemas.py @@ -351,7 +351,8 @@ class ConfigSchemaPretrain(BaseModel): # fmt: off max_epochs: StrictInt = Field(..., title="Maximum number of epochs to train for") dropout: StrictFloat = Field(..., title="Dropout rate") - n_save_every: Optional[StrictInt] = Field(..., title="Saving frequency") + n_save_every: Optional[StrictInt] = Field(..., title="Saving additional temporary model after n batches within an epoch") + n_save_epoch: Optional[StrictInt] = Field(..., title="Saving model after every n epoch") optimizer: Optimizer = Field(..., title="The optimizer to use") corpus: StrictStr = Field(..., title="Path in the config to the training data") batcher: Batcher = Field(..., title="Batcher for the training data") diff --git a/spacy/training/pretrain.py b/spacy/training/pretrain.py index 6d7850212..88f1dc0bb 100644 --- a/spacy/training/pretrain.py +++ b/spacy/training/pretrain.py @@ -48,7 +48,10 @@ def pretrain( objective = model.attrs["loss"] # TODO: move this to logger function? tracker = ProgressTracker(frequency=10000) - msg.divider(f"Pre-training tok2vec layer - starting at epoch {epoch_resume}") + if P["n_save_epoch"]: + msg.divider(f"Pre-training tok2vec layer - starting at epoch {epoch_resume} - saving every {P['n_save_epoch']} epoch") + else: + msg.divider(f"Pre-training tok2vec layer - starting at epoch {epoch_resume}") row_settings = {"widths": (3, 10, 10, 6, 4), "aligns": ("r", "r", "r", "r", "r")} msg.row(("#", "# Words", "Total Loss", "Loss", "w/s"), **row_settings) @@ -77,7 +80,12 @@ def pretrain( msg.row(progress, **row_settings) if P["n_save_every"] and (batch_id % P["n_save_every"] == 0): _save_model(epoch, is_temp=True) - _save_model(epoch) + + if P["n_save_epoch"]: + if epoch % P["n_save_epoch"] == 0 or epoch == P["max_epochs"] - 1: + _save_model(epoch) + else: + _save_model(epoch) tracker.epoch_loss = 0.0