Make checkpointing on train epoch end condition dynamic (#15300)

Co-authored-by: Carlos Mocholí <carlossmocholi@gmail.com>
Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com>
This commit is contained in:
Rohit Gupta 2022-11-09 19:57:53 +05:30 committed by GitHub
parent a00dfc850d
commit f4ca5623d2
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 36 additions and 15 deletions

View File

@ -31,7 +31,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- `Trainer.{validate,test,predict}(ckpt_path=...)` no longer restores the `Trainer.global_step` and `trainer.current_epoch` value from the checkpoints - From now on, only `Trainer.fit` will restore this value ([#15532](https://github.com/Lightning-AI/lightning/pull/15532))
-
- The `ModelCheckpoint.save_on_train_epoch_end` attribute is now computed dynamically every epoch, accounting for changes to the validation dataloaders ([#15300](https://github.com/Lightning-AI/lightning/pull/15300))
### Deprecated

View File

@ -262,13 +262,6 @@ class ModelCheckpoint(Checkpoint):
if trainer.is_global_zero and stage == "fit":
self.__warn_if_dir_not_empty(self.dirpath)
# NOTE: setting these attributes needs to happen as early as possible BEFORE reloading callback states,
# because the attributes are part of the state_key which needs to be fully defined before reloading.
if self._save_on_train_epoch_end is None:
# if the user runs validation multiple times per training epoch or multiple training epochs without
# validation, then we run after validation instead of on train epoch end
self._save_on_train_epoch_end = trainer.val_check_interval == 1.0 and trainer.check_val_every_n_epoch == 1
def on_train_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
self._last_time_checked = time.monotonic()
@ -306,7 +299,7 @@ class ModelCheckpoint(Checkpoint):
def on_train_epoch_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
"""Save a checkpoint at the end of the training epoch."""
if not self._should_skip_saving_checkpoint(trainer) and self._save_on_train_epoch_end:
if not self._should_skip_saving_checkpoint(trainer) and self._should_save_on_train_epoch_end(trainer):
monitor_candidates = self._monitor_candidates(trainer)
if self._every_n_epochs >= 1 and (trainer.current_epoch + 1) % self._every_n_epochs == 0:
self._save_topk_checkpoint(trainer, monitor_candidates)
@ -314,7 +307,7 @@ class ModelCheckpoint(Checkpoint):
def on_validation_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
"""Save a checkpoint at the end of the validation stage."""
if not self._should_skip_saving_checkpoint(trainer) and not self._save_on_train_epoch_end:
if not self._should_skip_saving_checkpoint(trainer) and not self._should_save_on_train_epoch_end(trainer):
monitor_candidates = self._monitor_candidates(trainer)
if self._every_n_epochs >= 1 and (trainer.current_epoch + 1) % self._every_n_epochs == 0:
self._save_topk_checkpoint(trainer, monitor_candidates)
@ -390,6 +383,23 @@ class ModelCheckpoint(Checkpoint):
or self._last_global_step_saved == trainer.global_step # already saved at the last step
)
def _should_save_on_train_epoch_end(self, trainer: "pl.Trainer") -> bool:
if self._save_on_train_epoch_end is not None:
return self._save_on_train_epoch_end
# if `check_val_every_n_epoch != 1`, we can't say when the validation dataloader will be loaded
# so let's not enforce saving at every training epoch end
if trainer.check_val_every_n_epoch != 1:
return False
# no validation means save on train epoch end
if sum(trainer.num_val_batches) == 0:
return True
# if the user runs validation multiple times per training epoch, then we run after validation
# instead of on train epoch end
return trainer.val_check_interval == 1.0
def __validate_init_configuration(self) -> None:
if self.save_top_k < -1:
raise MisconfigurationException(f"Invalid value for save_top_k={self.save_top_k}. Must be >= -1")

View File

@ -168,7 +168,7 @@ def test_model_checkpoint_score_and_ckpt(
mc_specific_data = chk["callbacks"][
f"ModelCheckpoint{{'monitor': '{monitor}', 'mode': 'min', 'every_n_train_steps': 0, 'every_n_epochs': 1,"
" 'train_time_interval': None, 'save_on_train_epoch_end': True}"
" 'train_time_interval': None, 'save_on_train_epoch_end': None}"
]
assert mc_specific_data["dirpath"] == checkpoint.dirpath
assert mc_specific_data["monitor"] == monitor
@ -269,7 +269,7 @@ def test_model_checkpoint_score_and_ckpt_val_check_interval(
mc_specific_data = chk["callbacks"][
f"ModelCheckpoint{{'monitor': '{monitor}', 'mode': 'min', 'every_n_train_steps': 0, 'every_n_epochs': 1,"
" 'train_time_interval': None, 'save_on_train_epoch_end': False}"
" 'train_time_interval': None, 'save_on_train_epoch_end': None}"
]
assert mc_specific_data["dirpath"] == checkpoint.dirpath
assert mc_specific_data["monitor"] == monitor
@ -805,7 +805,7 @@ def test_model_checkpoint_save_last_checkpoint_contents(tmpdir):
ckpt_id = (
"ModelCheckpoint{'monitor': 'early_stop_on', 'mode': 'min', 'every_n_train_steps': 0, 'every_n_epochs': 1,"
" 'train_time_interval': None, 'save_on_train_epoch_end': True}"
" 'train_time_interval': None, 'save_on_train_epoch_end': None}"
)
assert ckpt_last["callbacks"][ckpt_id] == ckpt_last_epoch["callbacks"][ckpt_id]
@ -1052,7 +1052,7 @@ def test_current_score(tmpdir):
ckpts = [
ckpt["callbacks"][
"ModelCheckpoint{'monitor': 'foo', 'mode': 'min', 'every_n_train_steps': 0, 'every_n_epochs': 1,"
" 'train_time_interval': None, 'save_on_train_epoch_end': True}"
" 'train_time_interval': None, 'save_on_train_epoch_end': None}"
]
for ckpt in ckpts
]
@ -1360,3 +1360,13 @@ def test_save_last_every_n_epochs_interaction(tmpdir, every_n_epochs):
trainer.fit(model)
assert mc.last_model_path # a "last" ckpt was saved
assert save_mock.call_count == trainer.max_epochs
def test_train_epoch_end_ckpt_with_no_validation():
trainer = Trainer(val_check_interval=0.5)
trainer.num_val_batches = [0]
assert trainer.checkpoint_callback._should_save_on_train_epoch_end(trainer)
trainer.num_val_batches = [1]
assert not trainer.checkpoint_callback._should_save_on_train_epoch_end(trainer)
trainer.val_check_interval = 0.8
assert not trainer.checkpoint_callback._should_save_on_train_epoch_end(trainer)

View File

@ -157,7 +157,7 @@ def test_all_callback_states_saved_before_checkpoint_callback(tmpdir):
assert "content1" in state2 and state2["content1"] == "two"
assert (
"ModelCheckpoint{'monitor': None, 'mode': 'min', 'every_n_train_steps': 0, 'every_n_epochs': 1,"
" 'train_time_interval': None, 'save_on_train_epoch_end': True}" in ckpt["callbacks"]
" 'train_time_interval': None, 'save_on_train_epoch_end': None}" in ckpt["callbacks"]
)

View File

@ -343,6 +343,7 @@ def test_model_checkpoint_options(tmpdir, save_top_k, save_last, expected_files)
save_top_k=save_top_k,
save_last=save_last,
verbose=True,
save_on_train_epoch_end=False,
)
trainer = Trainer()
trainer.state.fn = TrainerFn.FITTING