[IPU] Call accelerator hooks regardless if LM hook overridden 1/n (#7826)

* Modify API to ensure hooks defined in the accelerator are called as expected

* handle step_end in dp

* Add changelog

* Update pytorch_lightning/trainer/trainer.py

Co-authored-by: Carlos Mocholí <carlossmocholi@gmail.com>

* Add todo and explanation

Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com>
Co-authored-by: Carlos Mocholí <carlossmocholi@gmail.com>
This commit is contained in:
Sean Naren 2021-06-04 17:19:08 +01:00 committed by GitHub
parent 51d370f4c2
commit 7c7182d334
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 22 additions and 10 deletions

View File

@ -122,6 +122,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- `Trainer.fit` now raises an error when using manual optimization with unsupported features such as `gradient_clip_val` or `accumulate_grad_batches` ([#7788](https://github.com/PyTorchLightning/pytorch-lightning/pull/7788))
- Accelerator hooks are called regardless if `LightningModule` overrides the same hooks ([#7826](https://github.com/PyTorchLightning/pytorch-lightning/pull/7826))
- Moved profilers to their own file ([#7822](https://github.com/PyTorchLightning/pytorch-lightning/pull/7822))

View File

@ -19,6 +19,7 @@ from torch.nn import DataParallel
from pytorch_lightning.overrides.data_parallel import LightningParallelModule
from pytorch_lightning.plugins.training_type.parallel import ParallelPlugin
from pytorch_lightning.utilities.apply_func import apply_to_collection
from pytorch_lightning.utilities.model_helpers import is_overridden
from pytorch_lightning.utilities.types import _METRIC_COLLECTION
@ -101,10 +102,16 @@ class DataParallelPlugin(ParallelPlugin):
return self.model(*args, **kwargs)
def training_step_end(self, output):
return self.reduce(output)
if not is_overridden("training_step_end", self.lightning_module):
return self.reduce(output)
return output
def validation_step_end(self, output):
return self.reduce(output)
if not is_overridden("validation_step_end", self.lightning_module):
return self.reduce(output)
return output
def test_step_end(self, output):
return self.reduce(output)
if not is_overridden("test_step_end", self.lightning_module):
return self.reduce(output)
return output

View File

@ -1237,11 +1237,14 @@ class Trainer(
hook_fx = getattr(model_ref, hook_name)
output = hook_fx(*args, **kwargs)
# if the PL module doesn't have the hook then call the accelerator
# used to auto-reduce things for the user with Results obj
elif hasattr(self.accelerator, hook_name):
# call the accelerator hook
if hasattr(self.accelerator, hook_name):
accelerator_hook = getattr(self.accelerator, hook_name)
output = accelerator_hook(*args, **kwargs)
accelerator_output = accelerator_hook(*args, **kwargs)
# Rely on the accelerator output if lightningModule hook returns nothing
# Required for cases such as DataParallel where we reduce the output for the user
# todo: move this data parallel logic into the data parallel plugin
output = accelerator_output if output is None else output
if not skip:
self._cache_logged_metrics()

View File

@ -634,9 +634,8 @@ class TrainLoop:
else:
model_ref.on_train_epoch_end()
# if the PL module doesn't have the hook then call the accelerator
# used to auto-reduce things for the user with Results obj
elif hasattr(self.trainer.accelerator, hook_name):
# call the accelerator hook
if hasattr(self.trainer.accelerator, hook_name):
accelerator_hook = getattr(self.trainer.accelerator, hook_name)
accelerator_hook()