Fix default logging levels for train step specific hooks (#10756)
This commit is contained in:
parent
088818fbc6
commit
753cc4dfad
|
@ -196,6 +196,10 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
|
|||
- Fixed a consolidation error in Lite when attempting to save the state dict of a sharded optimizer ([#10746](https://github.com/PyTorchLightning/pytorch-lightning/pull/10746))
|
||||
|
||||
|
||||
- Fixed the default logging level for batch hooks associated with training from `on_step=False, on_epoch=True` to `on_step=True, on_epoch=False` ([#10756](https://github.com/PyTorchLightning/pytorch-lightning/pull/10756))
|
||||
|
||||
|
||||
|
||||
-
|
||||
|
||||
|
||||
|
|
|
@ -353,10 +353,6 @@ class LightningModule(
|
|||
value, object, self.__check_allowed, name, value, wrong_dtype=(numbers.Number, Metric, Tensor, dict)
|
||||
)
|
||||
|
||||
# set the default depending on the fx_name
|
||||
on_step = self.__auto_choose_log_on_step(on_step)
|
||||
on_epoch = self.__auto_choose_log_on_epoch(on_epoch)
|
||||
|
||||
if self.trainer is None:
|
||||
# not an error to support testing the `*_step` methods without a `Trainer` reference
|
||||
rank_zero_warn(
|
||||
|
@ -375,7 +371,10 @@ class LightningModule(
|
|||
raise MisconfigurationException(
|
||||
"You are trying to `self.log()` but it is not managed by the `Trainer` control flow"
|
||||
)
|
||||
_FxValidator.check_logging(self._current_fx_name, on_step=on_step, on_epoch=on_epoch)
|
||||
|
||||
on_step, on_epoch = _FxValidator.check_logging_and_get_default_levels(
|
||||
self._current_fx_name, on_step=on_step, on_epoch=on_epoch
|
||||
)
|
||||
|
||||
# make sure user doesn't introduce logic for multi-dataloaders
|
||||
if "/dataloader_idx_" in name:
|
||||
|
@ -530,18 +529,6 @@ class LightningModule(
|
|||
"""
|
||||
self.log_dict(grad_norm_dict, on_step=True, on_epoch=True, prog_bar=True, logger=True)
|
||||
|
||||
def __auto_choose_log_on_step(self, on_step: Optional[bool]) -> bool:
|
||||
if on_step is None:
|
||||
on_step = False
|
||||
on_step |= self._current_fx_name in ("training_step", "training_step_end")
|
||||
return on_step
|
||||
|
||||
def __auto_choose_log_on_epoch(self, on_epoch: Optional[bool]) -> bool:
|
||||
if on_epoch is None:
|
||||
on_epoch = True
|
||||
on_epoch &= self._current_fx_name not in ("training_step", "training_step_end")
|
||||
return on_epoch
|
||||
|
||||
def all_gather(
|
||||
self, data: Union[torch.Tensor, Dict, List, Tuple], group: Optional[Any] = None, sync_grads: bool = False
|
||||
):
|
||||
|
|
|
@ -11,7 +11,7 @@
|
|||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from typing import Tuple, Union
|
||||
from typing import Optional, Tuple, Union
|
||||
|
||||
from typing_extensions import TypedDict
|
||||
|
||||
|
@ -20,50 +20,98 @@ from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
|||
|
||||
class _FxValidator:
|
||||
class _LogOptions(TypedDict):
|
||||
on_step: Union[Tuple[bool], Tuple[bool, bool]]
|
||||
on_epoch: Union[Tuple[bool], Tuple[bool, bool]]
|
||||
allowed_on_step: Union[Tuple[bool], Tuple[bool, bool]]
|
||||
allowed_on_epoch: Union[Tuple[bool], Tuple[bool, bool]]
|
||||
default_on_step: bool
|
||||
default_on_epoch: bool
|
||||
|
||||
functions = {
|
||||
"on_before_accelerator_backend_setup": None,
|
||||
"on_configure_sharded_model": None,
|
||||
"on_before_backward": _LogOptions(on_step=(False, True), on_epoch=(False, True)),
|
||||
"on_after_backward": _LogOptions(on_step=(False, True), on_epoch=(False, True)),
|
||||
"on_before_optimizer_step": _LogOptions(on_step=(False, True), on_epoch=(False, True)),
|
||||
"on_before_zero_grad": _LogOptions(on_step=(False, True), on_epoch=(False, True)),
|
||||
"on_before_backward": _LogOptions(
|
||||
allowed_on_step=(False, True), allowed_on_epoch=(False, True), default_on_step=True, default_on_epoch=False
|
||||
),
|
||||
"on_after_backward": _LogOptions(
|
||||
allowed_on_step=(False, True), allowed_on_epoch=(False, True), default_on_step=True, default_on_epoch=False
|
||||
),
|
||||
"on_before_optimizer_step": _LogOptions(
|
||||
allowed_on_step=(False, True), allowed_on_epoch=(False, True), default_on_step=True, default_on_epoch=False
|
||||
),
|
||||
"on_before_zero_grad": _LogOptions(
|
||||
allowed_on_step=(False, True), allowed_on_epoch=(False, True), default_on_step=True, default_on_epoch=False
|
||||
),
|
||||
"on_init_start": None,
|
||||
"on_init_end": None,
|
||||
"on_fit_start": None,
|
||||
"on_fit_end": None,
|
||||
"on_sanity_check_start": None,
|
||||
"on_sanity_check_end": None,
|
||||
"on_train_start": _LogOptions(on_step=(False,), on_epoch=(True,)),
|
||||
"on_train_start": _LogOptions(
|
||||
allowed_on_step=(False,), allowed_on_epoch=(True,), default_on_step=False, default_on_epoch=True
|
||||
),
|
||||
"on_train_end": None,
|
||||
"on_validation_start": _LogOptions(on_step=(False,), on_epoch=(True,)),
|
||||
"on_validation_start": _LogOptions(
|
||||
allowed_on_step=(False,), allowed_on_epoch=(True,), default_on_step=False, default_on_epoch=True
|
||||
),
|
||||
"on_validation_end": None,
|
||||
"on_test_start": _LogOptions(on_step=(False,), on_epoch=(True,)),
|
||||
"on_test_start": _LogOptions(
|
||||
allowed_on_step=(False,), allowed_on_epoch=(True,), default_on_step=False, default_on_epoch=True
|
||||
),
|
||||
"on_test_end": None,
|
||||
"on_predict_start": None,
|
||||
"on_predict_end": None,
|
||||
"on_pretrain_routine_start": None,
|
||||
"on_pretrain_routine_end": None,
|
||||
"on_train_epoch_start": _LogOptions(on_step=(False,), on_epoch=(True,)),
|
||||
"on_train_epoch_end": _LogOptions(on_step=(False,), on_epoch=(True,)),
|
||||
"on_validation_epoch_start": _LogOptions(on_step=(False,), on_epoch=(True,)),
|
||||
"on_validation_epoch_end": _LogOptions(on_step=(False,), on_epoch=(True,)),
|
||||
"on_test_epoch_start": _LogOptions(on_step=(False,), on_epoch=(True,)),
|
||||
"on_test_epoch_end": _LogOptions(on_step=(False,), on_epoch=(True,)),
|
||||
"on_train_epoch_start": _LogOptions(
|
||||
allowed_on_step=(False,), allowed_on_epoch=(True,), default_on_step=False, default_on_epoch=True
|
||||
),
|
||||
"on_train_epoch_end": _LogOptions(
|
||||
allowed_on_step=(False,), allowed_on_epoch=(True,), default_on_step=False, default_on_epoch=True
|
||||
),
|
||||
"on_validation_epoch_start": _LogOptions(
|
||||
allowed_on_step=(False,), allowed_on_epoch=(True,), default_on_step=False, default_on_epoch=True
|
||||
),
|
||||
"on_validation_epoch_end": _LogOptions(
|
||||
allowed_on_step=(False,), allowed_on_epoch=(True,), default_on_step=False, default_on_epoch=True
|
||||
),
|
||||
"on_test_epoch_start": _LogOptions(
|
||||
allowed_on_step=(False,), allowed_on_epoch=(True,), default_on_step=False, default_on_epoch=True
|
||||
),
|
||||
"on_test_epoch_end": _LogOptions(
|
||||
allowed_on_step=(False,), allowed_on_epoch=(True,), default_on_step=False, default_on_epoch=True
|
||||
),
|
||||
"on_predict_epoch_start": None,
|
||||
"on_predict_epoch_end": None,
|
||||
"on_epoch_start": _LogOptions(on_step=(False,), on_epoch=(True,)),
|
||||
"on_epoch_end": _LogOptions(on_step=(False,), on_epoch=(True,)),
|
||||
"on_batch_start": _LogOptions(on_step=(False, True), on_epoch=(False, True)),
|
||||
"on_batch_end": _LogOptions(on_step=(False, True), on_epoch=(False, True)),
|
||||
"on_train_batch_start": _LogOptions(on_step=(False, True), on_epoch=(False, True)),
|
||||
"on_train_batch_end": _LogOptions(on_step=(False, True), on_epoch=(False, True)),
|
||||
"on_validation_batch_start": _LogOptions(on_step=(False, True), on_epoch=(False, True)),
|
||||
"on_validation_batch_end": _LogOptions(on_step=(False, True), on_epoch=(False, True)),
|
||||
"on_test_batch_start": _LogOptions(on_step=(False, True), on_epoch=(False, True)),
|
||||
"on_test_batch_end": _LogOptions(on_step=(False, True), on_epoch=(False, True)),
|
||||
"on_epoch_start": _LogOptions(
|
||||
allowed_on_step=(False,), allowed_on_epoch=(True,), default_on_step=False, default_on_epoch=True
|
||||
),
|
||||
"on_epoch_end": _LogOptions(
|
||||
allowed_on_step=(False,), allowed_on_epoch=(True,), default_on_step=False, default_on_epoch=True
|
||||
),
|
||||
"on_batch_start": _LogOptions(
|
||||
allowed_on_step=(False, True), allowed_on_epoch=(False, True), default_on_step=True, default_on_epoch=False
|
||||
),
|
||||
"on_batch_end": _LogOptions(
|
||||
allowed_on_step=(False, True), allowed_on_epoch=(False, True), default_on_step=True, default_on_epoch=False
|
||||
),
|
||||
"on_train_batch_start": _LogOptions(
|
||||
allowed_on_step=(False, True), allowed_on_epoch=(False, True), default_on_step=True, default_on_epoch=False
|
||||
),
|
||||
"on_train_batch_end": _LogOptions(
|
||||
allowed_on_step=(False, True), allowed_on_epoch=(False, True), default_on_step=True, default_on_epoch=False
|
||||
),
|
||||
"on_validation_batch_start": _LogOptions(
|
||||
allowed_on_step=(False, True), allowed_on_epoch=(False, True), default_on_step=False, default_on_epoch=True
|
||||
),
|
||||
"on_validation_batch_end": _LogOptions(
|
||||
allowed_on_step=(False, True), allowed_on_epoch=(False, True), default_on_step=False, default_on_epoch=True
|
||||
),
|
||||
"on_test_batch_start": _LogOptions(
|
||||
allowed_on_step=(False, True), allowed_on_epoch=(False, True), default_on_step=False, default_on_epoch=True
|
||||
),
|
||||
"on_test_batch_end": _LogOptions(
|
||||
allowed_on_step=(False, True), allowed_on_epoch=(False, True), default_on_step=False, default_on_epoch=True
|
||||
),
|
||||
"on_predict_batch_start": None,
|
||||
"on_predict_batch_end": None,
|
||||
"on_keyboard_interrupt": None,
|
||||
|
@ -73,16 +121,34 @@ class _FxValidator:
|
|||
"setup": None,
|
||||
"teardown": None,
|
||||
"configure_sharded_model": None,
|
||||
"training_step": _LogOptions(on_step=(False, True), on_epoch=(False, True)),
|
||||
"validation_step": _LogOptions(on_step=(False, True), on_epoch=(False, True)),
|
||||
"test_step": _LogOptions(on_step=(False, True), on_epoch=(False, True)),
|
||||
"training_step": _LogOptions(
|
||||
allowed_on_step=(False, True), allowed_on_epoch=(False, True), default_on_step=True, default_on_epoch=False
|
||||
),
|
||||
"validation_step": _LogOptions(
|
||||
allowed_on_step=(False, True), allowed_on_epoch=(False, True), default_on_step=False, default_on_epoch=True
|
||||
),
|
||||
"test_step": _LogOptions(
|
||||
allowed_on_step=(False, True), allowed_on_epoch=(False, True), default_on_step=False, default_on_epoch=True
|
||||
),
|
||||
"predict_step": None,
|
||||
"training_step_end": _LogOptions(on_step=(False, True), on_epoch=(False, True)),
|
||||
"validation_step_end": _LogOptions(on_step=(False, True), on_epoch=(False, True)),
|
||||
"test_step_end": _LogOptions(on_step=(False, True), on_epoch=(False, True)),
|
||||
"training_epoch_end": _LogOptions(on_step=(False,), on_epoch=(True,)),
|
||||
"validation_epoch_end": _LogOptions(on_step=(False,), on_epoch=(True,)),
|
||||
"test_epoch_end": _LogOptions(on_step=(False,), on_epoch=(True,)),
|
||||
"training_step_end": _LogOptions(
|
||||
allowed_on_step=(False, True), allowed_on_epoch=(False, True), default_on_step=True, default_on_epoch=False
|
||||
),
|
||||
"validation_step_end": _LogOptions(
|
||||
allowed_on_step=(False, True), allowed_on_epoch=(False, True), default_on_step=False, default_on_epoch=True
|
||||
),
|
||||
"test_step_end": _LogOptions(
|
||||
allowed_on_step=(False, True), allowed_on_epoch=(False, True), default_on_step=False, default_on_epoch=True
|
||||
),
|
||||
"training_epoch_end": _LogOptions(
|
||||
allowed_on_step=(False,), allowed_on_epoch=(True,), default_on_step=False, default_on_epoch=True
|
||||
),
|
||||
"validation_epoch_end": _LogOptions(
|
||||
allowed_on_step=(False,), allowed_on_epoch=(True,), default_on_step=False, default_on_epoch=True
|
||||
),
|
||||
"test_epoch_end": _LogOptions(
|
||||
allowed_on_step=(False,), allowed_on_epoch=(True,), default_on_step=False, default_on_epoch=True
|
||||
),
|
||||
"configure_optimizers": None,
|
||||
"on_train_dataloader": None,
|
||||
"train_dataloader": None,
|
||||
|
@ -97,22 +163,48 @@ class _FxValidator:
|
|||
}
|
||||
|
||||
@classmethod
|
||||
def check_logging(cls, fx_name: str, on_step: bool, on_epoch: bool) -> None:
|
||||
"""Check if the given function name is allowed to log."""
|
||||
def check_logging(cls, fx_name: str) -> None:
|
||||
"""Check if the given hook is allowed to log."""
|
||||
if fx_name not in cls.functions:
|
||||
raise RuntimeError(
|
||||
f"Logging inside `{fx_name}` is not implemented."
|
||||
" Please, open an issue in `https://github.com/PyTorchLightning/pytorch-lightning/issues`"
|
||||
" Please, open an issue in `https://github.com/PyTorchLightning/pytorch-lightning/issues`."
|
||||
)
|
||||
allowed = cls.functions[fx_name]
|
||||
if allowed is None:
|
||||
raise MisconfigurationException(f"You can't `self.log()` inside `{fx_name}`")
|
||||
|
||||
m = "You can't `self.log({}={})` inside `{}`, must be one of {}"
|
||||
if on_step not in allowed["on_step"]:
|
||||
msg = m.format("on_step", on_step, fx_name, allowed["on_step"])
|
||||
if cls.functions[fx_name] is None:
|
||||
raise MisconfigurationException(f"You can't `self.log()` inside `{fx_name}`.")
|
||||
|
||||
@classmethod
|
||||
def get_default_logging_levels(
|
||||
cls, fx_name: str, on_step: Optional[bool], on_epoch: Optional[bool]
|
||||
) -> Tuple[bool, bool]:
|
||||
"""Return default logging levels for given hook."""
|
||||
fx_config = cls.functions[fx_name]
|
||||
assert fx_config is not None
|
||||
on_step = fx_config["default_on_step"] if on_step is None else on_step
|
||||
on_epoch = fx_config["default_on_epoch"] if on_epoch is None else on_epoch
|
||||
return on_step, on_epoch
|
||||
|
||||
@classmethod
|
||||
def check_logging_levels(cls, fx_name: str, on_step: bool, on_epoch: bool) -> None:
|
||||
"""Check if the logging levels are allowed in the given hook."""
|
||||
fx_config = cls.functions[fx_name]
|
||||
assert fx_config is not None
|
||||
m = "You can't `self.log({}={})` inside `{}`, must be one of {}."
|
||||
if on_step not in fx_config["allowed_on_step"]:
|
||||
msg = m.format("on_step", on_step, fx_name, fx_config["allowed_on_step"])
|
||||
raise MisconfigurationException(msg)
|
||||
|
||||
if on_epoch not in allowed["on_epoch"]:
|
||||
msg = m.format("on_epoch", on_epoch, fx_name, allowed["on_epoch"])
|
||||
if on_epoch not in fx_config["allowed_on_epoch"]:
|
||||
msg = m.format("on_epoch", on_epoch, fx_name, fx_config["allowed_on_epoch"])
|
||||
raise MisconfigurationException(msg)
|
||||
|
||||
@classmethod
|
||||
def check_logging_and_get_default_levels(
|
||||
cls, fx_name: str, on_step: Optional[bool], on_epoch: Optional[bool]
|
||||
) -> Tuple[bool, bool]:
|
||||
"""Check if the given hook name is allowed to log and return logging levels."""
|
||||
cls.check_logging(fx_name)
|
||||
on_step, on_epoch = cls.get_default_logging_levels(fx_name, on_step, on_epoch)
|
||||
cls.check_logging_levels(fx_name, on_step, on_epoch)
|
||||
return on_step, on_epoch
|
||||
|
|
|
@ -141,17 +141,17 @@ def test_fx_validator(tmpdir):
|
|||
and func_name not in ["on_train_end", "on_test_end", "on_validation_end"]
|
||||
)
|
||||
if allowed:
|
||||
validator.check_logging(fx_name=func_name, on_step=on_step, on_epoch=on_epoch)
|
||||
validator.check_logging_levels(fx_name=func_name, on_step=on_step, on_epoch=on_epoch)
|
||||
if not is_start and is_stage:
|
||||
with pytest.raises(MisconfigurationException, match="must be one of"):
|
||||
validator.check_logging(fx_name=func_name, on_step=True, on_epoch=on_epoch)
|
||||
validator.check_logging_levels(fx_name=func_name, on_step=True, on_epoch=on_epoch)
|
||||
else:
|
||||
assert func_name in not_supported
|
||||
with pytest.raises(MisconfigurationException, match="You can't"):
|
||||
validator.check_logging(fx_name=func_name, on_step=on_step, on_epoch=on_epoch)
|
||||
validator.check_logging(fx_name=func_name)
|
||||
|
||||
with pytest.raises(RuntimeError, match="Logging inside `foo` is not implemented"):
|
||||
validator.check_logging("foo", False, False)
|
||||
validator.check_logging("foo")
|
||||
|
||||
|
||||
class HookedCallback(Callback):
|
||||
|
|
|
@ -0,0 +1,108 @@
|
|||
# Copyright The PyTorch Lightning team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Test logging in the training loop."""
|
||||
import inspect
|
||||
from unittest import mock
|
||||
from unittest.mock import ANY
|
||||
|
||||
import torch
|
||||
|
||||
from pytorch_lightning import Trainer
|
||||
from pytorch_lightning.trainer.connectors.logger_connector.fx_validator import _FxValidator
|
||||
from pytorch_lightning.trainer.connectors.logger_connector.result import ResultCollection
|
||||
from pytorch_lightning.trainer.states import RunningStage, TrainerFn
|
||||
from tests.helpers.boring_model import BoringModel
|
||||
|
||||
|
||||
def test_default_level_for_hooks_that_support_logging():
|
||||
def _make_assertion(model, hooks, result_mock, on_step, on_epoch, extra_kwargs):
|
||||
for hook in hooks:
|
||||
model._current_fx_name = hook
|
||||
model.log(hook, 1)
|
||||
result_mock.assert_called_with(
|
||||
hook, hook, torch.tensor(1), on_step=on_step, on_epoch=on_epoch, **extra_kwargs
|
||||
)
|
||||
|
||||
trainer = Trainer()
|
||||
model = BoringModel()
|
||||
model.trainer = trainer
|
||||
extra_kwargs = {
|
||||
k: ANY
|
||||
for k in inspect.signature(ResultCollection.log).parameters
|
||||
if k not in ["self", "fx", "name", "value", "on_step", "on_epoch"]
|
||||
}
|
||||
all_logging_hooks = {k for k in _FxValidator.functions if _FxValidator.functions[k]}
|
||||
|
||||
with mock.patch(
|
||||
"pytorch_lightning.trainer.connectors.logger_connector.result.ResultCollection.log", return_value=None
|
||||
) as result_mock:
|
||||
trainer.state.stage = RunningStage.TRAINING
|
||||
hooks = [
|
||||
"on_before_backward",
|
||||
"on_after_backward",
|
||||
"on_before_optimizer_step",
|
||||
"on_before_zero_grad",
|
||||
"training_step",
|
||||
"training_step_end",
|
||||
"on_batch_start",
|
||||
"on_batch_end",
|
||||
"on_train_batch_start",
|
||||
"on_train_batch_end",
|
||||
]
|
||||
all_logging_hooks = all_logging_hooks - set(hooks)
|
||||
_make_assertion(model, hooks, result_mock, on_step=True, on_epoch=False, extra_kwargs=extra_kwargs)
|
||||
|
||||
hooks = [
|
||||
"on_train_start",
|
||||
"on_train_epoch_start",
|
||||
"on_train_epoch_end",
|
||||
"on_epoch_start",
|
||||
"on_epoch_end",
|
||||
"training_epoch_end",
|
||||
]
|
||||
all_logging_hooks = all_logging_hooks - set(hooks)
|
||||
_make_assertion(model, hooks, result_mock, on_step=False, on_epoch=True, extra_kwargs=extra_kwargs)
|
||||
|
||||
trainer.state.stage = RunningStage.VALIDATING
|
||||
trainer.state.fn = TrainerFn.VALIDATING
|
||||
hooks = [
|
||||
"on_validation_start",
|
||||
"on_validation_epoch_start",
|
||||
"on_validation_epoch_end",
|
||||
"on_validation_batch_start",
|
||||
"on_validation_batch_end",
|
||||
"validation_step",
|
||||
"validation_step_end",
|
||||
"validation_epoch_end",
|
||||
]
|
||||
all_logging_hooks = all_logging_hooks - set(hooks)
|
||||
_make_assertion(model, hooks, result_mock, on_step=False, on_epoch=True, extra_kwargs=extra_kwargs)
|
||||
|
||||
trainer.state.stage = RunningStage.TESTING
|
||||
trainer.state.fn = TrainerFn.TESTING
|
||||
hooks = [
|
||||
"on_test_start",
|
||||
"on_test_epoch_start",
|
||||
"on_test_epoch_end",
|
||||
"on_test_batch_start",
|
||||
"on_test_batch_end",
|
||||
"test_step",
|
||||
"test_step_end",
|
||||
"test_epoch_end",
|
||||
]
|
||||
all_logging_hooks = all_logging_hooks - set(hooks)
|
||||
_make_assertion(model, hooks, result_mock, on_step=False, on_epoch=True, extra_kwargs=extra_kwargs)
|
||||
|
||||
# just to ensure we checked all possible logging hooks here
|
||||
assert len(all_logging_hooks) == 0
|
|
@ -716,10 +716,10 @@ def test_on_epoch_logging_with_sum_and_on_batch_start(tmpdir):
|
|||
assert all(v == 3 for v in self.trainer.callback_metrics.values())
|
||||
|
||||
def on_train_batch_start(self, batch, batch_idx):
|
||||
self.log("on_train_batch_start", 1.0, reduce_fx="sum")
|
||||
self.log("on_train_batch_start", 1.0, on_step=False, on_epoch=True, reduce_fx="sum")
|
||||
|
||||
def on_train_batch_end(self, outputs, batch, batch_idx):
|
||||
self.log("on_train_batch_end", 1.0, reduce_fx="sum")
|
||||
self.log("on_train_batch_end", 1.0, on_step=False, on_epoch=True, reduce_fx="sum")
|
||||
|
||||
def on_validation_batch_start(self, batch, batch_idx, dataloader_idx):
|
||||
self.log("on_validation_batch_start", 1.0, reduce_fx="sum")
|
||||
|
|
Loading…
Reference in New Issue