Initialize ModelCheckpoint state as early as possible (#11108)
This commit is contained in:
parent
262aefc8df
commit
e19d93f69e
|
@ -247,13 +247,16 @@ class ModelCheckpoint(Callback):
|
|||
save_on_train_epoch_end=self._save_on_train_epoch_end,
|
||||
)
|
||||
|
||||
def on_pretrain_routine_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
|
||||
"""When pretrain routine starts we resolve the ckpt dir on the fly."""
|
||||
def setup(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", stage: Optional[str] = None) -> None:
|
||||
# NOTE: setting these attributes needs to happen as early as possible BEFORE reloading callback states,
|
||||
# because the attributes are part of the state_key which needs to be fully defined before reloading.
|
||||
if self._save_on_train_epoch_end is None:
|
||||
# if the user runs validation multiple times per training epoch or multiple training epochs without
|
||||
# validation, then we run after validation instead of on train epoch end
|
||||
self._save_on_train_epoch_end = trainer.val_check_interval == 1.0 and trainer.check_val_every_n_epoch == 1
|
||||
|
||||
def on_pretrain_routine_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
|
||||
"""When pretrain routine starts we build the ckpt dir on the fly."""
|
||||
self.__resolve_ckpt_dir(trainer)
|
||||
if trainer.is_global_zero:
|
||||
self.__warn_if_dir_not_empty(self.dirpath)
|
||||
|
|
|
@ -246,10 +246,11 @@ def test_callbacks_state_fit_ckpt_path(tmpdir):
|
|||
checkpoint = ModelCheckpoint(dirpath=tmpdir, monitor="val_loss", save_last=True)
|
||||
trainer_args = dict(
|
||||
default_root_dir=tmpdir,
|
||||
max_steps=1,
|
||||
limit_train_batches=1,
|
||||
limit_val_batches=2,
|
||||
max_epochs=1,
|
||||
logger=False,
|
||||
callbacks=[checkpoint, callback_capture],
|
||||
limit_val_batches=2,
|
||||
)
|
||||
assert checkpoint.best_model_path == ""
|
||||
assert checkpoint.best_model_score is None
|
||||
|
|
Loading…
Reference in New Issue