diff --git a/CHANGELOG.md b/CHANGELOG.md index 7b2d79ba0f..c61a98bcf8 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -78,6 +78,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Log epoch metrics before the `on_evaluation_end` hook ([#7272](https://github.com/PyTorchLightning/pytorch-lightning/pull/7272)) +- Explicitly disallow calling `self.log(on_epoch=False)` during epoch-only or single-call hooks ([#7874](https://github.com/PyTorchLightning/pytorch-lightning/pull/7874)) + + - Changed these `Trainer` methods to be protected: `call_setup_hook`, `call_configure_sharded_model`, `pre_dispatch`, `dispatch`, `post_dispatch`, `call_teardown_hook`, `run_train`, `run_sanity_check`, `run_evaluate`, `run_evaluation`, `run_predict`, `track_output_for_epoch_end` diff --git a/pytorch_lightning/trainer/connectors/logger_connector/fx_validator.py b/pytorch_lightning/trainer/connectors/logger_connector/fx_validator.py index 3db8aace45..8d079f8b4a 100644 --- a/pytorch_lightning/trainer/connectors/logger_connector/fx_validator.py +++ b/pytorch_lightning/trainer/connectors/logger_connector/fx_validator.py @@ -29,26 +29,26 @@ class FxValidator: on_fit_end=None, on_sanity_check_start=None, on_sanity_check_end=None, - on_train_start=dict(on_step=(False, True), on_epoch=(False, True)), + on_train_start=dict(on_step=(False, ), on_epoch=(True, )), on_train_end=None, - on_validation_start=dict(on_step=(False, True), on_epoch=(False, True)), + on_validation_start=dict(on_step=(False, ), on_epoch=(True, )), on_validation_end=None, - on_test_start=dict(on_step=(False, True), on_epoch=(False, True)), + on_test_start=dict(on_step=(False, ), on_epoch=(True, )), on_test_end=None, on_predict_start=None, on_predict_end=None, on_pretrain_routine_start=None, on_pretrain_routine_end=None, - on_train_epoch_start=dict(on_step=(False, True), on_epoch=(False, True)), - on_train_epoch_end=dict(on_step=(False, ), on_epoch=(False, True)), - on_validation_epoch_start=dict(on_step=(False, True), on_epoch=(False, True)), - on_validation_epoch_end=dict(on_step=(False, ), on_epoch=(False, True)), - on_test_epoch_start=dict(on_step=(False, True), on_epoch=(False, True)), - on_test_epoch_end=dict(on_step=(False, ), on_epoch=(False, True)), + on_train_epoch_start=dict(on_step=(False, True), on_epoch=(True, )), + on_train_epoch_end=dict(on_step=(False, ), on_epoch=(True, )), + on_validation_epoch_start=dict(on_step=(False, True), on_epoch=(True, )), + on_validation_epoch_end=dict(on_step=(False, ), on_epoch=(True, )), + on_test_epoch_start=dict(on_step=(False, True), on_epoch=(True, )), + on_test_epoch_end=dict(on_step=(False, ), on_epoch=(True, )), on_predict_epoch_start=None, on_predict_epoch_end=None, - on_epoch_start=dict(on_step=(False, True), on_epoch=(False, True)), - on_epoch_end=dict(on_step=(False, ), on_epoch=(False, True)), + on_epoch_start=dict(on_step=(False, True), on_epoch=(True, )), + on_epoch_end=dict(on_step=(False, ), on_epoch=(True, )), on_batch_start=dict(on_step=(False, True), on_epoch=(False, True)), on_batch_end=dict(on_step=(False, True), on_epoch=(False, True)), on_train_batch_start=dict(on_step=(False, True), on_epoch=(False, True)), @@ -72,19 +72,26 @@ class FxValidator: training_step_end=dict(on_step=(False, True), on_epoch=(False, True)), validation_step_end=dict(on_step=(False, True), on_epoch=(False, True)), test_step_end=dict(on_step=(False, True), on_epoch=(False, True)), - training_epoch_end=dict(on_step=(False, ), on_epoch=(False, True)), - validation_epoch_end=dict(on_step=(False, ), on_epoch=(False, True)), - test_epoch_end=dict(on_step=(False, ), on_epoch=(False, True)), + training_epoch_end=dict(on_step=(False, ), on_epoch=(True, )), + validation_epoch_end=dict(on_step=(False, ), on_epoch=(True, )), + test_epoch_end=dict(on_step=(False, ), on_epoch=(True, )), + on_before_batch_transfer=None, + transfer_batch_to_device=None, + on_after_batch_transfer=None, + backward=None, + optimizer_step=None, # TODO(@carmocca): some {step,epoch}_{start,end} are missing ) - def check_logging(self, fx_name: str, on_step: bool, on_epoch: bool) -> None: - if fx_name not in self.functions: + @classmethod + def check_logging(cls, fx_name: str, on_step: bool, on_epoch: bool) -> None: + """Check if the given function name is allowed to log""" + if fx_name not in cls.functions: raise RuntimeError( f'You are trying to `self.log()` inside `{fx_name}` but it is not implemented.' ' Please, open an issue in `https://github.com/PyTorchLightning/pytorch-lightning/issues`' ) - allowed = self.functions[fx_name] + allowed = cls.functions[fx_name] if allowed is None: raise MisconfigurationException(f"{fx_name} function doesn't support logging using `self.log()`") diff --git a/tests/trainer/logging_/test_eval_loop_logging.py b/tests/trainer/logging_/test_eval_loop_logging.py index 36539c6d30..27ff807471 100644 --- a/tests/trainer/logging_/test_eval_loop_logging.py +++ b/tests/trainer/logging_/test_eval_loop_logging.py @@ -349,33 +349,18 @@ def test_log_works_in_val_callback(tmpdir): def on_validation_start(self, trainer, pl_module): self.make_logging( - pl_module, - 'on_validation_start', - 1, - on_steps=self.choices, - on_epochs=self.choices, - prob_bars=self.choices + pl_module, 'on_validation_start', 1, on_steps=[False], on_epochs=[True], prob_bars=self.choices ) def on_epoch_start(self, trainer, pl_module): if trainer.validating: self.make_logging( - pl_module, - 'on_epoch_start', - 2, - on_steps=self.choices, - on_epochs=self.choices, - prob_bars=self.choices + pl_module, 'on_epoch_start', 2, on_steps=[False], on_epochs=[True], prob_bars=self.choices ) def on_validation_epoch_start(self, trainer, pl_module): self.make_logging( - pl_module, - 'on_validation_epoch_start', - 3, - on_steps=self.choices, - on_epochs=self.choices, - prob_bars=self.choices + pl_module, 'on_validation_epoch_start', 3, on_steps=[False], on_epochs=[True], prob_bars=self.choices ) def on_batch_end(self, trainer, pl_module): @@ -400,17 +385,12 @@ def test_log_works_in_val_callback(tmpdir): def on_epoch_end(self, trainer, pl_module): if trainer.validating: self.make_logging( - pl_module, 'on_epoch_end', 8, on_steps=[False], on_epochs=self.choices, prob_bars=self.choices + pl_module, 'on_epoch_end', 8, on_steps=[False], on_epochs=[True], prob_bars=self.choices ) def on_validation_epoch_end(self, trainer, pl_module): self.make_logging( - pl_module, - 'on_validation_epoch_end', - 9, - on_steps=[False], - on_epochs=self.choices, - prob_bars=self.choices + pl_module, 'on_validation_epoch_end', 9, on_steps=[False], on_epochs=[True], prob_bars=self.choices ) class TestModel(BoringModel): @@ -558,18 +538,11 @@ def test_log_works_in_test_callback(tmpdir): } def on_test_start(self, trainer, pl_module): - self.make_logging( - pl_module, 'on_test_start', 1, on_steps=self.choices, on_epochs=self.choices, prob_bars=self.choices - ) + self.make_logging(pl_module, 'on_test_start', 1, on_steps=[False], on_epochs=[True], prob_bars=self.choices) def on_test_epoch_start(self, trainer, pl_module): self.make_logging( - pl_module, - 'on_test_epoch_start', - 3, - on_steps=self.choices, - on_epochs=self.choices, - prob_bars=self.choices + pl_module, 'on_test_epoch_start', 3, on_steps=[False], on_epochs=[True], prob_bars=self.choices ) def on_test_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx): @@ -589,7 +562,7 @@ def test_log_works_in_test_callback(tmpdir): def on_test_epoch_end(self, trainer, pl_module): self.make_logging( - pl_module, 'on_test_epoch_end', 7, on_steps=[False], on_epochs=self.choices, prob_bars=self.choices + pl_module, 'on_test_epoch_end', 7, on_steps=[False], on_epochs=[True], prob_bars=self.choices ) max_epochs = 2 diff --git a/tests/trainer/logging_/test_logger_connector.py b/tests/trainer/logging_/test_logger_connector.py index bd222058cb..2e6234bb98 100644 --- a/tests/trainer/logging_/test_logger_connector.py +++ b/tests/trainer/logging_/test_logger_connector.py @@ -359,7 +359,8 @@ def test_fx_validator(tmpdir): # This summarizes where and what is currently possible to log using `self.log` is_stage = "train" in func_name or "test" in func_name or "validation" in func_name is_start = "start" in func_name or "batch" in func_name - on_step = is_stage and is_start + is_epoch = "epoch" in func_name + on_step = is_stage and not is_start and not is_epoch on_epoch = True # creating allowed condition allowed = ( diff --git a/tests/trainer/logging_/test_train_loop_logging.py b/tests/trainer/logging_/test_train_loop_logging.py index 7fbbf5805b..e24d64331c 100644 --- a/tests/trainer/logging_/test_train_loop_logging.py +++ b/tests/trainer/logging_/test_train_loop_logging.py @@ -399,22 +399,17 @@ def test_log_works_in_train_callback(tmpdir): def on_train_start(self, trainer, pl_module): self.make_logging( - pl_module, 'on_train_start', 1, on_steps=self.choices, on_epochs=self.choices, prob_bars=self.choices + pl_module, 'on_train_start', 1, on_steps=[False], on_epochs=[True], prob_bars=self.choices ) def on_epoch_start(self, trainer, pl_module): self.make_logging( - pl_module, 'on_epoch_start', 2, on_steps=self.choices, on_epochs=self.choices, prob_bars=self.choices + pl_module, 'on_epoch_start', 2, on_steps=[False], on_epochs=[True], prob_bars=self.choices ) def on_train_epoch_start(self, trainer, pl_module): self.make_logging( - pl_module, - 'on_train_epoch_start', - 3, - on_steps=self.choices, - on_epochs=self.choices, - prob_bars=self.choices + pl_module, 'on_train_epoch_start', 3, on_steps=[False], on_epochs=[True], prob_bars=self.choices ) def on_batch_end(self, trainer, pl_module): @@ -438,13 +433,11 @@ def test_log_works_in_train_callback(tmpdir): def on_train_epoch_end(self, trainer, pl_module): self.make_logging( - pl_module, 'on_train_epoch_end', 8, on_steps=[False], on_epochs=self.choices, prob_bars=self.choices + pl_module, 'on_train_epoch_end', 8, on_steps=[False], on_epochs=[True], prob_bars=self.choices ) def on_epoch_end(self, trainer, pl_module): - self.make_logging( - pl_module, 'on_epoch_end', 9, on_steps=[False], on_epochs=self.choices, prob_bars=self.choices - ) + self.make_logging(pl_module, 'on_epoch_end', 9, on_steps=[False], on_epochs=[True], prob_bars=self.choices) class TestModel(BoringModel):