From e1442d247e0e4967dd2772bdcf5166226c974f89 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Carlos=20Mochol=C3=AD?= Date: Fri, 20 Aug 2021 18:22:03 +0200 Subject: [PATCH] Always use `trainer.call_hook` (#8498) --- CHANGELOG.md | 4 + pytorch_lightning/core/lightning.py | 17 ++- .../loops/dataloader/evaluation_loop.py | 5 +- .../trainer/connectors/callback_connector.py | 2 +- .../trainer/connectors/data_connector.py | 2 +- .../logger_connector/fx_validator.py | 21 +-- pytorch_lightning/trainer/data_loading.py | 5 +- pytorch_lightning/trainer/optimizers.py | 7 +- pytorch_lightning/trainer/trainer.py | 57 ++++---- tests/core/test_datamodules.py | 3 - .../connectors/test_callback_connector.py | 27 ++-- .../trainer/logging_/test_logger_connector.py | 135 +++++++++++++++++- 12 files changed, 215 insertions(+), 70 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 2e3d9382a4..00ab2503a8 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -89,6 +89,10 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - The accelerator and training type plugin `setup` hooks no longer have a `model` argument ([#8536](https://github.com/PyTorchLightning/pytorch-lightning/pull/8536)) + +- Improve coverage of `self.log`-ing in any `LightningModule` or `Callback` hook ([#8498](https://github.com/PyTorchLightning/pytorch-lightning/pull/8498)) + + - Removed restrictions in the trainer that loggers can only log from rank 0. Existing logger behavior has not changed. ([#8608] (https://github.com/PyTorchLightning/pytorch-lightning/pull/8608)) diff --git a/pytorch_lightning/core/lightning.py b/pytorch_lightning/core/lightning.py index fe924ed147..c847eea57c 100644 --- a/pytorch_lightning/core/lightning.py +++ b/pytorch_lightning/core/lightning.py @@ -402,9 +402,22 @@ class LightningModule( on_step = self.__auto_choose_log_on_step(on_step) on_epoch = self.__auto_choose_log_on_epoch(on_epoch) + if self.trainer is None: + raise MisconfigurationException( + "You are trying to `self.log()` but the `self.trainer` reference is not registered on the model yet." + " This is most likely because the model hasn't been passed to the `Trainer`" + ) results = self.trainer._results - assert results is not None - assert self._current_fx_name is not None + if results is None: + raise MisconfigurationException( + "You are trying to `self.log()` but the loop `ResultCollection` is not registered" + " yet. This is most likely because you are trying to log in a `predict` hook," + " but it doesn't support logging" + ) + if self._current_fx_name is None: + raise MisconfigurationException( + "You are trying to `self.log()` but it is not managed by the `Trainer` control flow" + ) FxValidator.check_logging(self._current_fx_name, on_step=on_step, on_epoch=on_epoch) # make sure user doesn't introduce logic for multi-dataloaders diff --git a/pytorch_lightning/loops/dataloader/evaluation_loop.py b/pytorch_lightning/loops/dataloader/evaluation_loop.py index d310d42f93..52998676e1 100644 --- a/pytorch_lightning/loops/dataloader/evaluation_loop.py +++ b/pytorch_lightning/loops/dataloader/evaluation_loop.py @@ -179,11 +179,10 @@ class EvaluationLoop(DataLoaderLoop): def on_evaluation_model_eval(self) -> None: """Sets model to eval mode""" - model_ref = self.trainer.lightning_module if self.trainer.testing: - model_ref.on_test_model_eval() + self.trainer.call_hook("on_test_model_eval") else: - model_ref.on_validation_model_eval() + self.trainer.call_hook("on_validation_model_eval") def on_evaluation_model_train(self) -> None: """Sets model to train mode""" diff --git a/pytorch_lightning/trainer/connectors/callback_connector.py b/pytorch_lightning/trainer/connectors/callback_connector.py index cd8183b68e..4bdafbf976 100644 --- a/pytorch_lightning/trainer/connectors/callback_connector.py +++ b/pytorch_lightning/trainer/connectors/callback_connector.py @@ -139,7 +139,7 @@ class CallbackConnector: In addition, all :class:`~pytorch_lightning.callbacks.model_checkpoint.ModelCheckpoint` callbacks will be pushed to the end of the list, ensuring they run last. """ - model_callbacks = self.trainer.lightning_module.configure_callbacks() + model_callbacks = self.trainer.call_hook("configure_callbacks") if not model_callbacks: return model_callback_types = {type(c) for c in model_callbacks} diff --git a/pytorch_lightning/trainer/connectors/data_connector.py b/pytorch_lightning/trainer/connectors/data_connector.py index 629be97f29..5c98eb6878 100644 --- a/pytorch_lightning/trainer/connectors/data_connector.py +++ b/pytorch_lightning/trainer/connectors/data_connector.py @@ -73,7 +73,7 @@ class DataConnector: if self.can_prepare_data(): if self.trainer.datamodule is not None: self.trainer.datamodule.prepare_data() - self.trainer.lightning_module.prepare_data() + self.trainer.call_hook("prepare_data") self.trainer._is_data_prepared = True def can_prepare_data(self): diff --git a/pytorch_lightning/trainer/connectors/logger_connector/fx_validator.py b/pytorch_lightning/trainer/connectors/logger_connector/fx_validator.py index f2ad8f1130..871ba0fa86 100644 --- a/pytorch_lightning/trainer/connectors/logger_connector/fx_validator.py +++ b/pytorch_lightning/trainer/connectors/logger_connector/fx_validator.py @@ -77,12 +77,17 @@ class FxValidator: training_epoch_end=dict(on_step=(False,), on_epoch=(True,)), validation_epoch_end=dict(on_step=(False,), on_epoch=(True,)), test_epoch_end=dict(on_step=(False,), on_epoch=(True,)), - on_before_batch_transfer=None, - transfer_batch_to_device=None, - on_after_batch_transfer=None, - backward=None, - optimizer_step=None, - # TODO(@carmocca): some {step,epoch}_{start,end} are missing + configure_optimizers=None, + on_train_dataloader=None, + train_dataloader=None, + on_val_dataloader=None, + val_dataloader=None, + on_test_dataloader=None, + test_dataloader=None, + prepare_data=None, + configure_callbacks=None, + on_validation_model_eval=None, + on_test_model_eval=None, ) @classmethod @@ -90,12 +95,12 @@ class FxValidator: """Check if the given function name is allowed to log""" if fx_name not in cls.functions: raise RuntimeError( - f"You are trying to `self.log()` inside `{fx_name}` but it is not implemented." + f"Logging inside `{fx_name}` is not implemented." " Please, open an issue in `https://github.com/PyTorchLightning/pytorch-lightning/issues`" ) allowed = cls.functions[fx_name] if allowed is None: - raise MisconfigurationException(f"{fx_name} function doesn't support logging using `self.log()`") + raise MisconfigurationException(f"You can't `self.log()` inside `{fx_name}`") m = "You can't `self.log({}={})` inside `{}`, must be one of {}" if on_step not in allowed["on_step"]: diff --git a/pytorch_lightning/trainer/data_loading.py b/pytorch_lightning/trainer/data_loading.py index 2ea2a74c7a..26c295b89d 100644 --- a/pytorch_lightning/trainer/data_loading.py +++ b/pytorch_lightning/trainer/data_loading.py @@ -523,8 +523,9 @@ class TrainerDataLoadingMixin(ABC): Returns: The dataloader """ - self.call_hook(f"on_{stage.dataloader_prefix}_dataloader") - dataloader = getattr(model, f"{stage.dataloader_prefix}_dataloader")() + hook = f"{stage.dataloader_prefix}_dataloader" + self.call_hook("on_" + hook, pl_module=model) + dataloader = self.call_hook(hook, pl_module=model) if isinstance(dataloader, tuple): dataloader = list(dataloader) self.accelerator.barrier("get_dataloaders") diff --git a/pytorch_lightning/trainer/optimizers.py b/pytorch_lightning/trainer/optimizers.py index 782cf0ee6c..0701eab390 100644 --- a/pytorch_lightning/trainer/optimizers.py +++ b/pytorch_lightning/trainer/optimizers.py @@ -29,9 +29,10 @@ class TrainerOptimizersMixin(ABC): _lightning_optimizers: Optional[List[LightningOptimizer]] - def init_optimizers(self, model: "pl.LightningModule") -> Tuple[List, List, List]: + def init_optimizers(self, model: Optional["pl.LightningModule"]) -> Tuple[List, List, List]: + pl_module = self.lightning_module or model self._lightning_optimizers = None - optim_conf = model.configure_optimizers() + optim_conf = self.call_hook("configure_optimizers", pl_module=pl_module) if optim_conf is None: rank_zero_warn( "`LightningModule.configure_optimizers` returned `None`, this fit will run with no optimizer", @@ -95,7 +96,7 @@ class TrainerOptimizersMixin(ABC): ' * A list of the previously described dict format, with an optional "frequency" key (int)' ) - is_manual_optimization = not model.automatic_optimization + is_manual_optimization = not pl_module.automatic_optimization lr_schedulers = self.configure_schedulers(lr_schedulers, monitor, is_manual_optimization) _validate_scheduler_optimizer(optimizers, lr_schedulers) diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 887cdd46a9..becf4d4cf2 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -1103,20 +1103,14 @@ class Trainer( # -------------------------- # Pre-train # -------------------------- - # on pretrain routine start - ref_model = self.lightning_module - - self.on_pretrain_routine_start() - ref_model.on_pretrain_routine_start() + self.call_hook("on_pretrain_routine_start") # print model summary if self.is_global_zero and self.weights_summary is not None and not self.testing: max_depth = ModelSummary.MODES[self.weights_summary] - summarize(ref_model, max_depth=max_depth) + summarize(self.lightning_module, max_depth=max_depth) - # on pretrain routine end - self.on_pretrain_routine_end() - ref_model.on_pretrain_routine_end() + self.call_hook("on_pretrain_routine_end") def _run_train(self) -> None: self._pre_training_routine() @@ -1179,8 +1173,7 @@ class Trainer( stage = self.state.stage self.sanity_checking = True - # hook and callback - self.on_sanity_check_start() + self.call_hook("on_sanity_check_start") # reload dataloaders self._evaluation_loop.reload_evaluation_dataloaders() @@ -1189,7 +1182,7 @@ class Trainer( with torch.no_grad(): self._evaluation_loop.run() - self.on_sanity_check_end() + self.call_hook("on_sanity_check_end") # reset validation metrics self.logger_connector.reset() @@ -1245,8 +1238,7 @@ class Trainer( if self.datamodule is not None: self.datamodule.setup(stage=fn) - self.setup(stage=fn) - self.lightning_module.setup(stage=fn) + self.call_hook("setup", stage=fn) self.accelerator.barrier("post_setup") @@ -1259,8 +1251,8 @@ class Trainer( model_call_configure_sharded_model_hook = getattr(model, "call_configure_sharded_model_hook", False) if self.accelerator.call_configure_sharded_model_hook and not model_call_configure_sharded_model_hook: with self.accelerator.model_sharded_context(): - model.configure_sharded_model() - self.on_configure_sharded_model() + self.call_hook("configure_sharded_model") + self.call_hook("on_configure_sharded_model") model.call_configure_sharded_model_hook = True self.accelerator.call_configure_sharded_model_hook = False @@ -1272,8 +1264,7 @@ class Trainer( self.data_connector.detach_data(self.lightning_module) - self.teardown(stage=fn) - self.lightning_module.teardown(stage=fn) + self.call_hook("teardown", stage=fn) self.lightning_module._current_fx_name = None self.lightning_module._current_dataloader_idx = None @@ -1288,28 +1279,30 @@ class Trainer( # summarize profile results self.profiler.describe() - def call_hook(self, hook_name: str, *args, **kwargs) -> Any: - if self.lightning_module: - prev_fx_name = self.lightning_module._current_fx_name - self.lightning_module._current_fx_name = hook_name + def call_hook( + self, hook_name: str, *args: Any, pl_module: Optional["pl.LightningModule"] = None, **kwargs: Any + ) -> Any: + pl_module = self.lightning_module or pl_module + if pl_module: + prev_fx_name = pl_module._current_fx_name + pl_module._current_fx_name = hook_name # always profile hooks with self.profiler.profile(hook_name): # first call trainer hook - if hasattr(self, hook_name): - trainer_hook = getattr(self, hook_name) - trainer_hook(*args, **kwargs) + callback_fx = getattr(self, hook_name, None) + if callable(callback_fx): + callback_fx(*args, **kwargs) # next call hook in lightningModule output = None - model_ref = self.lightning_module - if is_overridden(hook_name, model_ref): - hook_fx = getattr(model_ref, hook_name) - output = hook_fx(*args, **kwargs) + model_fx = getattr(pl_module, hook_name, None) + if callable(model_fx): + output = model_fx(*args, **kwargs) # call the accelerator hook - if hasattr(self.accelerator, hook_name): + if hook_name not in ("setup", "teardown") and hasattr(self.accelerator, hook_name): accelerator_hook = getattr(self.accelerator, hook_name) accelerator_output = accelerator_hook(*args, **kwargs) # Rely on the accelerator output if lightningModule hook returns nothing @@ -1317,9 +1310,9 @@ class Trainer( # todo: move this data parallel logic into the data parallel plugin output = accelerator_output if output is None else output - if self.lightning_module: + if pl_module: # restore current_fx when nested context - self.lightning_module._current_fx_name = prev_fx_name + pl_module._current_fx_name = prev_fx_name return output diff --git a/tests/core/test_datamodules.py b/tests/core/test_datamodules.py index f7275f87f8..5c33a2f68a 100644 --- a/tests/core/test_datamodules.py +++ b/tests/core/test_datamodules.py @@ -34,11 +34,8 @@ from tests.helpers.utils import reset_seed @mock.patch("pytorch_lightning.trainer.trainer.Trainer.node_rank", new_callable=PropertyMock) @mock.patch("pytorch_lightning.trainer.trainer.Trainer.local_rank", new_callable=PropertyMock) def test_can_prepare_data(local_rank, node_rank): - - model = BoringModel() dm = BoringDataModule() trainer = Trainer() - trainer.model = model trainer.datamodule = dm # 1 no DM diff --git a/tests/trainer/connectors/test_callback_connector.py b/tests/trainer/connectors/test_callback_connector.py index 45efa3c82b..43158865f9 100644 --- a/tests/trainer/connectors/test_callback_connector.py +++ b/tests/trainer/connectors/test_callback_connector.py @@ -12,11 +12,10 @@ # See the License for the specific language governing permissions and # limitations under the License. import logging -from unittest.mock import Mock import torch -from pytorch_lightning import Callback, Trainer +from pytorch_lightning import Callback, LightningModule, Trainer from pytorch_lightning.callbacks import ( EarlyStopping, GradientAccumulationScheduler, @@ -36,18 +35,22 @@ def test_checkpoint_callbacks_are_last(tmpdir): lr_monitor = LearningRateMonitor() progress_bar = ProgressBar() - # no model callbacks - model = Mock() - model.configure_callbacks.return_value = [] + # no model reference trainer = Trainer(callbacks=[checkpoint1, progress_bar, lr_monitor, checkpoint2]) - trainer.model = model cb_connector = CallbackConnector(trainer) cb_connector._attach_model_callbacks() assert trainer.callbacks == [progress_bar, lr_monitor, checkpoint1, checkpoint2] + # no model callbacks + model = LightningModule() + model.configure_callbacks = lambda: [] + trainer.model = model + cb_connector._attach_model_callbacks() + assert trainer.callbacks == [progress_bar, lr_monitor, checkpoint1, checkpoint2] + # with model-specific callbacks that substitute ones in Trainer - model = Mock() - model.configure_callbacks.return_value = [checkpoint1, early_stopping, checkpoint2] + model = LightningModule() + model.configure_callbacks = lambda: [checkpoint1, early_stopping, checkpoint2] trainer = Trainer(callbacks=[progress_bar, lr_monitor, ModelCheckpoint(tmpdir)]) trainer.model = model cb_connector = CallbackConnector(trainer) @@ -89,8 +92,8 @@ def test_attach_model_callbacks(): """Test that the callbacks defined in the model and through Trainer get merged correctly.""" def assert_composition(trainer_callbacks, model_callbacks, expected): - model = Mock() - model.configure_callbacks.return_value = model_callbacks + model = LightningModule() + model.configure_callbacks = lambda: model_callbacks trainer = Trainer(checkpoint_callback=False, progress_bar_refresh_rate=0, callbacks=trainer_callbacks) trainer.model = model cb_connector = CallbackConnector(trainer) @@ -140,8 +143,8 @@ def test_attach_model_callbacks(): def test_attach_model_callbacks_override_info(caplog): """Test that the logs contain the info about overriding callbacks returned by configure_callbacks.""" - model = Mock() - model.configure_callbacks.return_value = [LearningRateMonitor(), EarlyStopping()] + model = LightningModule() + model.configure_callbacks = lambda: [LearningRateMonitor(), EarlyStopping()] trainer = Trainer(checkpoint_callback=False, callbacks=[EarlyStopping(), LearningRateMonitor(), ProgressBar()]) trainer.model = model cb_connector = CallbackConnector(trainer) diff --git a/tests/trainer/logging_/test_logger_connector.py b/tests/trainer/logging_/test_logger_connector.py index 518a401a72..ed7711b32f 100644 --- a/tests/trainer/logging_/test_logger_connector.py +++ b/tests/trainer/logging_/test_logger_connector.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from functools import partial from unittest import mock import pytest @@ -140,17 +141,145 @@ def test_fx_validator(tmpdir): if allowed: validator.check_logging(fx_name=func_name, on_step=on_step, on_epoch=on_epoch) if not is_start and is_stage: - with pytest.raises(MisconfigurationException, match="You can't"): + with pytest.raises(MisconfigurationException, match="must be one of"): validator.check_logging(fx_name=func_name, on_step=True, on_epoch=on_epoch) else: assert func_name in not_supported - with pytest.raises(MisconfigurationException, match="function doesn't support"): + with pytest.raises(MisconfigurationException, match="You can't"): validator.check_logging(fx_name=func_name, on_step=on_step, on_epoch=on_epoch) - with pytest.raises(RuntimeError, match="`foo` but it is not implemented"): + with pytest.raises(RuntimeError, match="Logging inside `foo` is not implemented"): validator.check_logging("foo", False, False) +class HookedCallback(Callback): + def __init__(self, not_supported): + def call(hook, trainer, model=None, *_, **__): + lightning_module = trainer.lightning_module or model + if lightning_module is None: + # `on_init_{start,end}` do not have the `LightningModule` available + assert hook in ("on_init_start", "on_init_end") + return + + if hook in not_supported: + with pytest.raises(MisconfigurationException, match=not_supported[hook]): + lightning_module.log("anything", 1) + else: + lightning_module.log(hook, 1) + + for h in get_members(Callback): + setattr(self, h, partial(call, h)) + + +class HookedModel(BoringModel): + def __init__(self, not_supported): + super().__init__() + pl_module_hooks = get_members(LightningModule) + pl_module_hooks.difference_update( + { + "log", + "log_dict", + # the following are problematic as they do have `self._current_fx_name` defined some times but + # not others depending on where they were called. So we cannot reliably `self.log` in them + "on_before_batch_transfer", + "transfer_batch_to_device", + "on_after_batch_transfer", + "get_progress_bar_dict", + } + ) + # remove `nn.Module` hooks + module_hooks = get_members(torch.nn.Module) + pl_module_hooks.difference_update(module_hooks) + + def call(hook, fn, *args, **kwargs): + out = fn(*args, **kwargs) + + if hook in not_supported: + with pytest.raises(MisconfigurationException, match=not_supported[hook]): + self.log("anything", 1) + else: + self.log(hook, 1) + return out + + for h in pl_module_hooks: + attr = getattr(self, h) + setattr(self, h, partial(call, h, attr)) + + +def test_fx_validator_integration(tmpdir): + """Tries to log inside all `LightningModule` and `Callback` hooks to check any expected errors""" + not_supported = { + None: "`self.trainer` reference is not registered", + "on_before_accelerator_backend_setup": "You can't", + "setup": "You can't", + "configure_sharded_model": "You can't", + "on_configure_sharded_model": "You can't", + "configure_optimizers": "You can't", + "on_fit_start": "You can't", + "on_pretrain_routine_start": "You can't", + "on_pretrain_routine_end": "You can't", + "on_train_dataloader": "You can't", + "train_dataloader": "You can't", + "on_val_dataloader": "You can't", + "val_dataloader": "You can't", + "on_validation_end": "You can't", + "on_train_end": "You can't", + "on_fit_end": "You can't", + "teardown": "You can't", + "on_sanity_check_start": "You can't", + "on_sanity_check_end": "You can't", + "prepare_data": "You can't", + "configure_callbacks": "You can't", + "on_validation_model_eval": "You can't", + "summarize": "not managed by the `Trainer", + } + model = HookedModel(not_supported) + + with pytest.raises(MisconfigurationException, match=not_supported[None]): + model.log("foo", 1) + + callback = HookedCallback(not_supported) + trainer = Trainer( + default_root_dir=tmpdir, + max_epochs=2, + limit_train_batches=1, + limit_val_batches=1, + limit_test_batches=1, + limit_predict_batches=1, + callbacks=callback, + ) + trainer.fit(model) + + not_supported.update( + { + # `lightning_module` ref is now present from the `fit` call + "on_before_accelerator_backend_setup": "You can't", + "on_test_dataloader": "You can't", + "test_dataloader": "You can't", + "on_test_model_eval": "You can't", + "on_test_end": "You can't", + } + ) + trainer.test(model, verbose=False) + + not_supported.update({k: "ResultCollection` is not registered yet" for k in not_supported}) + not_supported.update( + { + "on_predict_dataloader": "ResultCollection` is not registered yet", + "predict_dataloader": "ResultCollection` is not registered yet", + "on_predict_model_eval": "ResultCollection` is not registered yet", + "on_predict_start": "ResultCollection` is not registered yet", + "on_predict_epoch_start": "ResultCollection` is not registered yet", + "on_predict_batch_start": "ResultCollection` is not registered yet", + "predict_step": "ResultCollection` is not registered yet", + "on_predict_batch_end": "ResultCollection` is not registered yet", + "on_predict_epoch_end": "ResultCollection` is not registered yet", + "on_predict_end": "ResultCollection` is not registered yet", + } + ) + trainer.predict(model) + + @RunIf(min_gpus=2) def test_epoch_results_cache_dp(tmpdir):