Add the `on_before_optimizer_step` hook (#8048)

Co-authored-by: Carlos Mocholi <carlossmocholi@gmail.com>
This commit is contained in:
Dusan Drevicky 2021-07-09 13:30:52 +02:00 committed by GitHub
parent 31fca1658d
commit 1b06edf2f2
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
15 changed files with 91 additions and 36 deletions

View File

@ -81,6 +81,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Added trainer stage hooks for Training Plugins and Accelerators ([#7864](https://github.com/PyTorchLightning/pytorch-lightning/pull/7864))
- Added the `on_before_optimizer_step` hook ([#8048](https://github.com/PyTorchLightning/pytorch-lightning/pull/8048))
- Added IPU Accelerator ([#7867](https://github.com/PyTorchLightning/pytorch-lightning/pull/7867))
@ -244,10 +247,10 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Moved profilers to their own file ([#7822](https://github.com/PyTorchLightning/pytorch-lightning/pull/7822))
- The `on_after_backward` hook is now called on accumulating iterations ([#8328](https://github.com/PyTorchLightning/pytorch-lightning/pull/8328))
- The `on_after_backward` hook is now called on accumulating iterations. Use the `on_before_optimizer_step` hook to mimic the old behaviour ([#8328](https://github.com/PyTorchLightning/pytorch-lightning/pull/8328))
- The mixed precision loss is no longer unscaled before the `on_after_backward` hook ([#8328](https://github.com/PyTorchLightning/pytorch-lightning/pull/8328))
- The mixed precision loss is no longer unscaled before the `on_after_backward` hook. Use the `on_before_optimizer_step` hook to mimic the old behaviour ([#8328](https://github.com/PyTorchLightning/pytorch-lightning/pull/8328))
- The `TrainingTypePlugin.{pre,post}_backward` hooks no longer take the `optimizer, opt_idx, should_accumulate` arguments ([#8328](https://github.com/PyTorchLightning/pytorch-lightning/pull/8328))

View File

@ -1195,6 +1195,7 @@ for more information.
backward()
on_after_backward()
on_before_optimizer_step()
optimizer_step()
on_train_batch_end()
@ -1451,6 +1452,12 @@ on_test_model_train
.. automethod:: pytorch_lightning.core.hooks.ModelHooks.on_test_model_train
:noindex:
on_before_optimizer_step
~~~~~~~~~~~~~~~~~~~~~~~~
.. automethod:: pytorch_lightning.core.hooks.ModelHooks.on_before_optimizer_step
:noindex:
optimizer_step
~~~~~~~~~~~~~~

View File

@ -363,6 +363,12 @@ on_after_backward
.. automethod:: pytorch_lightning.callbacks.Callback.on_after_backward
:noindex:
on_before_optimizer_step
^^^^^^^^^^^^^^^^^^^^^^^^
.. automethod:: pytorch_lightning.callbacks.Callback.on_before_optimizer_step
:noindex:
on_before_zero_grad
^^^^^^^^^^^^^^^^^^^

View File

@ -302,7 +302,13 @@ class Callback(abc.ABC):
pass
def on_after_backward(self, trainer: 'pl.Trainer', pl_module: 'pl.LightningModule') -> None:
"""Called after ``loss.backward()`` and before optimizers do anything."""
"""Called after ``loss.backward()`` and before optimizers are stepped."""
pass
def on_before_optimizer_step(
self, trainer: 'pl.Trainer', pl_module: 'pl.LightningModule', optimizer: Optimizer, opt_idx: int
) -> None:
"""Called before ``optimizer.step()``."""
pass
def on_before_zero_grad(self, trainer: 'pl.Trainer', pl_module: 'pl.LightningModule', optimizer: Optimizer) -> None:

View File

@ -79,6 +79,7 @@ class LambdaCallback(Callback):
on_load_checkpoint: Optional[Callable] = None,
on_before_backward: Optional[Callable] = None,
on_after_backward: Optional[Callable] = None,
on_before_optimizer_step: Optional[Callable] = None,
on_before_zero_grad: Optional[Callable] = None,
on_predict_start: Optional[Callable] = None,
on_predict_end: Optional[Callable] = None,

View File

@ -306,19 +306,36 @@ class ModelHooks:
def on_after_backward(self) -> None:
"""
Called in the training loop after loss.backward() and before optimizers do anything.
This is the ideal place to inspect or log gradient information.
Called after ``loss.backward()`` and before optimizers are stepped.
Note:
If using native AMP, the gradients will not be unscaled at this point.
Use the ``on_before_optimizer_step`` if you need the unscaled gradients.
"""
def on_before_optimizer_step(self, optimizer: Optimizer, optimizer_idx: int) -> None:
"""
Called before ``optimizer.step()``.
The hook is only called if gradients do not need to be accumulated.
See: :paramref:`~pytorch_lightning.trainer.Trainer.accumulate_grad_batches`.
If using native AMP, the loss will be unscaled before calling this hook.
See these `docs <https://pytorch.org/docs/stable/notes/amp_examples.html#working-with-unscaled-gradients>`__
for more information on the scaling of gradients.
Args:
optimizer: Current optimizer being used.
optimizer_idx: Index of the current optimizer being used.
Example::
def on_after_backward(self):
def on_before_optimizer_step(self, optimizer, optimizer_idx):
# example to inspect gradient information in tensorboard
if self.trainer.global_step % 25 == 0: # don't make the tf file huge
for k, v in self.named_parameters():
self.logger.experiment.add_histogram(
tag=k, values=v.grad, global_step=self.trainer.global_step
)
"""
def on_post_move_to_device(self) -> None:

View File

@ -90,17 +90,16 @@ class ApexMixedPrecisionPlugin(MixedPrecisionPlugin):
def pre_optimizer_step(
self,
pl_module: 'pl.LightningModule',
model: 'pl.LightningModule',
optimizer: Optimizer,
optimizer_idx: int,
lambda_closure: Callable,
**kwargs: Any,
) -> bool:
"""
always called before the optimizer step.
"""
# apex amp does not support closures.
lambda_closure()
"""Hook to do something before each optimizer step."""
super().pre_optimizer_step(model, optimizer, optimizer_idx, lambda_closure, **kwargs)
# the following should be in a `optimizer_step` hook but we don't have one in the precision plugin.
lambda_closure() # APEX amp does not support closures
optimizer.step(**kwargs)
return False

View File

@ -35,15 +35,17 @@ class DeepSpeedPrecisionPlugin(PrecisionPlugin):
def pre_optimizer_step(
self,
pl_module: 'pl.LightningModule',
model: 'pl.LightningModule',
optimizer: Optimizer,
optimizer_idx: int,
lambda_closure: Callable,
**kwargs: Any,
) -> bool:
# DeepSpeed not support closures.
lambda_closure()
deepspeed_engine = pl_module.trainer.model
"""Hook to do something before each optimizer step."""
super().pre_optimizer_step(model, optimizer, optimizer_idx, lambda_closure, **kwargs)
# the following should be in a `optimizer_step` hook but we don't have one in the precision plugin.
lambda_closure() # DeepSpeed does not support closures
deepspeed_engine = model.trainer.model
deepspeed_engine.step()
return False

View File

@ -47,7 +47,7 @@ class NativeMixedPrecisionPlugin(MixedPrecisionPlugin):
def pre_optimizer_step(
self,
pl_module: 'pl.LightningModule',
model: 'pl.LightningModule',
optimizer: Optimizer,
optimizer_idx: int,
lambda_closure: Callable,
@ -58,16 +58,15 @@ class NativeMixedPrecisionPlugin(MixedPrecisionPlugin):
f"native PyTorch amp and lbfgs are not compatible (optimizer {optimizer_idx})."
" To request, please file a Github issue in PyTorch and tag @mcarilli"
)
# TODO: Add `on_before_optimizer_step`
# self.scaler.unscale_(optimizer)
# pl_module.trainer.call_hook("on_before_optimizer_step")
if pl_module.automatic_optimization:
result = True
if model.automatic_optimization:
result = lambda_closure()
if result is None:
# lambda_closure returning None indicates that backward has been skipped
return False
self.scaler.step(optimizer)
self.scaler.update()
self.scaler.unscale_(optimizer)
super().pre_optimizer_step(model, optimizer, optimizer_idx, lambda_closure, **kwargs)
# lambda_closure returning None indicates that backward has been skipped
if result is not None:
self.scaler.step(optimizer)
self.scaler.update()
return False
@contextmanager

View File

@ -104,13 +104,14 @@ class PrecisionPlugin(Plugin, CheckpointHooks):
def pre_optimizer_step(
self,
pl_module: 'pl.LightningModule',
model: 'pl.LightningModule',
optimizer: Optimizer,
optimizer_idx: int,
lambda_closure: Callable,
**kwargs: Any,
) -> bool:
"""Hook to do something before each optimizer step."""
model.trainer.call_hook("on_before_optimizer_step", optimizer, optimizer_idx)
return True
def post_optimizer_step(self, optimizer: Optimizer, optimizer_idx: int) -> None:

View File

@ -327,6 +327,13 @@ class TrainerCallbackHookMixin(ABC):
for callback in self.callbacks:
callback.on_after_backward(self, self.lightning_module)
def on_before_optimizer_step(self, optimizer, optimizer_idx):
"""
Called after on_after_backward() once the gradient is accumulated and before optimizer.step().
"""
for callback in self.callbacks:
callback.on_before_optimizer_step(self, self.lightning_module, optimizer, optimizer_idx)
def on_before_zero_grad(self, optimizer):
"""
Called after optimizer.step() and before optimizer.zero_grad().

View File

@ -23,6 +23,7 @@ class FxValidator:
on_configure_sharded_model=None,
on_before_backward=dict(on_step=(False, True), on_epoch=(False, True)),
on_after_backward=dict(on_step=(False, True), on_epoch=(False, True)),
on_before_optimizer_step=dict(on_step=(False, True), on_epoch=(False, True)),
on_before_zero_grad=dict(on_step=(False, True), on_epoch=(False, True)),
on_init_start=None,
on_init_end=None,

View File

@ -299,6 +299,10 @@ class HookedModel(BoringModel):
using_native_amp = kwargs.get('amp_backend') == 'native'
using_deepspeed = kwargs.get('plugins') == 'deepspeed'
out = []
on_before_optimizer_step = [
dict(name='Callback.on_before_optimizer_step', args=(trainer, model, ANY, 0)),
dict(name='on_before_optimizer_step', args=(ANY, 0)),
]
for i in range(batches):
out.extend([
dict(name='on_before_batch_transfer', args=(ANY, 0)),
@ -308,7 +312,10 @@ class HookedModel(BoringModel):
dict(name='Callback.on_batch_start', args=(trainer, model)),
dict(name='Callback.on_train_batch_start', args=(trainer, model, ANY, i, 0)),
dict(name='on_train_batch_start', args=(ANY, i, 0)),
# TODO: `on_before_optimizer_step`
# these are before the training step because
# they are not part of the `training_step_and_backward` closure, however,
# with native amp, the closure is run first and then the optimizer step.
*(on_before_optimizer_step if not using_native_amp else []),
dict(name='forward', args=(ANY, )),
dict(name='training_step', args=(ANY, i)),
dict(name='training_step_end', args=(dict(loss=ANY), )),
@ -321,6 +328,7 @@ class HookedModel(BoringModel):
*([dict(name='backward', args=(ANY, ANY, 0))] if not using_deepspeed else []),
dict(name='Callback.on_after_backward', args=(trainer, model)),
dict(name='on_after_backward'),
*(on_before_optimizer_step if using_native_amp else []),
dict(
name='optimizer_step',
args=(current_epoch, i, ANY, 0, ANY),
@ -354,7 +362,8 @@ class HookedModel(BoringModel):
dict(name='on_after_backward'),
# `manual_backward` calls the previous 3
dict(name='manual_backward', args=(ANY, )),
# TODO: `on_before_optimizer_step`
dict(name='Callback.on_before_optimizer_step', args=(trainer, model, ANY, 0)),
dict(name='on_before_optimizer_step', args=(ANY, 0)),
dict(name='training_step', args=(ANY, i)),
dict(name='training_step_end', args=(dict(loss=ANY), )),
dict(name='Callback.on_train_batch_end', args=(trainer, model, dict(loss=ANY), ANY, i, 0)),

View File

@ -71,12 +71,7 @@ def test_amp_apex_ddp(
class GradientUnscaleBoringModel(BoringModel):
def on_after_backward(self):
# TODO: replace with `on_before_optimizer_step` so we don't need to check accumulate and unscale manually
if self.trainer.fit_loop.should_accumulate():
return
opt = self.optimizers()
self.trainer.precision_plugin.scaler.unscale_(opt)
def on_before_optimizer_step(self, *_):
norm = torch.nn.utils.clip_grad_norm_(self.parameters(), 2)
if not (torch.isinf(norm) or torch.isnan(norm)):
assert norm.item() < 15.

View File

@ -34,6 +34,7 @@ def test_fx_validator(tmpdir):
callbacks_func = [
'on_before_backward',
'on_after_backward',
'on_before_optimizer_step',
'on_batch_end',
'on_batch_start',
'on_before_accelerator_backend_setup',
@ -124,6 +125,7 @@ def test_fx_validator(tmpdir):
# creating allowed condition
allowed = (
is_stage or "batch" in func_name or "epoch" in func_name or "grad" in func_name or "backward" in func_name
or "optimizer_step" in func_name
)
allowed = (
allowed and "pretrain" not in func_name and "predict" not in func_name