From f4ca5623d2960da252073cacac9474b34b76be59 Mon Sep 17 00:00:00 2001 From: Rohit Gupta Date: Wed, 9 Nov 2022 19:57:53 +0530 Subject: [PATCH] Make checkpointing on train epoch end condition dynamic (#15300) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Carlos Mocholí Co-authored-by: Adrian Wälchli --- src/pytorch_lightning/CHANGELOG.md | 2 +- .../callbacks/model_checkpoint.py | 28 +++++++++++++------ .../checkpointing/test_model_checkpoint.py | 18 +++++++++--- .../connectors/test_callback_connector.py | 2 +- tests/tests_pytorch/trainer/test_trainer.py | 1 + 5 files changed, 36 insertions(+), 15 deletions(-) diff --git a/src/pytorch_lightning/CHANGELOG.md b/src/pytorch_lightning/CHANGELOG.md index 0cc85a26d9..2a9d654c99 100644 --- a/src/pytorch_lightning/CHANGELOG.md +++ b/src/pytorch_lightning/CHANGELOG.md @@ -31,7 +31,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - `Trainer.{validate,test,predict}(ckpt_path=...)` no longer restores the `Trainer.global_step` and `trainer.current_epoch` value from the checkpoints - From now on, only `Trainer.fit` will restore this value ([#15532](https://github.com/Lightning-AI/lightning/pull/15532)) -- +- The `ModelCheckpoint.save_on_train_epoch_end` attribute is now computed dynamically every epoch, accounting for changes to the validation dataloaders ([#15300](https://github.com/Lightning-AI/lightning/pull/15300)) ### Deprecated diff --git a/src/pytorch_lightning/callbacks/model_checkpoint.py b/src/pytorch_lightning/callbacks/model_checkpoint.py index 8d0596e3bd..1403b754be 100644 --- a/src/pytorch_lightning/callbacks/model_checkpoint.py +++ b/src/pytorch_lightning/callbacks/model_checkpoint.py @@ -262,13 +262,6 @@ class ModelCheckpoint(Checkpoint): if trainer.is_global_zero and stage == "fit": self.__warn_if_dir_not_empty(self.dirpath) - # NOTE: setting these attributes needs to happen as early as possible BEFORE reloading callback states, - # because the attributes are part of the state_key which needs to be fully defined before reloading. - if self._save_on_train_epoch_end is None: - # if the user runs validation multiple times per training epoch or multiple training epochs without - # validation, then we run after validation instead of on train epoch end - self._save_on_train_epoch_end = trainer.val_check_interval == 1.0 and trainer.check_val_every_n_epoch == 1 - def on_train_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: self._last_time_checked = time.monotonic() @@ -306,7 +299,7 @@ class ModelCheckpoint(Checkpoint): def on_train_epoch_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: """Save a checkpoint at the end of the training epoch.""" - if not self._should_skip_saving_checkpoint(trainer) and self._save_on_train_epoch_end: + if not self._should_skip_saving_checkpoint(trainer) and self._should_save_on_train_epoch_end(trainer): monitor_candidates = self._monitor_candidates(trainer) if self._every_n_epochs >= 1 and (trainer.current_epoch + 1) % self._every_n_epochs == 0: self._save_topk_checkpoint(trainer, monitor_candidates) @@ -314,7 +307,7 @@ class ModelCheckpoint(Checkpoint): def on_validation_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: """Save a checkpoint at the end of the validation stage.""" - if not self._should_skip_saving_checkpoint(trainer) and not self._save_on_train_epoch_end: + if not self._should_skip_saving_checkpoint(trainer) and not self._should_save_on_train_epoch_end(trainer): monitor_candidates = self._monitor_candidates(trainer) if self._every_n_epochs >= 1 and (trainer.current_epoch + 1) % self._every_n_epochs == 0: self._save_topk_checkpoint(trainer, monitor_candidates) @@ -390,6 +383,23 @@ class ModelCheckpoint(Checkpoint): or self._last_global_step_saved == trainer.global_step # already saved at the last step ) + def _should_save_on_train_epoch_end(self, trainer: "pl.Trainer") -> bool: + if self._save_on_train_epoch_end is not None: + return self._save_on_train_epoch_end + + # if `check_val_every_n_epoch != 1`, we can't say when the validation dataloader will be loaded + # so let's not enforce saving at every training epoch end + if trainer.check_val_every_n_epoch != 1: + return False + + # no validation means save on train epoch end + if sum(trainer.num_val_batches) == 0: + return True + + # if the user runs validation multiple times per training epoch, then we run after validation + # instead of on train epoch end + return trainer.val_check_interval == 1.0 + def __validate_init_configuration(self) -> None: if self.save_top_k < -1: raise MisconfigurationException(f"Invalid value for save_top_k={self.save_top_k}. Must be >= -1") diff --git a/tests/tests_pytorch/checkpointing/test_model_checkpoint.py b/tests/tests_pytorch/checkpointing/test_model_checkpoint.py index 54ad7c80ee..786f3181a6 100644 --- a/tests/tests_pytorch/checkpointing/test_model_checkpoint.py +++ b/tests/tests_pytorch/checkpointing/test_model_checkpoint.py @@ -168,7 +168,7 @@ def test_model_checkpoint_score_and_ckpt( mc_specific_data = chk["callbacks"][ f"ModelCheckpoint{{'monitor': '{monitor}', 'mode': 'min', 'every_n_train_steps': 0, 'every_n_epochs': 1," - " 'train_time_interval': None, 'save_on_train_epoch_end': True}" + " 'train_time_interval': None, 'save_on_train_epoch_end': None}" ] assert mc_specific_data["dirpath"] == checkpoint.dirpath assert mc_specific_data["monitor"] == monitor @@ -269,7 +269,7 @@ def test_model_checkpoint_score_and_ckpt_val_check_interval( mc_specific_data = chk["callbacks"][ f"ModelCheckpoint{{'monitor': '{monitor}', 'mode': 'min', 'every_n_train_steps': 0, 'every_n_epochs': 1," - " 'train_time_interval': None, 'save_on_train_epoch_end': False}" + " 'train_time_interval': None, 'save_on_train_epoch_end': None}" ] assert mc_specific_data["dirpath"] == checkpoint.dirpath assert mc_specific_data["monitor"] == monitor @@ -805,7 +805,7 @@ def test_model_checkpoint_save_last_checkpoint_contents(tmpdir): ckpt_id = ( "ModelCheckpoint{'monitor': 'early_stop_on', 'mode': 'min', 'every_n_train_steps': 0, 'every_n_epochs': 1," - " 'train_time_interval': None, 'save_on_train_epoch_end': True}" + " 'train_time_interval': None, 'save_on_train_epoch_end': None}" ) assert ckpt_last["callbacks"][ckpt_id] == ckpt_last_epoch["callbacks"][ckpt_id] @@ -1052,7 +1052,7 @@ def test_current_score(tmpdir): ckpts = [ ckpt["callbacks"][ "ModelCheckpoint{'monitor': 'foo', 'mode': 'min', 'every_n_train_steps': 0, 'every_n_epochs': 1," - " 'train_time_interval': None, 'save_on_train_epoch_end': True}" + " 'train_time_interval': None, 'save_on_train_epoch_end': None}" ] for ckpt in ckpts ] @@ -1360,3 +1360,13 @@ def test_save_last_every_n_epochs_interaction(tmpdir, every_n_epochs): trainer.fit(model) assert mc.last_model_path # a "last" ckpt was saved assert save_mock.call_count == trainer.max_epochs + + +def test_train_epoch_end_ckpt_with_no_validation(): + trainer = Trainer(val_check_interval=0.5) + trainer.num_val_batches = [0] + assert trainer.checkpoint_callback._should_save_on_train_epoch_end(trainer) + trainer.num_val_batches = [1] + assert not trainer.checkpoint_callback._should_save_on_train_epoch_end(trainer) + trainer.val_check_interval = 0.8 + assert not trainer.checkpoint_callback._should_save_on_train_epoch_end(trainer) diff --git a/tests/tests_pytorch/trainer/connectors/test_callback_connector.py b/tests/tests_pytorch/trainer/connectors/test_callback_connector.py index b7ecab6998..8cd1777901 100644 --- a/tests/tests_pytorch/trainer/connectors/test_callback_connector.py +++ b/tests/tests_pytorch/trainer/connectors/test_callback_connector.py @@ -157,7 +157,7 @@ def test_all_callback_states_saved_before_checkpoint_callback(tmpdir): assert "content1" in state2 and state2["content1"] == "two" assert ( "ModelCheckpoint{'monitor': None, 'mode': 'min', 'every_n_train_steps': 0, 'every_n_epochs': 1," - " 'train_time_interval': None, 'save_on_train_epoch_end': True}" in ckpt["callbacks"] + " 'train_time_interval': None, 'save_on_train_epoch_end': None}" in ckpt["callbacks"] ) diff --git a/tests/tests_pytorch/trainer/test_trainer.py b/tests/tests_pytorch/trainer/test_trainer.py index 21d2525f21..64b52a5a1c 100644 --- a/tests/tests_pytorch/trainer/test_trainer.py +++ b/tests/tests_pytorch/trainer/test_trainer.py @@ -343,6 +343,7 @@ def test_model_checkpoint_options(tmpdir, save_top_k, save_last, expected_files) save_top_k=save_top_k, save_last=save_last, verbose=True, + save_on_train_epoch_end=False, ) trainer = Trainer() trainer.state.fn = TrainerFn.FITTING