From 536c1323b0e6715fb5919196ea48b0fcddddcd66 Mon Sep 17 00:00:00 2001 From: Shuying Sun Date: Wed, 24 Mar 2021 01:17:20 -0700 Subject: [PATCH] checkpoint consolidation --- pytorch_lightning/callbacks/base.py | 4 +++ pytorch_lightning/callbacks/early_stopping.py | 15 ++++++++ .../callbacks/lambda_function.py | 3 ++ .../callbacks/model_checkpoint.py | 31 ++++++++++++++++ pytorch_lightning/trainer/callback_hook.py | 7 ++++ .../callback_hook_validator.py | 5 +++ pytorch_lightning/trainer/training_loop.py | 35 ++----------------- tests/checkpointing/test_model_checkpoint.py | 35 +++++++++++++++---- tests/helpers/utils.py | 2 +- .../trainer/logging_/test_logger_connector.py | 1 + 10 files changed, 99 insertions(+), 39 deletions(-) diff --git a/pytorch_lightning/callbacks/base.py b/pytorch_lightning/callbacks/base.py index db507fa991..ffb26f38ca 100644 --- a/pytorch_lightning/callbacks/base.py +++ b/pytorch_lightning/callbacks/base.py @@ -109,6 +109,10 @@ class Callback(abc.ABC): """Called when the epoch ends.""" pass + def on_train_epoch_final_end(self, trainer, pl_module: LightningModule) -> None: + """Called when at the very end of train epoch.""" + pass + def on_batch_start(self, trainer, pl_module: LightningModule) -> None: """Called when the training batch begins.""" pass diff --git a/pytorch_lightning/callbacks/early_stopping.py b/pytorch_lightning/callbacks/early_stopping.py index 4448de8e48..0de8ff6f0b 100644 --- a/pytorch_lightning/callbacks/early_stopping.py +++ b/pytorch_lightning/callbacks/early_stopping.py @@ -143,6 +143,21 @@ class EarlyStopping(Callback): self._run_early_stopping_check(trainer) + def on_train_epoch_final_end(self, trainer, pl_module): + from pytorch_lightning.trainer.states import TrainerState + if ( + trainer.state != TrainerState.FITTING or trainer.sanity_checking + or not trainer.checkpoint_connector.has_trained + ): + return + # if validation is disabled or should skip, we run early stopping + # at end of the training epoch + if ( + trainer.disable_validation + or trainer.evaluation_loop.should_skip_evaluation(trainer.num_val_batches) + ): + self._run_early_stopping_check(trainer) + def _run_early_stopping_check(self, trainer): """ Checks whether the early stopping condition is met diff --git a/pytorch_lightning/callbacks/lambda_function.py b/pytorch_lightning/callbacks/lambda_function.py index 58324e363c..2a56e1c8ac 100644 --- a/pytorch_lightning/callbacks/lambda_function.py +++ b/pytorch_lightning/callbacks/lambda_function.py @@ -53,6 +53,7 @@ class LambdaCallback(Callback): on_train_batch_end: Optional[Callable] = None, on_train_epoch_start: Optional[Callable] = None, on_train_epoch_end: Optional[Callable] = None, + on_train_epoch_final_end: Optional[Callable] = None, on_validation_epoch_start: Optional[Callable] = None, on_validation_epoch_end: Optional[Callable] = None, on_test_epoch_start: Optional[Callable] = None, @@ -155,3 +156,5 @@ class LambdaCallback(Callback): self.on_after_backward = on_after_backward if on_before_zero_grad is not None: self.on_before_zero_grad = on_before_zero_grad + if on_train_epoch_final_end is not None: + self.on_train_epoch_final_end = on_train_epoch_final_end diff --git a/pytorch_lightning/callbacks/model_checkpoint.py b/pytorch_lightning/callbacks/model_checkpoint.py index 2a0c108ba7..9436720e38 100644 --- a/pytorch_lightning/callbacks/model_checkpoint.py +++ b/pytorch_lightning/callbacks/model_checkpoint.py @@ -238,6 +238,37 @@ class ModelCheckpoint(Callback): return self.save_checkpoint(trainer) + def on_train_epoch_final_end(self, trainer, pl_module): + """ + at the end of each training epoch, checkpoint only when validation is skipped or disabled + """ + print("aaa: epoch {}, step: {}".format(trainer.current_epoch, trainer.global_step)) + if ( + self._should_skip_saving_checkpoint(trainer) + or not trainer.checkpoint_connector.has_trained + ): + return + # if validation is disabled or should skip, we checkpoint at end of the training epoch + if ( + trainer.disable_validation + or trainer.evaluation_loop.should_skip_evaluation(trainer.num_val_batches) + ): + self.save_checkpoint(trainer) + + def on_train_end(self, trainer, *args, **kwargs) -> None: + """ + checkpoints can be saved at the end of the trianing + """ + trainer.global_step -= 1 + if ( + not self._should_skip_saving_checkpoint(trainer) + and trainer.checkpoint_connector.has_trained + ): + if self.save_last and self.verbose: + rank_zero_info("Saving latest checkpoint...") + self.save_checkpoint(trainer) + trainer.global_step += 1 + def on_save_checkpoint(self, trainer, pl_module, checkpoint: Dict[str, Any]) -> Dict[str, Any]: return { "monitor": self.monitor, diff --git a/pytorch_lightning/trainer/callback_hook.py b/pytorch_lightning/trainer/callback_hook.py index 8823d48a78..c53c21ad04 100644 --- a/pytorch_lightning/trainer/callback_hook.py +++ b/pytorch_lightning/trainer/callback_hook.py @@ -92,6 +92,13 @@ class TrainerCallbackHookMixin(ABC): for callback in self.callbacks: callback.on_train_epoch_end(self, self.lightning_module, outputs) + def on_train_epoch_final_end(self) -> None: + """ + Called when at the very end of train epoch. + """ + for callback in self.callbacks: + callback.on_train_epoch_final_end(self, self.lightning_module) + def on_validation_epoch_start(self): """Called when the epoch begins.""" for callback in self.callbacks: diff --git a/pytorch_lightning/trainer/connectors/logger_connector/callback_hook_validator.py b/pytorch_lightning/trainer/connectors/logger_connector/callback_hook_validator.py index 534dad5199..e7884124df 100644 --- a/pytorch_lightning/trainer/connectors/logger_connector/callback_hook_validator.py +++ b/pytorch_lightning/trainer/connectors/logger_connector/callback_hook_validator.py @@ -100,6 +100,11 @@ class CallbackHookNameValidator: """Called when the epoch ends.""" return {"on_step": [False], "on_epoch": [False, True]} + @staticmethod + def _on_train_epoch_final_end_log(): + """Called when at the very end of train epoch.""" + return {"on_step": [False], "on_epoch": [False, True]} + @staticmethod def _on_validation_epoch_start_log(): """Called when the epoch begins.""" diff --git a/pytorch_lightning/trainer/training_loop.py b/pytorch_lightning/trainer/training_loop.py index c3ba34ca66..1d498a0a9f 100644 --- a/pytorch_lightning/trainer/training_loop.py +++ b/pytorch_lightning/trainer/training_loop.py @@ -121,12 +121,6 @@ class TrainLoop: return self._teardown_already_run = True - # trigger checkpoint check. need to temporarily decrease the global step to avoid saving duplicates - # when a checkpoint was saved at the last step - self.trainer.global_step -= 1 - self.check_checkpoint_callback(should_update=True, is_last=True) - self.trainer.global_step += 1 - # hook self.trainer.call_hook("on_train_end") @@ -145,28 +139,6 @@ class TrainLoop: # reset bookkeeping self.trainer._running_stage = None - def check_checkpoint_callback(self, should_update, is_last=False): - # TODO bake this logic into the ModelCheckpoint callback - if should_update and self.trainer.checkpoint_connector.has_trained: - callbacks = self.trainer.checkpoint_callbacks - - if is_last and any(cb.save_last and cb.verbose for cb in callbacks): - rank_zero_info("Saving latest checkpoint...") - - model = self.trainer.lightning_module - - for cb in callbacks: - cb.on_validation_end(self.trainer, model) - - def check_early_stopping_callback(self, should_update): - # TODO bake this logic into the EarlyStopping callback - if should_update and self.trainer.checkpoint_connector.has_trained: - callbacks = [c for c in self.trainer.callbacks if isinstance(c, EarlyStopping)] - model = self.trainer.lightning_module - - for cb in callbacks: - cb.on_validation_end(self.trainer, model) - def on_train_epoch_start(self, epoch): # update training progress in trainer @@ -562,15 +534,14 @@ class TrainLoop: if (val_loop_called and not should_check_val) or should_train_only: self.trainer.optimizer_connector.update_learning_rates(interval='epoch') - if should_train_only: - self.check_checkpoint_callback(True) - self.check_early_stopping_callback(True) - if should_check_val: self.trainer.validating = True self.trainer.run_evaluation(on_epoch=True) self.trainer.training = True + if should_train_only: + self.trainer.call_hook('on_train_epoch_final_end') + # increment the global step once # progress global step according to grads progress self.increment_accumulated_grad_global_step() diff --git a/tests/checkpointing/test_model_checkpoint.py b/tests/checkpointing/test_model_checkpoint.py index 75f25b90fa..e0c295a843 100644 --- a/tests/checkpointing/test_model_checkpoint.py +++ b/tests/checkpointing/test_model_checkpoint.py @@ -609,7 +609,13 @@ def test_model_checkpoint_period(tmpdir, period: int): trainer.fit(model) # check that the correct ckpts were created - expected = [f'epoch={e}.ckpt' for e in range(epochs) if not (e + 1) % period] if period > 0 else [] + final_epoch_ckpt = "epoch={e}.ckpt".format(e=epochs-1) + expected = ( + [f"epoch={e}.ckpt" for e in range(epochs) if not (e + 1) % period and e + 1 != epochs] + if period > 0 + else [] + ) + expected.append(final_epoch_ckpt) assert set(os.listdir(tmpdir)) == set(expected) @@ -631,8 +637,14 @@ def test_model_checkpoint_every_n_val_epochs(tmpdir, every_n_val_epochs): trainer.fit(model) # check that the correct ckpts were created - expected = [f'epoch={e}.ckpt' for e in range(epochs) - if not (e + 1) % every_n_val_epochs] if every_n_val_epochs > 0 else [] + # check that the correct ckpts were created + final_epoch_ckpt = "epoch={e}.ckpt".format(e=epochs-1) + expected = ( + [f"epoch={e}.ckpt" for e in range(epochs) if not (e + 1) % every_n_val_epochs and e + 1 != epochs] + if every_n_val_epochs > 0 + else [] + ) + expected.append(final_epoch_ckpt) assert set(os.listdir(tmpdir)) == set(expected) @@ -659,8 +671,14 @@ def test_model_checkpoint_every_n_val_epochs_and_period(tmpdir, every_n_val_epoc trainer.fit(model) # check that the correct ckpts were created - expected = [f'epoch={e}.ckpt' for e in range(epochs) - if not (e + 1) % every_n_val_epochs] if every_n_val_epochs > 0 else [] + # check that the correct ckpts were created + final_epoch_ckpt = "epoch={e}.ckpt".format(e=epochs-1) + expected = ( + [f"epoch={e}.ckpt" for e in range(epochs) if not (e + 1) % every_n_val_epochs and e + 1 != epochs] + if every_n_val_epochs > 0 + else [] + ) + expected.append(final_epoch_ckpt) assert set(os.listdir(tmpdir)) == set(expected) @@ -816,10 +834,15 @@ def test_model_checkpoint_save_last_warning( default_root_dir=tmpdir, callbacks=[ckpt], max_epochs=max_epochs, + val_check_interval=0.1, ) with caplog.at_level(logging.INFO): trainer.fit(model) - assert caplog.messages.count('Saving latest checkpoint...') == (verbose and save_last) + if verbose and save_last and not should_validate: + # no validation, hence checkpoint triggered at the end of each training epoch + assert caplog.messages.count('Saving latest checkpoint...') == False + else: + assert caplog.messages.count('Saving latest checkpoint...') == (verbose and save_last) def test_model_checkpoint_save_last_checkpoint_contents(tmpdir): diff --git a/tests/helpers/utils.py b/tests/helpers/utils.py index f5c1726a42..493d32d3fe 100644 --- a/tests/helpers/utils.py +++ b/tests/helpers/utils.py @@ -76,7 +76,7 @@ def reset_seed(seed=0): def set_random_master_port(): reset_seed() port = RANDOM_PORTS.pop() - os.environ['MASTER_PORT'] = str(port) + os.environ['MASTER_PORT'] = "29501" def init_checkpoint_callback(logger): diff --git a/tests/trainer/logging_/test_logger_connector.py b/tests/trainer/logging_/test_logger_connector.py index 3db0a8eaa0..b2727177bc 100644 --- a/tests/trainer/logging_/test_logger_connector.py +++ b/tests/trainer/logging_/test_logger_connector.py @@ -300,6 +300,7 @@ def test_call_back_validator(tmpdir): 'on_train_batch_start', 'on_train_end', 'on_train_epoch_end', + 'on_train_epoch_final_end', 'on_train_epoch_start', 'on_train_start', 'on_validation_batch_end',