diff --git a/CHANGELOG.md b/CHANGELOG.md index 08c0fc05c2..733c77957c 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -196,6 +196,10 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fixed a consolidation error in Lite when attempting to save the state dict of a sharded optimizer ([#10746](https://github.com/PyTorchLightning/pytorch-lightning/pull/10746)) +- Fixed the default logging level for batch hooks associated with training from `on_step=False, on_epoch=True` to `on_step=True, on_epoch=False` ([#10756](https://github.com/PyTorchLightning/pytorch-lightning/pull/10756)) + + + - diff --git a/pytorch_lightning/core/lightning.py b/pytorch_lightning/core/lightning.py index 85afdca05d..87fa042475 100644 --- a/pytorch_lightning/core/lightning.py +++ b/pytorch_lightning/core/lightning.py @@ -353,10 +353,6 @@ class LightningModule( value, object, self.__check_allowed, name, value, wrong_dtype=(numbers.Number, Metric, Tensor, dict) ) - # set the default depending on the fx_name - on_step = self.__auto_choose_log_on_step(on_step) - on_epoch = self.__auto_choose_log_on_epoch(on_epoch) - if self.trainer is None: # not an error to support testing the `*_step` methods without a `Trainer` reference rank_zero_warn( @@ -375,7 +371,10 @@ class LightningModule( raise MisconfigurationException( "You are trying to `self.log()` but it is not managed by the `Trainer` control flow" ) - _FxValidator.check_logging(self._current_fx_name, on_step=on_step, on_epoch=on_epoch) + + on_step, on_epoch = _FxValidator.check_logging_and_get_default_levels( + self._current_fx_name, on_step=on_step, on_epoch=on_epoch + ) # make sure user doesn't introduce logic for multi-dataloaders if "/dataloader_idx_" in name: @@ -530,18 +529,6 @@ class LightningModule( """ self.log_dict(grad_norm_dict, on_step=True, on_epoch=True, prog_bar=True, logger=True) - def __auto_choose_log_on_step(self, on_step: Optional[bool]) -> bool: - if on_step is None: - on_step = False - on_step |= self._current_fx_name in ("training_step", "training_step_end") - return on_step - - def __auto_choose_log_on_epoch(self, on_epoch: Optional[bool]) -> bool: - if on_epoch is None: - on_epoch = True - on_epoch &= self._current_fx_name not in ("training_step", "training_step_end") - return on_epoch - def all_gather( self, data: Union[torch.Tensor, Dict, List, Tuple], group: Optional[Any] = None, sync_grads: bool = False ): diff --git a/pytorch_lightning/trainer/connectors/logger_connector/fx_validator.py b/pytorch_lightning/trainer/connectors/logger_connector/fx_validator.py index cc91476518..ad3dce3c12 100644 --- a/pytorch_lightning/trainer/connectors/logger_connector/fx_validator.py +++ b/pytorch_lightning/trainer/connectors/logger_connector/fx_validator.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Tuple, Union +from typing import Optional, Tuple, Union from typing_extensions import TypedDict @@ -20,50 +20,98 @@ from pytorch_lightning.utilities.exceptions import MisconfigurationException class _FxValidator: class _LogOptions(TypedDict): - on_step: Union[Tuple[bool], Tuple[bool, bool]] - on_epoch: Union[Tuple[bool], Tuple[bool, bool]] + allowed_on_step: Union[Tuple[bool], Tuple[bool, bool]] + allowed_on_epoch: Union[Tuple[bool], Tuple[bool, bool]] + default_on_step: bool + default_on_epoch: bool functions = { "on_before_accelerator_backend_setup": None, "on_configure_sharded_model": None, - "on_before_backward": _LogOptions(on_step=(False, True), on_epoch=(False, True)), - "on_after_backward": _LogOptions(on_step=(False, True), on_epoch=(False, True)), - "on_before_optimizer_step": _LogOptions(on_step=(False, True), on_epoch=(False, True)), - "on_before_zero_grad": _LogOptions(on_step=(False, True), on_epoch=(False, True)), + "on_before_backward": _LogOptions( + allowed_on_step=(False, True), allowed_on_epoch=(False, True), default_on_step=True, default_on_epoch=False + ), + "on_after_backward": _LogOptions( + allowed_on_step=(False, True), allowed_on_epoch=(False, True), default_on_step=True, default_on_epoch=False + ), + "on_before_optimizer_step": _LogOptions( + allowed_on_step=(False, True), allowed_on_epoch=(False, True), default_on_step=True, default_on_epoch=False + ), + "on_before_zero_grad": _LogOptions( + allowed_on_step=(False, True), allowed_on_epoch=(False, True), default_on_step=True, default_on_epoch=False + ), "on_init_start": None, "on_init_end": None, "on_fit_start": None, "on_fit_end": None, "on_sanity_check_start": None, "on_sanity_check_end": None, - "on_train_start": _LogOptions(on_step=(False,), on_epoch=(True,)), + "on_train_start": _LogOptions( + allowed_on_step=(False,), allowed_on_epoch=(True,), default_on_step=False, default_on_epoch=True + ), "on_train_end": None, - "on_validation_start": _LogOptions(on_step=(False,), on_epoch=(True,)), + "on_validation_start": _LogOptions( + allowed_on_step=(False,), allowed_on_epoch=(True,), default_on_step=False, default_on_epoch=True + ), "on_validation_end": None, - "on_test_start": _LogOptions(on_step=(False,), on_epoch=(True,)), + "on_test_start": _LogOptions( + allowed_on_step=(False,), allowed_on_epoch=(True,), default_on_step=False, default_on_epoch=True + ), "on_test_end": None, "on_predict_start": None, "on_predict_end": None, "on_pretrain_routine_start": None, "on_pretrain_routine_end": None, - "on_train_epoch_start": _LogOptions(on_step=(False,), on_epoch=(True,)), - "on_train_epoch_end": _LogOptions(on_step=(False,), on_epoch=(True,)), - "on_validation_epoch_start": _LogOptions(on_step=(False,), on_epoch=(True,)), - "on_validation_epoch_end": _LogOptions(on_step=(False,), on_epoch=(True,)), - "on_test_epoch_start": _LogOptions(on_step=(False,), on_epoch=(True,)), - "on_test_epoch_end": _LogOptions(on_step=(False,), on_epoch=(True,)), + "on_train_epoch_start": _LogOptions( + allowed_on_step=(False,), allowed_on_epoch=(True,), default_on_step=False, default_on_epoch=True + ), + "on_train_epoch_end": _LogOptions( + allowed_on_step=(False,), allowed_on_epoch=(True,), default_on_step=False, default_on_epoch=True + ), + "on_validation_epoch_start": _LogOptions( + allowed_on_step=(False,), allowed_on_epoch=(True,), default_on_step=False, default_on_epoch=True + ), + "on_validation_epoch_end": _LogOptions( + allowed_on_step=(False,), allowed_on_epoch=(True,), default_on_step=False, default_on_epoch=True + ), + "on_test_epoch_start": _LogOptions( + allowed_on_step=(False,), allowed_on_epoch=(True,), default_on_step=False, default_on_epoch=True + ), + "on_test_epoch_end": _LogOptions( + allowed_on_step=(False,), allowed_on_epoch=(True,), default_on_step=False, default_on_epoch=True + ), "on_predict_epoch_start": None, "on_predict_epoch_end": None, - "on_epoch_start": _LogOptions(on_step=(False,), on_epoch=(True,)), - "on_epoch_end": _LogOptions(on_step=(False,), on_epoch=(True,)), - "on_batch_start": _LogOptions(on_step=(False, True), on_epoch=(False, True)), - "on_batch_end": _LogOptions(on_step=(False, True), on_epoch=(False, True)), - "on_train_batch_start": _LogOptions(on_step=(False, True), on_epoch=(False, True)), - "on_train_batch_end": _LogOptions(on_step=(False, True), on_epoch=(False, True)), - "on_validation_batch_start": _LogOptions(on_step=(False, True), on_epoch=(False, True)), - "on_validation_batch_end": _LogOptions(on_step=(False, True), on_epoch=(False, True)), - "on_test_batch_start": _LogOptions(on_step=(False, True), on_epoch=(False, True)), - "on_test_batch_end": _LogOptions(on_step=(False, True), on_epoch=(False, True)), + "on_epoch_start": _LogOptions( + allowed_on_step=(False,), allowed_on_epoch=(True,), default_on_step=False, default_on_epoch=True + ), + "on_epoch_end": _LogOptions( + allowed_on_step=(False,), allowed_on_epoch=(True,), default_on_step=False, default_on_epoch=True + ), + "on_batch_start": _LogOptions( + allowed_on_step=(False, True), allowed_on_epoch=(False, True), default_on_step=True, default_on_epoch=False + ), + "on_batch_end": _LogOptions( + allowed_on_step=(False, True), allowed_on_epoch=(False, True), default_on_step=True, default_on_epoch=False + ), + "on_train_batch_start": _LogOptions( + allowed_on_step=(False, True), allowed_on_epoch=(False, True), default_on_step=True, default_on_epoch=False + ), + "on_train_batch_end": _LogOptions( + allowed_on_step=(False, True), allowed_on_epoch=(False, True), default_on_step=True, default_on_epoch=False + ), + "on_validation_batch_start": _LogOptions( + allowed_on_step=(False, True), allowed_on_epoch=(False, True), default_on_step=False, default_on_epoch=True + ), + "on_validation_batch_end": _LogOptions( + allowed_on_step=(False, True), allowed_on_epoch=(False, True), default_on_step=False, default_on_epoch=True + ), + "on_test_batch_start": _LogOptions( + allowed_on_step=(False, True), allowed_on_epoch=(False, True), default_on_step=False, default_on_epoch=True + ), + "on_test_batch_end": _LogOptions( + allowed_on_step=(False, True), allowed_on_epoch=(False, True), default_on_step=False, default_on_epoch=True + ), "on_predict_batch_start": None, "on_predict_batch_end": None, "on_keyboard_interrupt": None, @@ -73,16 +121,34 @@ class _FxValidator: "setup": None, "teardown": None, "configure_sharded_model": None, - "training_step": _LogOptions(on_step=(False, True), on_epoch=(False, True)), - "validation_step": _LogOptions(on_step=(False, True), on_epoch=(False, True)), - "test_step": _LogOptions(on_step=(False, True), on_epoch=(False, True)), + "training_step": _LogOptions( + allowed_on_step=(False, True), allowed_on_epoch=(False, True), default_on_step=True, default_on_epoch=False + ), + "validation_step": _LogOptions( + allowed_on_step=(False, True), allowed_on_epoch=(False, True), default_on_step=False, default_on_epoch=True + ), + "test_step": _LogOptions( + allowed_on_step=(False, True), allowed_on_epoch=(False, True), default_on_step=False, default_on_epoch=True + ), "predict_step": None, - "training_step_end": _LogOptions(on_step=(False, True), on_epoch=(False, True)), - "validation_step_end": _LogOptions(on_step=(False, True), on_epoch=(False, True)), - "test_step_end": _LogOptions(on_step=(False, True), on_epoch=(False, True)), - "training_epoch_end": _LogOptions(on_step=(False,), on_epoch=(True,)), - "validation_epoch_end": _LogOptions(on_step=(False,), on_epoch=(True,)), - "test_epoch_end": _LogOptions(on_step=(False,), on_epoch=(True,)), + "training_step_end": _LogOptions( + allowed_on_step=(False, True), allowed_on_epoch=(False, True), default_on_step=True, default_on_epoch=False + ), + "validation_step_end": _LogOptions( + allowed_on_step=(False, True), allowed_on_epoch=(False, True), default_on_step=False, default_on_epoch=True + ), + "test_step_end": _LogOptions( + allowed_on_step=(False, True), allowed_on_epoch=(False, True), default_on_step=False, default_on_epoch=True + ), + "training_epoch_end": _LogOptions( + allowed_on_step=(False,), allowed_on_epoch=(True,), default_on_step=False, default_on_epoch=True + ), + "validation_epoch_end": _LogOptions( + allowed_on_step=(False,), allowed_on_epoch=(True,), default_on_step=False, default_on_epoch=True + ), + "test_epoch_end": _LogOptions( + allowed_on_step=(False,), allowed_on_epoch=(True,), default_on_step=False, default_on_epoch=True + ), "configure_optimizers": None, "on_train_dataloader": None, "train_dataloader": None, @@ -97,22 +163,48 @@ class _FxValidator: } @classmethod - def check_logging(cls, fx_name: str, on_step: bool, on_epoch: bool) -> None: - """Check if the given function name is allowed to log.""" + def check_logging(cls, fx_name: str) -> None: + """Check if the given hook is allowed to log.""" if fx_name not in cls.functions: raise RuntimeError( f"Logging inside `{fx_name}` is not implemented." - " Please, open an issue in `https://github.com/PyTorchLightning/pytorch-lightning/issues`" + " Please, open an issue in `https://github.com/PyTorchLightning/pytorch-lightning/issues`." ) - allowed = cls.functions[fx_name] - if allowed is None: - raise MisconfigurationException(f"You can't `self.log()` inside `{fx_name}`") - m = "You can't `self.log({}={})` inside `{}`, must be one of {}" - if on_step not in allowed["on_step"]: - msg = m.format("on_step", on_step, fx_name, allowed["on_step"]) + if cls.functions[fx_name] is None: + raise MisconfigurationException(f"You can't `self.log()` inside `{fx_name}`.") + + @classmethod + def get_default_logging_levels( + cls, fx_name: str, on_step: Optional[bool], on_epoch: Optional[bool] + ) -> Tuple[bool, bool]: + """Return default logging levels for given hook.""" + fx_config = cls.functions[fx_name] + assert fx_config is not None + on_step = fx_config["default_on_step"] if on_step is None else on_step + on_epoch = fx_config["default_on_epoch"] if on_epoch is None else on_epoch + return on_step, on_epoch + + @classmethod + def check_logging_levels(cls, fx_name: str, on_step: bool, on_epoch: bool) -> None: + """Check if the logging levels are allowed in the given hook.""" + fx_config = cls.functions[fx_name] + assert fx_config is not None + m = "You can't `self.log({}={})` inside `{}`, must be one of {}." + if on_step not in fx_config["allowed_on_step"]: + msg = m.format("on_step", on_step, fx_name, fx_config["allowed_on_step"]) raise MisconfigurationException(msg) - if on_epoch not in allowed["on_epoch"]: - msg = m.format("on_epoch", on_epoch, fx_name, allowed["on_epoch"]) + if on_epoch not in fx_config["allowed_on_epoch"]: + msg = m.format("on_epoch", on_epoch, fx_name, fx_config["allowed_on_epoch"]) raise MisconfigurationException(msg) + + @classmethod + def check_logging_and_get_default_levels( + cls, fx_name: str, on_step: Optional[bool], on_epoch: Optional[bool] + ) -> Tuple[bool, bool]: + """Check if the given hook name is allowed to log and return logging levels.""" + cls.check_logging(fx_name) + on_step, on_epoch = cls.get_default_logging_levels(fx_name, on_step, on_epoch) + cls.check_logging_levels(fx_name, on_step, on_epoch) + return on_step, on_epoch diff --git a/tests/trainer/logging_/test_logger_connector.py b/tests/trainer/logging_/test_logger_connector.py index 6e5c6d2ddb..a3e3c0aec9 100644 --- a/tests/trainer/logging_/test_logger_connector.py +++ b/tests/trainer/logging_/test_logger_connector.py @@ -141,17 +141,17 @@ def test_fx_validator(tmpdir): and func_name not in ["on_train_end", "on_test_end", "on_validation_end"] ) if allowed: - validator.check_logging(fx_name=func_name, on_step=on_step, on_epoch=on_epoch) + validator.check_logging_levels(fx_name=func_name, on_step=on_step, on_epoch=on_epoch) if not is_start and is_stage: with pytest.raises(MisconfigurationException, match="must be one of"): - validator.check_logging(fx_name=func_name, on_step=True, on_epoch=on_epoch) + validator.check_logging_levels(fx_name=func_name, on_step=True, on_epoch=on_epoch) else: assert func_name in not_supported with pytest.raises(MisconfigurationException, match="You can't"): - validator.check_logging(fx_name=func_name, on_step=on_step, on_epoch=on_epoch) + validator.check_logging(fx_name=func_name) with pytest.raises(RuntimeError, match="Logging inside `foo` is not implemented"): - validator.check_logging("foo", False, False) + validator.check_logging("foo") class HookedCallback(Callback): diff --git a/tests/trainer/logging_/test_loop_logging.py b/tests/trainer/logging_/test_loop_logging.py new file mode 100644 index 0000000000..2c2f2253c4 --- /dev/null +++ b/tests/trainer/logging_/test_loop_logging.py @@ -0,0 +1,108 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Test logging in the training loop.""" +import inspect +from unittest import mock +from unittest.mock import ANY + +import torch + +from pytorch_lightning import Trainer +from pytorch_lightning.trainer.connectors.logger_connector.fx_validator import _FxValidator +from pytorch_lightning.trainer.connectors.logger_connector.result import ResultCollection +from pytorch_lightning.trainer.states import RunningStage, TrainerFn +from tests.helpers.boring_model import BoringModel + + +def test_default_level_for_hooks_that_support_logging(): + def _make_assertion(model, hooks, result_mock, on_step, on_epoch, extra_kwargs): + for hook in hooks: + model._current_fx_name = hook + model.log(hook, 1) + result_mock.assert_called_with( + hook, hook, torch.tensor(1), on_step=on_step, on_epoch=on_epoch, **extra_kwargs + ) + + trainer = Trainer() + model = BoringModel() + model.trainer = trainer + extra_kwargs = { + k: ANY + for k in inspect.signature(ResultCollection.log).parameters + if k not in ["self", "fx", "name", "value", "on_step", "on_epoch"] + } + all_logging_hooks = {k for k in _FxValidator.functions if _FxValidator.functions[k]} + + with mock.patch( + "pytorch_lightning.trainer.connectors.logger_connector.result.ResultCollection.log", return_value=None + ) as result_mock: + trainer.state.stage = RunningStage.TRAINING + hooks = [ + "on_before_backward", + "on_after_backward", + "on_before_optimizer_step", + "on_before_zero_grad", + "training_step", + "training_step_end", + "on_batch_start", + "on_batch_end", + "on_train_batch_start", + "on_train_batch_end", + ] + all_logging_hooks = all_logging_hooks - set(hooks) + _make_assertion(model, hooks, result_mock, on_step=True, on_epoch=False, extra_kwargs=extra_kwargs) + + hooks = [ + "on_train_start", + "on_train_epoch_start", + "on_train_epoch_end", + "on_epoch_start", + "on_epoch_end", + "training_epoch_end", + ] + all_logging_hooks = all_logging_hooks - set(hooks) + _make_assertion(model, hooks, result_mock, on_step=False, on_epoch=True, extra_kwargs=extra_kwargs) + + trainer.state.stage = RunningStage.VALIDATING + trainer.state.fn = TrainerFn.VALIDATING + hooks = [ + "on_validation_start", + "on_validation_epoch_start", + "on_validation_epoch_end", + "on_validation_batch_start", + "on_validation_batch_end", + "validation_step", + "validation_step_end", + "validation_epoch_end", + ] + all_logging_hooks = all_logging_hooks - set(hooks) + _make_assertion(model, hooks, result_mock, on_step=False, on_epoch=True, extra_kwargs=extra_kwargs) + + trainer.state.stage = RunningStage.TESTING + trainer.state.fn = TrainerFn.TESTING + hooks = [ + "on_test_start", + "on_test_epoch_start", + "on_test_epoch_end", + "on_test_batch_start", + "on_test_batch_end", + "test_step", + "test_step_end", + "test_epoch_end", + ] + all_logging_hooks = all_logging_hooks - set(hooks) + _make_assertion(model, hooks, result_mock, on_step=False, on_epoch=True, extra_kwargs=extra_kwargs) + + # just to ensure we checked all possible logging hooks here + assert len(all_logging_hooks) == 0 diff --git a/tests/trainer/logging_/test_train_loop_logging.py b/tests/trainer/logging_/test_train_loop_logging.py index 6bfbaa9a7b..139714acc9 100644 --- a/tests/trainer/logging_/test_train_loop_logging.py +++ b/tests/trainer/logging_/test_train_loop_logging.py @@ -716,10 +716,10 @@ def test_on_epoch_logging_with_sum_and_on_batch_start(tmpdir): assert all(v == 3 for v in self.trainer.callback_metrics.values()) def on_train_batch_start(self, batch, batch_idx): - self.log("on_train_batch_start", 1.0, reduce_fx="sum") + self.log("on_train_batch_start", 1.0, on_step=False, on_epoch=True, reduce_fx="sum") def on_train_batch_end(self, outputs, batch, batch_idx): - self.log("on_train_batch_end", 1.0, reduce_fx="sum") + self.log("on_train_batch_end", 1.0, on_step=False, on_epoch=True, reduce_fx="sum") def on_validation_batch_start(self, batch, batch_idx, dataloader_idx): self.log("on_validation_batch_start", 1.0, reduce_fx="sum")