Make checkpointing on train epoch end condition dynamic (#15300)
Co-authored-by: Carlos Mocholí <carlossmocholi@gmail.com> Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com>
This commit is contained in:
parent
a00dfc850d
commit
f4ca5623d2
|
@ -31,7 +31,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
|
|||
|
||||
- `Trainer.{validate,test,predict}(ckpt_path=...)` no longer restores the `Trainer.global_step` and `trainer.current_epoch` value from the checkpoints - From now on, only `Trainer.fit` will restore this value ([#15532](https://github.com/Lightning-AI/lightning/pull/15532))
|
||||
|
||||
-
|
||||
- The `ModelCheckpoint.save_on_train_epoch_end` attribute is now computed dynamically every epoch, accounting for changes to the validation dataloaders ([#15300](https://github.com/Lightning-AI/lightning/pull/15300))
|
||||
|
||||
|
||||
### Deprecated
|
||||
|
|
|
@ -262,13 +262,6 @@ class ModelCheckpoint(Checkpoint):
|
|||
if trainer.is_global_zero and stage == "fit":
|
||||
self.__warn_if_dir_not_empty(self.dirpath)
|
||||
|
||||
# NOTE: setting these attributes needs to happen as early as possible BEFORE reloading callback states,
|
||||
# because the attributes are part of the state_key which needs to be fully defined before reloading.
|
||||
if self._save_on_train_epoch_end is None:
|
||||
# if the user runs validation multiple times per training epoch or multiple training epochs without
|
||||
# validation, then we run after validation instead of on train epoch end
|
||||
self._save_on_train_epoch_end = trainer.val_check_interval == 1.0 and trainer.check_val_every_n_epoch == 1
|
||||
|
||||
def on_train_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
|
||||
self._last_time_checked = time.monotonic()
|
||||
|
||||
|
@ -306,7 +299,7 @@ class ModelCheckpoint(Checkpoint):
|
|||
|
||||
def on_train_epoch_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
|
||||
"""Save a checkpoint at the end of the training epoch."""
|
||||
if not self._should_skip_saving_checkpoint(trainer) and self._save_on_train_epoch_end:
|
||||
if not self._should_skip_saving_checkpoint(trainer) and self._should_save_on_train_epoch_end(trainer):
|
||||
monitor_candidates = self._monitor_candidates(trainer)
|
||||
if self._every_n_epochs >= 1 and (trainer.current_epoch + 1) % self._every_n_epochs == 0:
|
||||
self._save_topk_checkpoint(trainer, monitor_candidates)
|
||||
|
@ -314,7 +307,7 @@ class ModelCheckpoint(Checkpoint):
|
|||
|
||||
def on_validation_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
|
||||
"""Save a checkpoint at the end of the validation stage."""
|
||||
if not self._should_skip_saving_checkpoint(trainer) and not self._save_on_train_epoch_end:
|
||||
if not self._should_skip_saving_checkpoint(trainer) and not self._should_save_on_train_epoch_end(trainer):
|
||||
monitor_candidates = self._monitor_candidates(trainer)
|
||||
if self._every_n_epochs >= 1 and (trainer.current_epoch + 1) % self._every_n_epochs == 0:
|
||||
self._save_topk_checkpoint(trainer, monitor_candidates)
|
||||
|
@ -390,6 +383,23 @@ class ModelCheckpoint(Checkpoint):
|
|||
or self._last_global_step_saved == trainer.global_step # already saved at the last step
|
||||
)
|
||||
|
||||
def _should_save_on_train_epoch_end(self, trainer: "pl.Trainer") -> bool:
|
||||
if self._save_on_train_epoch_end is not None:
|
||||
return self._save_on_train_epoch_end
|
||||
|
||||
# if `check_val_every_n_epoch != 1`, we can't say when the validation dataloader will be loaded
|
||||
# so let's not enforce saving at every training epoch end
|
||||
if trainer.check_val_every_n_epoch != 1:
|
||||
return False
|
||||
|
||||
# no validation means save on train epoch end
|
||||
if sum(trainer.num_val_batches) == 0:
|
||||
return True
|
||||
|
||||
# if the user runs validation multiple times per training epoch, then we run after validation
|
||||
# instead of on train epoch end
|
||||
return trainer.val_check_interval == 1.0
|
||||
|
||||
def __validate_init_configuration(self) -> None:
|
||||
if self.save_top_k < -1:
|
||||
raise MisconfigurationException(f"Invalid value for save_top_k={self.save_top_k}. Must be >= -1")
|
||||
|
|
|
@ -168,7 +168,7 @@ def test_model_checkpoint_score_and_ckpt(
|
|||
|
||||
mc_specific_data = chk["callbacks"][
|
||||
f"ModelCheckpoint{{'monitor': '{monitor}', 'mode': 'min', 'every_n_train_steps': 0, 'every_n_epochs': 1,"
|
||||
" 'train_time_interval': None, 'save_on_train_epoch_end': True}"
|
||||
" 'train_time_interval': None, 'save_on_train_epoch_end': None}"
|
||||
]
|
||||
assert mc_specific_data["dirpath"] == checkpoint.dirpath
|
||||
assert mc_specific_data["monitor"] == monitor
|
||||
|
@ -269,7 +269,7 @@ def test_model_checkpoint_score_and_ckpt_val_check_interval(
|
|||
|
||||
mc_specific_data = chk["callbacks"][
|
||||
f"ModelCheckpoint{{'monitor': '{monitor}', 'mode': 'min', 'every_n_train_steps': 0, 'every_n_epochs': 1,"
|
||||
" 'train_time_interval': None, 'save_on_train_epoch_end': False}"
|
||||
" 'train_time_interval': None, 'save_on_train_epoch_end': None}"
|
||||
]
|
||||
assert mc_specific_data["dirpath"] == checkpoint.dirpath
|
||||
assert mc_specific_data["monitor"] == monitor
|
||||
|
@ -805,7 +805,7 @@ def test_model_checkpoint_save_last_checkpoint_contents(tmpdir):
|
|||
|
||||
ckpt_id = (
|
||||
"ModelCheckpoint{'monitor': 'early_stop_on', 'mode': 'min', 'every_n_train_steps': 0, 'every_n_epochs': 1,"
|
||||
" 'train_time_interval': None, 'save_on_train_epoch_end': True}"
|
||||
" 'train_time_interval': None, 'save_on_train_epoch_end': None}"
|
||||
)
|
||||
assert ckpt_last["callbacks"][ckpt_id] == ckpt_last_epoch["callbacks"][ckpt_id]
|
||||
|
||||
|
@ -1052,7 +1052,7 @@ def test_current_score(tmpdir):
|
|||
ckpts = [
|
||||
ckpt["callbacks"][
|
||||
"ModelCheckpoint{'monitor': 'foo', 'mode': 'min', 'every_n_train_steps': 0, 'every_n_epochs': 1,"
|
||||
" 'train_time_interval': None, 'save_on_train_epoch_end': True}"
|
||||
" 'train_time_interval': None, 'save_on_train_epoch_end': None}"
|
||||
]
|
||||
for ckpt in ckpts
|
||||
]
|
||||
|
@ -1360,3 +1360,13 @@ def test_save_last_every_n_epochs_interaction(tmpdir, every_n_epochs):
|
|||
trainer.fit(model)
|
||||
assert mc.last_model_path # a "last" ckpt was saved
|
||||
assert save_mock.call_count == trainer.max_epochs
|
||||
|
||||
|
||||
def test_train_epoch_end_ckpt_with_no_validation():
|
||||
trainer = Trainer(val_check_interval=0.5)
|
||||
trainer.num_val_batches = [0]
|
||||
assert trainer.checkpoint_callback._should_save_on_train_epoch_end(trainer)
|
||||
trainer.num_val_batches = [1]
|
||||
assert not trainer.checkpoint_callback._should_save_on_train_epoch_end(trainer)
|
||||
trainer.val_check_interval = 0.8
|
||||
assert not trainer.checkpoint_callback._should_save_on_train_epoch_end(trainer)
|
||||
|
|
|
@ -157,7 +157,7 @@ def test_all_callback_states_saved_before_checkpoint_callback(tmpdir):
|
|||
assert "content1" in state2 and state2["content1"] == "two"
|
||||
assert (
|
||||
"ModelCheckpoint{'monitor': None, 'mode': 'min', 'every_n_train_steps': 0, 'every_n_epochs': 1,"
|
||||
" 'train_time_interval': None, 'save_on_train_epoch_end': True}" in ckpt["callbacks"]
|
||||
" 'train_time_interval': None, 'save_on_train_epoch_end': None}" in ckpt["callbacks"]
|
||||
)
|
||||
|
||||
|
||||
|
|
|
@ -343,6 +343,7 @@ def test_model_checkpoint_options(tmpdir, save_top_k, save_last, expected_files)
|
|||
save_top_k=save_top_k,
|
||||
save_last=save_last,
|
||||
verbose=True,
|
||||
save_on_train_epoch_end=False,
|
||||
)
|
||||
trainer = Trainer()
|
||||
trainer.state.fn = TrainerFn.FITTING
|
||||
|
|
Loading…
Reference in New Issue