Fix default logging levels for train step specific hooks (#10756)

This commit is contained in:
Rohit Gupta 2021-11-30 01:21:17 +05:30 committed by GitHub
parent 088818fbc6
commit 753cc4dfad
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 260 additions and 69 deletions

View File

@ -196,6 +196,10 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Fixed a consolidation error in Lite when attempting to save the state dict of a sharded optimizer ([#10746](https://github.com/PyTorchLightning/pytorch-lightning/pull/10746))
- Fixed the default logging level for batch hooks associated with training from `on_step=False, on_epoch=True` to `on_step=True, on_epoch=False` ([#10756](https://github.com/PyTorchLightning/pytorch-lightning/pull/10756))
-

View File

@ -353,10 +353,6 @@ class LightningModule(
value, object, self.__check_allowed, name, value, wrong_dtype=(numbers.Number, Metric, Tensor, dict)
)
# set the default depending on the fx_name
on_step = self.__auto_choose_log_on_step(on_step)
on_epoch = self.__auto_choose_log_on_epoch(on_epoch)
if self.trainer is None:
# not an error to support testing the `*_step` methods without a `Trainer` reference
rank_zero_warn(
@ -375,7 +371,10 @@ class LightningModule(
raise MisconfigurationException(
"You are trying to `self.log()` but it is not managed by the `Trainer` control flow"
)
_FxValidator.check_logging(self._current_fx_name, on_step=on_step, on_epoch=on_epoch)
on_step, on_epoch = _FxValidator.check_logging_and_get_default_levels(
self._current_fx_name, on_step=on_step, on_epoch=on_epoch
)
# make sure user doesn't introduce logic for multi-dataloaders
if "/dataloader_idx_" in name:
@ -530,18 +529,6 @@ class LightningModule(
"""
self.log_dict(grad_norm_dict, on_step=True, on_epoch=True, prog_bar=True, logger=True)
def __auto_choose_log_on_step(self, on_step: Optional[bool]) -> bool:
if on_step is None:
on_step = False
on_step |= self._current_fx_name in ("training_step", "training_step_end")
return on_step
def __auto_choose_log_on_epoch(self, on_epoch: Optional[bool]) -> bool:
if on_epoch is None:
on_epoch = True
on_epoch &= self._current_fx_name not in ("training_step", "training_step_end")
return on_epoch
def all_gather(
self, data: Union[torch.Tensor, Dict, List, Tuple], group: Optional[Any] = None, sync_grads: bool = False
):

View File

@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Tuple, Union
from typing import Optional, Tuple, Union
from typing_extensions import TypedDict
@ -20,50 +20,98 @@ from pytorch_lightning.utilities.exceptions import MisconfigurationException
class _FxValidator:
class _LogOptions(TypedDict):
on_step: Union[Tuple[bool], Tuple[bool, bool]]
on_epoch: Union[Tuple[bool], Tuple[bool, bool]]
allowed_on_step: Union[Tuple[bool], Tuple[bool, bool]]
allowed_on_epoch: Union[Tuple[bool], Tuple[bool, bool]]
default_on_step: bool
default_on_epoch: bool
functions = {
"on_before_accelerator_backend_setup": None,
"on_configure_sharded_model": None,
"on_before_backward": _LogOptions(on_step=(False, True), on_epoch=(False, True)),
"on_after_backward": _LogOptions(on_step=(False, True), on_epoch=(False, True)),
"on_before_optimizer_step": _LogOptions(on_step=(False, True), on_epoch=(False, True)),
"on_before_zero_grad": _LogOptions(on_step=(False, True), on_epoch=(False, True)),
"on_before_backward": _LogOptions(
allowed_on_step=(False, True), allowed_on_epoch=(False, True), default_on_step=True, default_on_epoch=False
),
"on_after_backward": _LogOptions(
allowed_on_step=(False, True), allowed_on_epoch=(False, True), default_on_step=True, default_on_epoch=False
),
"on_before_optimizer_step": _LogOptions(
allowed_on_step=(False, True), allowed_on_epoch=(False, True), default_on_step=True, default_on_epoch=False
),
"on_before_zero_grad": _LogOptions(
allowed_on_step=(False, True), allowed_on_epoch=(False, True), default_on_step=True, default_on_epoch=False
),
"on_init_start": None,
"on_init_end": None,
"on_fit_start": None,
"on_fit_end": None,
"on_sanity_check_start": None,
"on_sanity_check_end": None,
"on_train_start": _LogOptions(on_step=(False,), on_epoch=(True,)),
"on_train_start": _LogOptions(
allowed_on_step=(False,), allowed_on_epoch=(True,), default_on_step=False, default_on_epoch=True
),
"on_train_end": None,
"on_validation_start": _LogOptions(on_step=(False,), on_epoch=(True,)),
"on_validation_start": _LogOptions(
allowed_on_step=(False,), allowed_on_epoch=(True,), default_on_step=False, default_on_epoch=True
),
"on_validation_end": None,
"on_test_start": _LogOptions(on_step=(False,), on_epoch=(True,)),
"on_test_start": _LogOptions(
allowed_on_step=(False,), allowed_on_epoch=(True,), default_on_step=False, default_on_epoch=True
),
"on_test_end": None,
"on_predict_start": None,
"on_predict_end": None,
"on_pretrain_routine_start": None,
"on_pretrain_routine_end": None,
"on_train_epoch_start": _LogOptions(on_step=(False,), on_epoch=(True,)),
"on_train_epoch_end": _LogOptions(on_step=(False,), on_epoch=(True,)),
"on_validation_epoch_start": _LogOptions(on_step=(False,), on_epoch=(True,)),
"on_validation_epoch_end": _LogOptions(on_step=(False,), on_epoch=(True,)),
"on_test_epoch_start": _LogOptions(on_step=(False,), on_epoch=(True,)),
"on_test_epoch_end": _LogOptions(on_step=(False,), on_epoch=(True,)),
"on_train_epoch_start": _LogOptions(
allowed_on_step=(False,), allowed_on_epoch=(True,), default_on_step=False, default_on_epoch=True
),
"on_train_epoch_end": _LogOptions(
allowed_on_step=(False,), allowed_on_epoch=(True,), default_on_step=False, default_on_epoch=True
),
"on_validation_epoch_start": _LogOptions(
allowed_on_step=(False,), allowed_on_epoch=(True,), default_on_step=False, default_on_epoch=True
),
"on_validation_epoch_end": _LogOptions(
allowed_on_step=(False,), allowed_on_epoch=(True,), default_on_step=False, default_on_epoch=True
),
"on_test_epoch_start": _LogOptions(
allowed_on_step=(False,), allowed_on_epoch=(True,), default_on_step=False, default_on_epoch=True
),
"on_test_epoch_end": _LogOptions(
allowed_on_step=(False,), allowed_on_epoch=(True,), default_on_step=False, default_on_epoch=True
),
"on_predict_epoch_start": None,
"on_predict_epoch_end": None,
"on_epoch_start": _LogOptions(on_step=(False,), on_epoch=(True,)),
"on_epoch_end": _LogOptions(on_step=(False,), on_epoch=(True,)),
"on_batch_start": _LogOptions(on_step=(False, True), on_epoch=(False, True)),
"on_batch_end": _LogOptions(on_step=(False, True), on_epoch=(False, True)),
"on_train_batch_start": _LogOptions(on_step=(False, True), on_epoch=(False, True)),
"on_train_batch_end": _LogOptions(on_step=(False, True), on_epoch=(False, True)),
"on_validation_batch_start": _LogOptions(on_step=(False, True), on_epoch=(False, True)),
"on_validation_batch_end": _LogOptions(on_step=(False, True), on_epoch=(False, True)),
"on_test_batch_start": _LogOptions(on_step=(False, True), on_epoch=(False, True)),
"on_test_batch_end": _LogOptions(on_step=(False, True), on_epoch=(False, True)),
"on_epoch_start": _LogOptions(
allowed_on_step=(False,), allowed_on_epoch=(True,), default_on_step=False, default_on_epoch=True
),
"on_epoch_end": _LogOptions(
allowed_on_step=(False,), allowed_on_epoch=(True,), default_on_step=False, default_on_epoch=True
),
"on_batch_start": _LogOptions(
allowed_on_step=(False, True), allowed_on_epoch=(False, True), default_on_step=True, default_on_epoch=False
),
"on_batch_end": _LogOptions(
allowed_on_step=(False, True), allowed_on_epoch=(False, True), default_on_step=True, default_on_epoch=False
),
"on_train_batch_start": _LogOptions(
allowed_on_step=(False, True), allowed_on_epoch=(False, True), default_on_step=True, default_on_epoch=False
),
"on_train_batch_end": _LogOptions(
allowed_on_step=(False, True), allowed_on_epoch=(False, True), default_on_step=True, default_on_epoch=False
),
"on_validation_batch_start": _LogOptions(
allowed_on_step=(False, True), allowed_on_epoch=(False, True), default_on_step=False, default_on_epoch=True
),
"on_validation_batch_end": _LogOptions(
allowed_on_step=(False, True), allowed_on_epoch=(False, True), default_on_step=False, default_on_epoch=True
),
"on_test_batch_start": _LogOptions(
allowed_on_step=(False, True), allowed_on_epoch=(False, True), default_on_step=False, default_on_epoch=True
),
"on_test_batch_end": _LogOptions(
allowed_on_step=(False, True), allowed_on_epoch=(False, True), default_on_step=False, default_on_epoch=True
),
"on_predict_batch_start": None,
"on_predict_batch_end": None,
"on_keyboard_interrupt": None,
@ -73,16 +121,34 @@ class _FxValidator:
"setup": None,
"teardown": None,
"configure_sharded_model": None,
"training_step": _LogOptions(on_step=(False, True), on_epoch=(False, True)),
"validation_step": _LogOptions(on_step=(False, True), on_epoch=(False, True)),
"test_step": _LogOptions(on_step=(False, True), on_epoch=(False, True)),
"training_step": _LogOptions(
allowed_on_step=(False, True), allowed_on_epoch=(False, True), default_on_step=True, default_on_epoch=False
),
"validation_step": _LogOptions(
allowed_on_step=(False, True), allowed_on_epoch=(False, True), default_on_step=False, default_on_epoch=True
),
"test_step": _LogOptions(
allowed_on_step=(False, True), allowed_on_epoch=(False, True), default_on_step=False, default_on_epoch=True
),
"predict_step": None,
"training_step_end": _LogOptions(on_step=(False, True), on_epoch=(False, True)),
"validation_step_end": _LogOptions(on_step=(False, True), on_epoch=(False, True)),
"test_step_end": _LogOptions(on_step=(False, True), on_epoch=(False, True)),
"training_epoch_end": _LogOptions(on_step=(False,), on_epoch=(True,)),
"validation_epoch_end": _LogOptions(on_step=(False,), on_epoch=(True,)),
"test_epoch_end": _LogOptions(on_step=(False,), on_epoch=(True,)),
"training_step_end": _LogOptions(
allowed_on_step=(False, True), allowed_on_epoch=(False, True), default_on_step=True, default_on_epoch=False
),
"validation_step_end": _LogOptions(
allowed_on_step=(False, True), allowed_on_epoch=(False, True), default_on_step=False, default_on_epoch=True
),
"test_step_end": _LogOptions(
allowed_on_step=(False, True), allowed_on_epoch=(False, True), default_on_step=False, default_on_epoch=True
),
"training_epoch_end": _LogOptions(
allowed_on_step=(False,), allowed_on_epoch=(True,), default_on_step=False, default_on_epoch=True
),
"validation_epoch_end": _LogOptions(
allowed_on_step=(False,), allowed_on_epoch=(True,), default_on_step=False, default_on_epoch=True
),
"test_epoch_end": _LogOptions(
allowed_on_step=(False,), allowed_on_epoch=(True,), default_on_step=False, default_on_epoch=True
),
"configure_optimizers": None,
"on_train_dataloader": None,
"train_dataloader": None,
@ -97,22 +163,48 @@ class _FxValidator:
}
@classmethod
def check_logging(cls, fx_name: str, on_step: bool, on_epoch: bool) -> None:
"""Check if the given function name is allowed to log."""
def check_logging(cls, fx_name: str) -> None:
"""Check if the given hook is allowed to log."""
if fx_name not in cls.functions:
raise RuntimeError(
f"Logging inside `{fx_name}` is not implemented."
" Please, open an issue in `https://github.com/PyTorchLightning/pytorch-lightning/issues`"
" Please, open an issue in `https://github.com/PyTorchLightning/pytorch-lightning/issues`."
)
allowed = cls.functions[fx_name]
if allowed is None:
raise MisconfigurationException(f"You can't `self.log()` inside `{fx_name}`")
m = "You can't `self.log({}={})` inside `{}`, must be one of {}"
if on_step not in allowed["on_step"]:
msg = m.format("on_step", on_step, fx_name, allowed["on_step"])
if cls.functions[fx_name] is None:
raise MisconfigurationException(f"You can't `self.log()` inside `{fx_name}`.")
@classmethod
def get_default_logging_levels(
cls, fx_name: str, on_step: Optional[bool], on_epoch: Optional[bool]
) -> Tuple[bool, bool]:
"""Return default logging levels for given hook."""
fx_config = cls.functions[fx_name]
assert fx_config is not None
on_step = fx_config["default_on_step"] if on_step is None else on_step
on_epoch = fx_config["default_on_epoch"] if on_epoch is None else on_epoch
return on_step, on_epoch
@classmethod
def check_logging_levels(cls, fx_name: str, on_step: bool, on_epoch: bool) -> None:
"""Check if the logging levels are allowed in the given hook."""
fx_config = cls.functions[fx_name]
assert fx_config is not None
m = "You can't `self.log({}={})` inside `{}`, must be one of {}."
if on_step not in fx_config["allowed_on_step"]:
msg = m.format("on_step", on_step, fx_name, fx_config["allowed_on_step"])
raise MisconfigurationException(msg)
if on_epoch not in allowed["on_epoch"]:
msg = m.format("on_epoch", on_epoch, fx_name, allowed["on_epoch"])
if on_epoch not in fx_config["allowed_on_epoch"]:
msg = m.format("on_epoch", on_epoch, fx_name, fx_config["allowed_on_epoch"])
raise MisconfigurationException(msg)
@classmethod
def check_logging_and_get_default_levels(
cls, fx_name: str, on_step: Optional[bool], on_epoch: Optional[bool]
) -> Tuple[bool, bool]:
"""Check if the given hook name is allowed to log and return logging levels."""
cls.check_logging(fx_name)
on_step, on_epoch = cls.get_default_logging_levels(fx_name, on_step, on_epoch)
cls.check_logging_levels(fx_name, on_step, on_epoch)
return on_step, on_epoch

View File

@ -141,17 +141,17 @@ def test_fx_validator(tmpdir):
and func_name not in ["on_train_end", "on_test_end", "on_validation_end"]
)
if allowed:
validator.check_logging(fx_name=func_name, on_step=on_step, on_epoch=on_epoch)
validator.check_logging_levels(fx_name=func_name, on_step=on_step, on_epoch=on_epoch)
if not is_start and is_stage:
with pytest.raises(MisconfigurationException, match="must be one of"):
validator.check_logging(fx_name=func_name, on_step=True, on_epoch=on_epoch)
validator.check_logging_levels(fx_name=func_name, on_step=True, on_epoch=on_epoch)
else:
assert func_name in not_supported
with pytest.raises(MisconfigurationException, match="You can't"):
validator.check_logging(fx_name=func_name, on_step=on_step, on_epoch=on_epoch)
validator.check_logging(fx_name=func_name)
with pytest.raises(RuntimeError, match="Logging inside `foo` is not implemented"):
validator.check_logging("foo", False, False)
validator.check_logging("foo")
class HookedCallback(Callback):

View File

@ -0,0 +1,108 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Test logging in the training loop."""
import inspect
from unittest import mock
from unittest.mock import ANY
import torch
from pytorch_lightning import Trainer
from pytorch_lightning.trainer.connectors.logger_connector.fx_validator import _FxValidator
from pytorch_lightning.trainer.connectors.logger_connector.result import ResultCollection
from pytorch_lightning.trainer.states import RunningStage, TrainerFn
from tests.helpers.boring_model import BoringModel
def test_default_level_for_hooks_that_support_logging():
def _make_assertion(model, hooks, result_mock, on_step, on_epoch, extra_kwargs):
for hook in hooks:
model._current_fx_name = hook
model.log(hook, 1)
result_mock.assert_called_with(
hook, hook, torch.tensor(1), on_step=on_step, on_epoch=on_epoch, **extra_kwargs
)
trainer = Trainer()
model = BoringModel()
model.trainer = trainer
extra_kwargs = {
k: ANY
for k in inspect.signature(ResultCollection.log).parameters
if k not in ["self", "fx", "name", "value", "on_step", "on_epoch"]
}
all_logging_hooks = {k for k in _FxValidator.functions if _FxValidator.functions[k]}
with mock.patch(
"pytorch_lightning.trainer.connectors.logger_connector.result.ResultCollection.log", return_value=None
) as result_mock:
trainer.state.stage = RunningStage.TRAINING
hooks = [
"on_before_backward",
"on_after_backward",
"on_before_optimizer_step",
"on_before_zero_grad",
"training_step",
"training_step_end",
"on_batch_start",
"on_batch_end",
"on_train_batch_start",
"on_train_batch_end",
]
all_logging_hooks = all_logging_hooks - set(hooks)
_make_assertion(model, hooks, result_mock, on_step=True, on_epoch=False, extra_kwargs=extra_kwargs)
hooks = [
"on_train_start",
"on_train_epoch_start",
"on_train_epoch_end",
"on_epoch_start",
"on_epoch_end",
"training_epoch_end",
]
all_logging_hooks = all_logging_hooks - set(hooks)
_make_assertion(model, hooks, result_mock, on_step=False, on_epoch=True, extra_kwargs=extra_kwargs)
trainer.state.stage = RunningStage.VALIDATING
trainer.state.fn = TrainerFn.VALIDATING
hooks = [
"on_validation_start",
"on_validation_epoch_start",
"on_validation_epoch_end",
"on_validation_batch_start",
"on_validation_batch_end",
"validation_step",
"validation_step_end",
"validation_epoch_end",
]
all_logging_hooks = all_logging_hooks - set(hooks)
_make_assertion(model, hooks, result_mock, on_step=False, on_epoch=True, extra_kwargs=extra_kwargs)
trainer.state.stage = RunningStage.TESTING
trainer.state.fn = TrainerFn.TESTING
hooks = [
"on_test_start",
"on_test_epoch_start",
"on_test_epoch_end",
"on_test_batch_start",
"on_test_batch_end",
"test_step",
"test_step_end",
"test_epoch_end",
]
all_logging_hooks = all_logging_hooks - set(hooks)
_make_assertion(model, hooks, result_mock, on_step=False, on_epoch=True, extra_kwargs=extra_kwargs)
# just to ensure we checked all possible logging hooks here
assert len(all_logging_hooks) == 0

View File

@ -716,10 +716,10 @@ def test_on_epoch_logging_with_sum_and_on_batch_start(tmpdir):
assert all(v == 3 for v in self.trainer.callback_metrics.values())
def on_train_batch_start(self, batch, batch_idx):
self.log("on_train_batch_start", 1.0, reduce_fx="sum")
self.log("on_train_batch_start", 1.0, on_step=False, on_epoch=True, reduce_fx="sum")
def on_train_batch_end(self, outputs, batch, batch_idx):
self.log("on_train_batch_end", 1.0, reduce_fx="sum")
self.log("on_train_batch_end", 1.0, on_step=False, on_epoch=True, reduce_fx="sum")
def on_validation_batch_start(self, batch, batch_idx, dataloader_idx):
self.log("on_validation_batch_start", 1.0, reduce_fx="sum")