Always use `trainer.call_hook` (#8498)

This commit is contained in:
Carlos Mocholí 2021-08-20 18:22:03 +02:00 committed by GitHub
parent ad3f183bc3
commit e1442d247e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
12 changed files with 215 additions and 70 deletions

View File

@ -89,6 +89,10 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- The accelerator and training type plugin `setup` hooks no longer have a `model` argument ([#8536](https://github.com/PyTorchLightning/pytorch-lightning/pull/8536))
- Improve coverage of `self.log`-ing in any `LightningModule` or `Callback` hook ([#8498](https://github.com/PyTorchLightning/pytorch-lightning/pull/8498))
- Removed restrictions in the trainer that loggers can only log from rank 0. Existing logger behavior has not changed. ([#8608]
(https://github.com/PyTorchLightning/pytorch-lightning/pull/8608))

View File

@ -402,9 +402,22 @@ class LightningModule(
on_step = self.__auto_choose_log_on_step(on_step)
on_epoch = self.__auto_choose_log_on_epoch(on_epoch)
if self.trainer is None:
raise MisconfigurationException(
"You are trying to `self.log()` but the `self.trainer` reference is not registered on the model yet."
" This is most likely because the model hasn't been passed to the `Trainer`"
)
results = self.trainer._results
assert results is not None
assert self._current_fx_name is not None
if results is None:
raise MisconfigurationException(
"You are trying to `self.log()` but the loop `ResultCollection` is not registered"
" yet. This is most likely because you are trying to log in a `predict` hook,"
" but it doesn't support logging"
)
if self._current_fx_name is None:
raise MisconfigurationException(
"You are trying to `self.log()` but it is not managed by the `Trainer` control flow"
)
FxValidator.check_logging(self._current_fx_name, on_step=on_step, on_epoch=on_epoch)
# make sure user doesn't introduce logic for multi-dataloaders

View File

@ -179,11 +179,10 @@ class EvaluationLoop(DataLoaderLoop):
def on_evaluation_model_eval(self) -> None:
"""Sets model to eval mode"""
model_ref = self.trainer.lightning_module
if self.trainer.testing:
model_ref.on_test_model_eval()
self.trainer.call_hook("on_test_model_eval")
else:
model_ref.on_validation_model_eval()
self.trainer.call_hook("on_validation_model_eval")
def on_evaluation_model_train(self) -> None:
"""Sets model to train mode"""

View File

@ -139,7 +139,7 @@ class CallbackConnector:
In addition, all :class:`~pytorch_lightning.callbacks.model_checkpoint.ModelCheckpoint` callbacks
will be pushed to the end of the list, ensuring they run last.
"""
model_callbacks = self.trainer.lightning_module.configure_callbacks()
model_callbacks = self.trainer.call_hook("configure_callbacks")
if not model_callbacks:
return
model_callback_types = {type(c) for c in model_callbacks}

View File

@ -73,7 +73,7 @@ class DataConnector:
if self.can_prepare_data():
if self.trainer.datamodule is not None:
self.trainer.datamodule.prepare_data()
self.trainer.lightning_module.prepare_data()
self.trainer.call_hook("prepare_data")
self.trainer._is_data_prepared = True
def can_prepare_data(self):

View File

@ -77,12 +77,17 @@ class FxValidator:
training_epoch_end=dict(on_step=(False,), on_epoch=(True,)),
validation_epoch_end=dict(on_step=(False,), on_epoch=(True,)),
test_epoch_end=dict(on_step=(False,), on_epoch=(True,)),
on_before_batch_transfer=None,
transfer_batch_to_device=None,
on_after_batch_transfer=None,
backward=None,
optimizer_step=None,
# TODO(@carmocca): some {step,epoch}_{start,end} are missing
configure_optimizers=None,
on_train_dataloader=None,
train_dataloader=None,
on_val_dataloader=None,
val_dataloader=None,
on_test_dataloader=None,
test_dataloader=None,
prepare_data=None,
configure_callbacks=None,
on_validation_model_eval=None,
on_test_model_eval=None,
)
@classmethod
@ -90,12 +95,12 @@ class FxValidator:
"""Check if the given function name is allowed to log"""
if fx_name not in cls.functions:
raise RuntimeError(
f"You are trying to `self.log()` inside `{fx_name}` but it is not implemented."
f"Logging inside `{fx_name}` is not implemented."
" Please, open an issue in `https://github.com/PyTorchLightning/pytorch-lightning/issues`"
)
allowed = cls.functions[fx_name]
if allowed is None:
raise MisconfigurationException(f"{fx_name} function doesn't support logging using `self.log()`")
raise MisconfigurationException(f"You can't `self.log()` inside `{fx_name}`")
m = "You can't `self.log({}={})` inside `{}`, must be one of {}"
if on_step not in allowed["on_step"]:

View File

@ -523,8 +523,9 @@ class TrainerDataLoadingMixin(ABC):
Returns:
The dataloader
"""
self.call_hook(f"on_{stage.dataloader_prefix}_dataloader")
dataloader = getattr(model, f"{stage.dataloader_prefix}_dataloader")()
hook = f"{stage.dataloader_prefix}_dataloader"
self.call_hook("on_" + hook, pl_module=model)
dataloader = self.call_hook(hook, pl_module=model)
if isinstance(dataloader, tuple):
dataloader = list(dataloader)
self.accelerator.barrier("get_dataloaders")

View File

@ -29,9 +29,10 @@ class TrainerOptimizersMixin(ABC):
_lightning_optimizers: Optional[List[LightningOptimizer]]
def init_optimizers(self, model: "pl.LightningModule") -> Tuple[List, List, List]:
def init_optimizers(self, model: Optional["pl.LightningModule"]) -> Tuple[List, List, List]:
pl_module = self.lightning_module or model
self._lightning_optimizers = None
optim_conf = model.configure_optimizers()
optim_conf = self.call_hook("configure_optimizers", pl_module=pl_module)
if optim_conf is None:
rank_zero_warn(
"`LightningModule.configure_optimizers` returned `None`, this fit will run with no optimizer",
@ -95,7 +96,7 @@ class TrainerOptimizersMixin(ABC):
' * A list of the previously described dict format, with an optional "frequency" key (int)'
)
is_manual_optimization = not model.automatic_optimization
is_manual_optimization = not pl_module.automatic_optimization
lr_schedulers = self.configure_schedulers(lr_schedulers, monitor, is_manual_optimization)
_validate_scheduler_optimizer(optimizers, lr_schedulers)

View File

@ -1103,20 +1103,14 @@ class Trainer(
# --------------------------
# Pre-train
# --------------------------
# on pretrain routine start
ref_model = self.lightning_module
self.on_pretrain_routine_start()
ref_model.on_pretrain_routine_start()
self.call_hook("on_pretrain_routine_start")
# print model summary
if self.is_global_zero and self.weights_summary is not None and not self.testing:
max_depth = ModelSummary.MODES[self.weights_summary]
summarize(ref_model, max_depth=max_depth)
summarize(self.lightning_module, max_depth=max_depth)
# on pretrain routine end
self.on_pretrain_routine_end()
ref_model.on_pretrain_routine_end()
self.call_hook("on_pretrain_routine_end")
def _run_train(self) -> None:
self._pre_training_routine()
@ -1179,8 +1173,7 @@ class Trainer(
stage = self.state.stage
self.sanity_checking = True
# hook and callback
self.on_sanity_check_start()
self.call_hook("on_sanity_check_start")
# reload dataloaders
self._evaluation_loop.reload_evaluation_dataloaders()
@ -1189,7 +1182,7 @@ class Trainer(
with torch.no_grad():
self._evaluation_loop.run()
self.on_sanity_check_end()
self.call_hook("on_sanity_check_end")
# reset validation metrics
self.logger_connector.reset()
@ -1245,8 +1238,7 @@ class Trainer(
if self.datamodule is not None:
self.datamodule.setup(stage=fn)
self.setup(stage=fn)
self.lightning_module.setup(stage=fn)
self.call_hook("setup", stage=fn)
self.accelerator.barrier("post_setup")
@ -1259,8 +1251,8 @@ class Trainer(
model_call_configure_sharded_model_hook = getattr(model, "call_configure_sharded_model_hook", False)
if self.accelerator.call_configure_sharded_model_hook and not model_call_configure_sharded_model_hook:
with self.accelerator.model_sharded_context():
model.configure_sharded_model()
self.on_configure_sharded_model()
self.call_hook("configure_sharded_model")
self.call_hook("on_configure_sharded_model")
model.call_configure_sharded_model_hook = True
self.accelerator.call_configure_sharded_model_hook = False
@ -1272,8 +1264,7 @@ class Trainer(
self.data_connector.detach_data(self.lightning_module)
self.teardown(stage=fn)
self.lightning_module.teardown(stage=fn)
self.call_hook("teardown", stage=fn)
self.lightning_module._current_fx_name = None
self.lightning_module._current_dataloader_idx = None
@ -1288,28 +1279,30 @@ class Trainer(
# summarize profile results
self.profiler.describe()
def call_hook(self, hook_name: str, *args, **kwargs) -> Any:
if self.lightning_module:
prev_fx_name = self.lightning_module._current_fx_name
self.lightning_module._current_fx_name = hook_name
def call_hook(
self, hook_name: str, *args: Any, pl_module: Optional["pl.LightningModule"] = None, **kwargs: Any
) -> Any:
pl_module = self.lightning_module or pl_module
if pl_module:
prev_fx_name = pl_module._current_fx_name
pl_module._current_fx_name = hook_name
# always profile hooks
with self.profiler.profile(hook_name):
# first call trainer hook
if hasattr(self, hook_name):
trainer_hook = getattr(self, hook_name)
trainer_hook(*args, **kwargs)
callback_fx = getattr(self, hook_name, None)
if callable(callback_fx):
callback_fx(*args, **kwargs)
# next call hook in lightningModule
output = None
model_ref = self.lightning_module
if is_overridden(hook_name, model_ref):
hook_fx = getattr(model_ref, hook_name)
output = hook_fx(*args, **kwargs)
model_fx = getattr(pl_module, hook_name, None)
if callable(model_fx):
output = model_fx(*args, **kwargs)
# call the accelerator hook
if hasattr(self.accelerator, hook_name):
if hook_name not in ("setup", "teardown") and hasattr(self.accelerator, hook_name):
accelerator_hook = getattr(self.accelerator, hook_name)
accelerator_output = accelerator_hook(*args, **kwargs)
# Rely on the accelerator output if lightningModule hook returns nothing
@ -1317,9 +1310,9 @@ class Trainer(
# todo: move this data parallel logic into the data parallel plugin
output = accelerator_output if output is None else output
if self.lightning_module:
if pl_module:
# restore current_fx when nested context
self.lightning_module._current_fx_name = prev_fx_name
pl_module._current_fx_name = prev_fx_name
return output

View File

@ -34,11 +34,8 @@ from tests.helpers.utils import reset_seed
@mock.patch("pytorch_lightning.trainer.trainer.Trainer.node_rank", new_callable=PropertyMock)
@mock.patch("pytorch_lightning.trainer.trainer.Trainer.local_rank", new_callable=PropertyMock)
def test_can_prepare_data(local_rank, node_rank):
model = BoringModel()
dm = BoringDataModule()
trainer = Trainer()
trainer.model = model
trainer.datamodule = dm
# 1 no DM

View File

@ -12,11 +12,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from unittest.mock import Mock
import torch
from pytorch_lightning import Callback, Trainer
from pytorch_lightning import Callback, LightningModule, Trainer
from pytorch_lightning.callbacks import (
EarlyStopping,
GradientAccumulationScheduler,
@ -36,18 +35,22 @@ def test_checkpoint_callbacks_are_last(tmpdir):
lr_monitor = LearningRateMonitor()
progress_bar = ProgressBar()
# no model callbacks
model = Mock()
model.configure_callbacks.return_value = []
# no model reference
trainer = Trainer(callbacks=[checkpoint1, progress_bar, lr_monitor, checkpoint2])
trainer.model = model
cb_connector = CallbackConnector(trainer)
cb_connector._attach_model_callbacks()
assert trainer.callbacks == [progress_bar, lr_monitor, checkpoint1, checkpoint2]
# no model callbacks
model = LightningModule()
model.configure_callbacks = lambda: []
trainer.model = model
cb_connector._attach_model_callbacks()
assert trainer.callbacks == [progress_bar, lr_monitor, checkpoint1, checkpoint2]
# with model-specific callbacks that substitute ones in Trainer
model = Mock()
model.configure_callbacks.return_value = [checkpoint1, early_stopping, checkpoint2]
model = LightningModule()
model.configure_callbacks = lambda: [checkpoint1, early_stopping, checkpoint2]
trainer = Trainer(callbacks=[progress_bar, lr_monitor, ModelCheckpoint(tmpdir)])
trainer.model = model
cb_connector = CallbackConnector(trainer)
@ -89,8 +92,8 @@ def test_attach_model_callbacks():
"""Test that the callbacks defined in the model and through Trainer get merged correctly."""
def assert_composition(trainer_callbacks, model_callbacks, expected):
model = Mock()
model.configure_callbacks.return_value = model_callbacks
model = LightningModule()
model.configure_callbacks = lambda: model_callbacks
trainer = Trainer(checkpoint_callback=False, progress_bar_refresh_rate=0, callbacks=trainer_callbacks)
trainer.model = model
cb_connector = CallbackConnector(trainer)
@ -140,8 +143,8 @@ def test_attach_model_callbacks():
def test_attach_model_callbacks_override_info(caplog):
"""Test that the logs contain the info about overriding callbacks returned by configure_callbacks."""
model = Mock()
model.configure_callbacks.return_value = [LearningRateMonitor(), EarlyStopping()]
model = LightningModule()
model.configure_callbacks = lambda: [LearningRateMonitor(), EarlyStopping()]
trainer = Trainer(checkpoint_callback=False, callbacks=[EarlyStopping(), LearningRateMonitor(), ProgressBar()])
trainer.model = model
cb_connector = CallbackConnector(trainer)

View File

@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from functools import partial
from unittest import mock
import pytest
@ -140,17 +141,145 @@ def test_fx_validator(tmpdir):
if allowed:
validator.check_logging(fx_name=func_name, on_step=on_step, on_epoch=on_epoch)
if not is_start and is_stage:
with pytest.raises(MisconfigurationException, match="You can't"):
with pytest.raises(MisconfigurationException, match="must be one of"):
validator.check_logging(fx_name=func_name, on_step=True, on_epoch=on_epoch)
else:
assert func_name in not_supported
with pytest.raises(MisconfigurationException, match="function doesn't support"):
with pytest.raises(MisconfigurationException, match="You can't"):
validator.check_logging(fx_name=func_name, on_step=on_step, on_epoch=on_epoch)
with pytest.raises(RuntimeError, match="`foo` but it is not implemented"):
with pytest.raises(RuntimeError, match="Logging inside `foo` is not implemented"):
validator.check_logging("foo", False, False)
class HookedCallback(Callback):
def __init__(self, not_supported):
def call(hook, trainer, model=None, *_, **__):
lightning_module = trainer.lightning_module or model
if lightning_module is None:
# `on_init_{start,end}` do not have the `LightningModule` available
assert hook in ("on_init_start", "on_init_end")
return
if hook in not_supported:
with pytest.raises(MisconfigurationException, match=not_supported[hook]):
lightning_module.log("anything", 1)
else:
lightning_module.log(hook, 1)
for h in get_members(Callback):
setattr(self, h, partial(call, h))
class HookedModel(BoringModel):
def __init__(self, not_supported):
super().__init__()
pl_module_hooks = get_members(LightningModule)
pl_module_hooks.difference_update(
{
"log",
"log_dict",
# the following are problematic as they do have `self._current_fx_name` defined some times but
# not others depending on where they were called. So we cannot reliably `self.log` in them
"on_before_batch_transfer",
"transfer_batch_to_device",
"on_after_batch_transfer",
"get_progress_bar_dict",
}
)
# remove `nn.Module` hooks
module_hooks = get_members(torch.nn.Module)
pl_module_hooks.difference_update(module_hooks)
def call(hook, fn, *args, **kwargs):
out = fn(*args, **kwargs)
if hook in not_supported:
with pytest.raises(MisconfigurationException, match=not_supported[hook]):
self.log("anything", 1)
else:
self.log(hook, 1)
return out
for h in pl_module_hooks:
attr = getattr(self, h)
setattr(self, h, partial(call, h, attr))
def test_fx_validator_integration(tmpdir):
"""Tries to log inside all `LightningModule` and `Callback` hooks to check any expected errors"""
not_supported = {
None: "`self.trainer` reference is not registered",
"on_before_accelerator_backend_setup": "You can't",
"setup": "You can't",
"configure_sharded_model": "You can't",
"on_configure_sharded_model": "You can't",
"configure_optimizers": "You can't",
"on_fit_start": "You can't",
"on_pretrain_routine_start": "You can't",
"on_pretrain_routine_end": "You can't",
"on_train_dataloader": "You can't",
"train_dataloader": "You can't",
"on_val_dataloader": "You can't",
"val_dataloader": "You can't",
"on_validation_end": "You can't",
"on_train_end": "You can't",
"on_fit_end": "You can't",
"teardown": "You can't",
"on_sanity_check_start": "You can't",
"on_sanity_check_end": "You can't",
"prepare_data": "You can't",
"configure_callbacks": "You can't",
"on_validation_model_eval": "You can't",
"summarize": "not managed by the `Trainer",
}
model = HookedModel(not_supported)
with pytest.raises(MisconfigurationException, match=not_supported[None]):
model.log("foo", 1)
callback = HookedCallback(not_supported)
trainer = Trainer(
default_root_dir=tmpdir,
max_epochs=2,
limit_train_batches=1,
limit_val_batches=1,
limit_test_batches=1,
limit_predict_batches=1,
callbacks=callback,
)
trainer.fit(model)
not_supported.update(
{
# `lightning_module` ref is now present from the `fit` call
"on_before_accelerator_backend_setup": "You can't",
"on_test_dataloader": "You can't",
"test_dataloader": "You can't",
"on_test_model_eval": "You can't",
"on_test_end": "You can't",
}
)
trainer.test(model, verbose=False)
not_supported.update({k: "ResultCollection` is not registered yet" for k in not_supported})
not_supported.update(
{
"on_predict_dataloader": "ResultCollection` is not registered yet",
"predict_dataloader": "ResultCollection` is not registered yet",
"on_predict_model_eval": "ResultCollection` is not registered yet",
"on_predict_start": "ResultCollection` is not registered yet",
"on_predict_epoch_start": "ResultCollection` is not registered yet",
"on_predict_batch_start": "ResultCollection` is not registered yet",
"predict_step": "ResultCollection` is not registered yet",
"on_predict_batch_end": "ResultCollection` is not registered yet",
"on_predict_epoch_end": "ResultCollection` is not registered yet",
"on_predict_end": "ResultCollection` is not registered yet",
}
)
trainer.predict(model)
@RunIf(min_gpus=2)
def test_epoch_results_cache_dp(tmpdir):