[IPU] Call accelerator hooks regardless if LM hook overridden 1/n (#7826)
* Modify API to ensure hooks defined in the accelerator are called as expected * handle step_end in dp * Add changelog * Update pytorch_lightning/trainer/trainer.py Co-authored-by: Carlos Mocholí <carlossmocholi@gmail.com> * Add todo and explanation Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com> Co-authored-by: Carlos Mocholí <carlossmocholi@gmail.com>
This commit is contained in:
parent
51d370f4c2
commit
7c7182d334
|
@ -122,6 +122,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
|
|||
- `Trainer.fit` now raises an error when using manual optimization with unsupported features such as `gradient_clip_val` or `accumulate_grad_batches` ([#7788](https://github.com/PyTorchLightning/pytorch-lightning/pull/7788))
|
||||
|
||||
|
||||
- Accelerator hooks are called regardless if `LightningModule` overrides the same hooks ([#7826](https://github.com/PyTorchLightning/pytorch-lightning/pull/7826))
|
||||
|
||||
|
||||
- Moved profilers to their own file ([#7822](https://github.com/PyTorchLightning/pytorch-lightning/pull/7822))
|
||||
|
||||
|
||||
|
|
|
@ -19,6 +19,7 @@ from torch.nn import DataParallel
|
|||
from pytorch_lightning.overrides.data_parallel import LightningParallelModule
|
||||
from pytorch_lightning.plugins.training_type.parallel import ParallelPlugin
|
||||
from pytorch_lightning.utilities.apply_func import apply_to_collection
|
||||
from pytorch_lightning.utilities.model_helpers import is_overridden
|
||||
from pytorch_lightning.utilities.types import _METRIC_COLLECTION
|
||||
|
||||
|
||||
|
@ -101,10 +102,16 @@ class DataParallelPlugin(ParallelPlugin):
|
|||
return self.model(*args, **kwargs)
|
||||
|
||||
def training_step_end(self, output):
|
||||
return self.reduce(output)
|
||||
if not is_overridden("training_step_end", self.lightning_module):
|
||||
return self.reduce(output)
|
||||
return output
|
||||
|
||||
def validation_step_end(self, output):
|
||||
return self.reduce(output)
|
||||
if not is_overridden("validation_step_end", self.lightning_module):
|
||||
return self.reduce(output)
|
||||
return output
|
||||
|
||||
def test_step_end(self, output):
|
||||
return self.reduce(output)
|
||||
if not is_overridden("test_step_end", self.lightning_module):
|
||||
return self.reduce(output)
|
||||
return output
|
||||
|
|
|
@ -1237,11 +1237,14 @@ class Trainer(
|
|||
hook_fx = getattr(model_ref, hook_name)
|
||||
output = hook_fx(*args, **kwargs)
|
||||
|
||||
# if the PL module doesn't have the hook then call the accelerator
|
||||
# used to auto-reduce things for the user with Results obj
|
||||
elif hasattr(self.accelerator, hook_name):
|
||||
# call the accelerator hook
|
||||
if hasattr(self.accelerator, hook_name):
|
||||
accelerator_hook = getattr(self.accelerator, hook_name)
|
||||
output = accelerator_hook(*args, **kwargs)
|
||||
accelerator_output = accelerator_hook(*args, **kwargs)
|
||||
# Rely on the accelerator output if lightningModule hook returns nothing
|
||||
# Required for cases such as DataParallel where we reduce the output for the user
|
||||
# todo: move this data parallel logic into the data parallel plugin
|
||||
output = accelerator_output if output is None else output
|
||||
|
||||
if not skip:
|
||||
self._cache_logged_metrics()
|
||||
|
|
|
@ -634,9 +634,8 @@ class TrainLoop:
|
|||
else:
|
||||
model_ref.on_train_epoch_end()
|
||||
|
||||
# if the PL module doesn't have the hook then call the accelerator
|
||||
# used to auto-reduce things for the user with Results obj
|
||||
elif hasattr(self.trainer.accelerator, hook_name):
|
||||
# call the accelerator hook
|
||||
if hasattr(self.trainer.accelerator, hook_name):
|
||||
accelerator_hook = getattr(self.trainer.accelerator, hook_name)
|
||||
accelerator_hook()
|
||||
|
||||
|
|
Loading…
Reference in New Issue