From e19d93f69ee78afda68538981d0e0acbb9dd0961 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Fri, 17 Dec 2021 00:18:29 +0100 Subject: [PATCH] Initialize ModelCheckpoint state as early as possible (#11108) --- pytorch_lightning/callbacks/model_checkpoint.py | 7 +++++-- tests/models/test_restore.py | 5 +++-- 2 files changed, 8 insertions(+), 4 deletions(-) diff --git a/pytorch_lightning/callbacks/model_checkpoint.py b/pytorch_lightning/callbacks/model_checkpoint.py index 37db4fb62d..7707f8a377 100644 --- a/pytorch_lightning/callbacks/model_checkpoint.py +++ b/pytorch_lightning/callbacks/model_checkpoint.py @@ -247,13 +247,16 @@ class ModelCheckpoint(Callback): save_on_train_epoch_end=self._save_on_train_epoch_end, ) - def on_pretrain_routine_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: - """When pretrain routine starts we resolve the ckpt dir on the fly.""" + def setup(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", stage: Optional[str] = None) -> None: + # NOTE: setting these attributes needs to happen as early as possible BEFORE reloading callback states, + # because the attributes are part of the state_key which needs to be fully defined before reloading. if self._save_on_train_epoch_end is None: # if the user runs validation multiple times per training epoch or multiple training epochs without # validation, then we run after validation instead of on train epoch end self._save_on_train_epoch_end = trainer.val_check_interval == 1.0 and trainer.check_val_every_n_epoch == 1 + def on_pretrain_routine_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: + """When pretrain routine starts we build the ckpt dir on the fly.""" self.__resolve_ckpt_dir(trainer) if trainer.is_global_zero: self.__warn_if_dir_not_empty(self.dirpath) diff --git a/tests/models/test_restore.py b/tests/models/test_restore.py index 1139e6fb5e..6e8b6e5926 100644 --- a/tests/models/test_restore.py +++ b/tests/models/test_restore.py @@ -246,10 +246,11 @@ def test_callbacks_state_fit_ckpt_path(tmpdir): checkpoint = ModelCheckpoint(dirpath=tmpdir, monitor="val_loss", save_last=True) trainer_args = dict( default_root_dir=tmpdir, - max_steps=1, + limit_train_batches=1, + limit_val_batches=2, + max_epochs=1, logger=False, callbacks=[checkpoint, callback_capture], - limit_val_batches=2, ) assert checkpoint.best_model_path == "" assert checkpoint.best_model_score is None