Initialize ModelCheckpoint state as early as possible (#11108)

This commit is contained in:
Adrian Wälchli 2021-12-17 00:18:29 +01:00 committed by GitHub
parent 262aefc8df
commit e19d93f69e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 8 additions and 4 deletions

View File

@ -247,13 +247,16 @@ class ModelCheckpoint(Callback):
save_on_train_epoch_end=self._save_on_train_epoch_end,
)
def on_pretrain_routine_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
"""When pretrain routine starts we resolve the ckpt dir on the fly."""
def setup(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", stage: Optional[str] = None) -> None:
# NOTE: setting these attributes needs to happen as early as possible BEFORE reloading callback states,
# because the attributes are part of the state_key which needs to be fully defined before reloading.
if self._save_on_train_epoch_end is None:
# if the user runs validation multiple times per training epoch or multiple training epochs without
# validation, then we run after validation instead of on train epoch end
self._save_on_train_epoch_end = trainer.val_check_interval == 1.0 and trainer.check_val_every_n_epoch == 1
def on_pretrain_routine_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
"""When pretrain routine starts we build the ckpt dir on the fly."""
self.__resolve_ckpt_dir(trainer)
if trainer.is_global_zero:
self.__warn_if_dir_not_empty(self.dirpath)

View File

@ -246,10 +246,11 @@ def test_callbacks_state_fit_ckpt_path(tmpdir):
checkpoint = ModelCheckpoint(dirpath=tmpdir, monitor="val_loss", save_last=True)
trainer_args = dict(
default_root_dir=tmpdir,
max_steps=1,
limit_train_batches=1,
limit_val_batches=2,
max_epochs=1,
logger=False,
callbacks=[checkpoint, callback_capture],
limit_val_batches=2,
)
assert checkpoint.best_model_path == ""
assert checkpoint.best_model_score is None