diff --git a/CHANGELOG.md b/CHANGELOG.md index ab42984161..6905cf4f35 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -122,6 +122,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - `Trainer.fit` now raises an error when using manual optimization with unsupported features such as `gradient_clip_val` or `accumulate_grad_batches` ([#7788](https://github.com/PyTorchLightning/pytorch-lightning/pull/7788)) +- Accelerator hooks are called regardless if `LightningModule` overrides the same hooks ([#7826](https://github.com/PyTorchLightning/pytorch-lightning/pull/7826)) + + - Moved profilers to their own file ([#7822](https://github.com/PyTorchLightning/pytorch-lightning/pull/7822)) diff --git a/pytorch_lightning/plugins/training_type/dp.py b/pytorch_lightning/plugins/training_type/dp.py index 18aeb6a451..2787ab5644 100644 --- a/pytorch_lightning/plugins/training_type/dp.py +++ b/pytorch_lightning/plugins/training_type/dp.py @@ -19,6 +19,7 @@ from torch.nn import DataParallel from pytorch_lightning.overrides.data_parallel import LightningParallelModule from pytorch_lightning.plugins.training_type.parallel import ParallelPlugin from pytorch_lightning.utilities.apply_func import apply_to_collection +from pytorch_lightning.utilities.model_helpers import is_overridden from pytorch_lightning.utilities.types import _METRIC_COLLECTION @@ -101,10 +102,16 @@ class DataParallelPlugin(ParallelPlugin): return self.model(*args, **kwargs) def training_step_end(self, output): - return self.reduce(output) + if not is_overridden("training_step_end", self.lightning_module): + return self.reduce(output) + return output def validation_step_end(self, output): - return self.reduce(output) + if not is_overridden("validation_step_end", self.lightning_module): + return self.reduce(output) + return output def test_step_end(self, output): - return self.reduce(output) + if not is_overridden("test_step_end", self.lightning_module): + return self.reduce(output) + return output diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index ded6e3395e..b9846af644 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -1237,11 +1237,14 @@ class Trainer( hook_fx = getattr(model_ref, hook_name) output = hook_fx(*args, **kwargs) - # if the PL module doesn't have the hook then call the accelerator - # used to auto-reduce things for the user with Results obj - elif hasattr(self.accelerator, hook_name): + # call the accelerator hook + if hasattr(self.accelerator, hook_name): accelerator_hook = getattr(self.accelerator, hook_name) - output = accelerator_hook(*args, **kwargs) + accelerator_output = accelerator_hook(*args, **kwargs) + # Rely on the accelerator output if lightningModule hook returns nothing + # Required for cases such as DataParallel where we reduce the output for the user + # todo: move this data parallel logic into the data parallel plugin + output = accelerator_output if output is None else output if not skip: self._cache_logged_metrics() diff --git a/pytorch_lightning/trainer/training_loop.py b/pytorch_lightning/trainer/training_loop.py index 4b48cf0acb..81498cbe3b 100644 --- a/pytorch_lightning/trainer/training_loop.py +++ b/pytorch_lightning/trainer/training_loop.py @@ -634,9 +634,8 @@ class TrainLoop: else: model_ref.on_train_epoch_end() - # if the PL module doesn't have the hook then call the accelerator - # used to auto-reduce things for the user with Results obj - elif hasattr(self.trainer.accelerator, hook_name): + # call the accelerator hook + if hasattr(self.trainer.accelerator, hook_name): accelerator_hook = getattr(self.trainer.accelerator, hook_name) accelerator_hook()