Add the `on_before_optimizer_step` hook (#8048)
Co-authored-by: Carlos Mocholi <carlossmocholi@gmail.com>
This commit is contained in:
parent
31fca1658d
commit
1b06edf2f2
|
@ -81,6 +81,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
|
|||
- Added trainer stage hooks for Training Plugins and Accelerators ([#7864](https://github.com/PyTorchLightning/pytorch-lightning/pull/7864))
|
||||
|
||||
|
||||
- Added the `on_before_optimizer_step` hook ([#8048](https://github.com/PyTorchLightning/pytorch-lightning/pull/8048))
|
||||
|
||||
|
||||
- Added IPU Accelerator ([#7867](https://github.com/PyTorchLightning/pytorch-lightning/pull/7867))
|
||||
|
||||
|
||||
|
@ -244,10 +247,10 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
|
|||
- Moved profilers to their own file ([#7822](https://github.com/PyTorchLightning/pytorch-lightning/pull/7822))
|
||||
|
||||
|
||||
- The `on_after_backward` hook is now called on accumulating iterations ([#8328](https://github.com/PyTorchLightning/pytorch-lightning/pull/8328))
|
||||
- The `on_after_backward` hook is now called on accumulating iterations. Use the `on_before_optimizer_step` hook to mimic the old behaviour ([#8328](https://github.com/PyTorchLightning/pytorch-lightning/pull/8328))
|
||||
|
||||
|
||||
- The mixed precision loss is no longer unscaled before the `on_after_backward` hook ([#8328](https://github.com/PyTorchLightning/pytorch-lightning/pull/8328))
|
||||
- The mixed precision loss is no longer unscaled before the `on_after_backward` hook. Use the `on_before_optimizer_step` hook to mimic the old behaviour ([#8328](https://github.com/PyTorchLightning/pytorch-lightning/pull/8328))
|
||||
|
||||
|
||||
- The `TrainingTypePlugin.{pre,post}_backward` hooks no longer take the `optimizer, opt_idx, should_accumulate` arguments ([#8328](https://github.com/PyTorchLightning/pytorch-lightning/pull/8328))
|
||||
|
|
|
@ -1195,6 +1195,7 @@ for more information.
|
|||
backward()
|
||||
on_after_backward()
|
||||
|
||||
on_before_optimizer_step()
|
||||
optimizer_step()
|
||||
|
||||
on_train_batch_end()
|
||||
|
@ -1451,6 +1452,12 @@ on_test_model_train
|
|||
.. automethod:: pytorch_lightning.core.hooks.ModelHooks.on_test_model_train
|
||||
:noindex:
|
||||
|
||||
on_before_optimizer_step
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
.. automethod:: pytorch_lightning.core.hooks.ModelHooks.on_before_optimizer_step
|
||||
:noindex:
|
||||
|
||||
optimizer_step
|
||||
~~~~~~~~~~~~~~
|
||||
|
||||
|
|
|
@ -363,6 +363,12 @@ on_after_backward
|
|||
.. automethod:: pytorch_lightning.callbacks.Callback.on_after_backward
|
||||
:noindex:
|
||||
|
||||
on_before_optimizer_step
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^
|
||||
|
||||
.. automethod:: pytorch_lightning.callbacks.Callback.on_before_optimizer_step
|
||||
:noindex:
|
||||
|
||||
on_before_zero_grad
|
||||
^^^^^^^^^^^^^^^^^^^
|
||||
|
||||
|
|
|
@ -302,7 +302,13 @@ class Callback(abc.ABC):
|
|||
pass
|
||||
|
||||
def on_after_backward(self, trainer: 'pl.Trainer', pl_module: 'pl.LightningModule') -> None:
|
||||
"""Called after ``loss.backward()`` and before optimizers do anything."""
|
||||
"""Called after ``loss.backward()`` and before optimizers are stepped."""
|
||||
pass
|
||||
|
||||
def on_before_optimizer_step(
|
||||
self, trainer: 'pl.Trainer', pl_module: 'pl.LightningModule', optimizer: Optimizer, opt_idx: int
|
||||
) -> None:
|
||||
"""Called before ``optimizer.step()``."""
|
||||
pass
|
||||
|
||||
def on_before_zero_grad(self, trainer: 'pl.Trainer', pl_module: 'pl.LightningModule', optimizer: Optimizer) -> None:
|
||||
|
|
|
@ -79,6 +79,7 @@ class LambdaCallback(Callback):
|
|||
on_load_checkpoint: Optional[Callable] = None,
|
||||
on_before_backward: Optional[Callable] = None,
|
||||
on_after_backward: Optional[Callable] = None,
|
||||
on_before_optimizer_step: Optional[Callable] = None,
|
||||
on_before_zero_grad: Optional[Callable] = None,
|
||||
on_predict_start: Optional[Callable] = None,
|
||||
on_predict_end: Optional[Callable] = None,
|
||||
|
|
|
@ -306,19 +306,36 @@ class ModelHooks:
|
|||
|
||||
def on_after_backward(self) -> None:
|
||||
"""
|
||||
Called in the training loop after loss.backward() and before optimizers do anything.
|
||||
This is the ideal place to inspect or log gradient information.
|
||||
Called after ``loss.backward()`` and before optimizers are stepped.
|
||||
|
||||
Note:
|
||||
If using native AMP, the gradients will not be unscaled at this point.
|
||||
Use the ``on_before_optimizer_step`` if you need the unscaled gradients.
|
||||
"""
|
||||
|
||||
def on_before_optimizer_step(self, optimizer: Optimizer, optimizer_idx: int) -> None:
|
||||
"""
|
||||
Called before ``optimizer.step()``.
|
||||
|
||||
The hook is only called if gradients do not need to be accumulated.
|
||||
See: :paramref:`~pytorch_lightning.trainer.Trainer.accumulate_grad_batches`.
|
||||
If using native AMP, the loss will be unscaled before calling this hook.
|
||||
See these `docs <https://pytorch.org/docs/stable/notes/amp_examples.html#working-with-unscaled-gradients>`__
|
||||
for more information on the scaling of gradients.
|
||||
|
||||
Args:
|
||||
optimizer: Current optimizer being used.
|
||||
optimizer_idx: Index of the current optimizer being used.
|
||||
|
||||
Example::
|
||||
|
||||
def on_after_backward(self):
|
||||
def on_before_optimizer_step(self, optimizer, optimizer_idx):
|
||||
# example to inspect gradient information in tensorboard
|
||||
if self.trainer.global_step % 25 == 0: # don't make the tf file huge
|
||||
for k, v in self.named_parameters():
|
||||
self.logger.experiment.add_histogram(
|
||||
tag=k, values=v.grad, global_step=self.trainer.global_step
|
||||
)
|
||||
|
||||
"""
|
||||
|
||||
def on_post_move_to_device(self) -> None:
|
||||
|
|
|
@ -90,17 +90,16 @@ class ApexMixedPrecisionPlugin(MixedPrecisionPlugin):
|
|||
|
||||
def pre_optimizer_step(
|
||||
self,
|
||||
pl_module: 'pl.LightningModule',
|
||||
model: 'pl.LightningModule',
|
||||
optimizer: Optimizer,
|
||||
optimizer_idx: int,
|
||||
lambda_closure: Callable,
|
||||
**kwargs: Any,
|
||||
) -> bool:
|
||||
"""
|
||||
always called before the optimizer step.
|
||||
"""
|
||||
# apex amp does not support closures.
|
||||
lambda_closure()
|
||||
"""Hook to do something before each optimizer step."""
|
||||
super().pre_optimizer_step(model, optimizer, optimizer_idx, lambda_closure, **kwargs)
|
||||
# the following should be in a `optimizer_step` hook but we don't have one in the precision plugin.
|
||||
lambda_closure() # APEX amp does not support closures
|
||||
optimizer.step(**kwargs)
|
||||
return False
|
||||
|
||||
|
|
|
@ -35,15 +35,17 @@ class DeepSpeedPrecisionPlugin(PrecisionPlugin):
|
|||
|
||||
def pre_optimizer_step(
|
||||
self,
|
||||
pl_module: 'pl.LightningModule',
|
||||
model: 'pl.LightningModule',
|
||||
optimizer: Optimizer,
|
||||
optimizer_idx: int,
|
||||
lambda_closure: Callable,
|
||||
**kwargs: Any,
|
||||
) -> bool:
|
||||
# DeepSpeed not support closures.
|
||||
lambda_closure()
|
||||
deepspeed_engine = pl_module.trainer.model
|
||||
"""Hook to do something before each optimizer step."""
|
||||
super().pre_optimizer_step(model, optimizer, optimizer_idx, lambda_closure, **kwargs)
|
||||
# the following should be in a `optimizer_step` hook but we don't have one in the precision plugin.
|
||||
lambda_closure() # DeepSpeed does not support closures
|
||||
deepspeed_engine = model.trainer.model
|
||||
deepspeed_engine.step()
|
||||
return False
|
||||
|
||||
|
|
|
@ -47,7 +47,7 @@ class NativeMixedPrecisionPlugin(MixedPrecisionPlugin):
|
|||
|
||||
def pre_optimizer_step(
|
||||
self,
|
||||
pl_module: 'pl.LightningModule',
|
||||
model: 'pl.LightningModule',
|
||||
optimizer: Optimizer,
|
||||
optimizer_idx: int,
|
||||
lambda_closure: Callable,
|
||||
|
@ -58,16 +58,15 @@ class NativeMixedPrecisionPlugin(MixedPrecisionPlugin):
|
|||
f"native PyTorch amp and lbfgs are not compatible (optimizer {optimizer_idx})."
|
||||
" To request, please file a Github issue in PyTorch and tag @mcarilli"
|
||||
)
|
||||
# TODO: Add `on_before_optimizer_step`
|
||||
# self.scaler.unscale_(optimizer)
|
||||
# pl_module.trainer.call_hook("on_before_optimizer_step")
|
||||
if pl_module.automatic_optimization:
|
||||
result = True
|
||||
if model.automatic_optimization:
|
||||
result = lambda_closure()
|
||||
if result is None:
|
||||
# lambda_closure returning None indicates that backward has been skipped
|
||||
return False
|
||||
self.scaler.step(optimizer)
|
||||
self.scaler.update()
|
||||
self.scaler.unscale_(optimizer)
|
||||
super().pre_optimizer_step(model, optimizer, optimizer_idx, lambda_closure, **kwargs)
|
||||
# lambda_closure returning None indicates that backward has been skipped
|
||||
if result is not None:
|
||||
self.scaler.step(optimizer)
|
||||
self.scaler.update()
|
||||
return False
|
||||
|
||||
@contextmanager
|
||||
|
|
|
@ -104,13 +104,14 @@ class PrecisionPlugin(Plugin, CheckpointHooks):
|
|||
|
||||
def pre_optimizer_step(
|
||||
self,
|
||||
pl_module: 'pl.LightningModule',
|
||||
model: 'pl.LightningModule',
|
||||
optimizer: Optimizer,
|
||||
optimizer_idx: int,
|
||||
lambda_closure: Callable,
|
||||
**kwargs: Any,
|
||||
) -> bool:
|
||||
"""Hook to do something before each optimizer step."""
|
||||
model.trainer.call_hook("on_before_optimizer_step", optimizer, optimizer_idx)
|
||||
return True
|
||||
|
||||
def post_optimizer_step(self, optimizer: Optimizer, optimizer_idx: int) -> None:
|
||||
|
|
|
@ -327,6 +327,13 @@ class TrainerCallbackHookMixin(ABC):
|
|||
for callback in self.callbacks:
|
||||
callback.on_after_backward(self, self.lightning_module)
|
||||
|
||||
def on_before_optimizer_step(self, optimizer, optimizer_idx):
|
||||
"""
|
||||
Called after on_after_backward() once the gradient is accumulated and before optimizer.step().
|
||||
"""
|
||||
for callback in self.callbacks:
|
||||
callback.on_before_optimizer_step(self, self.lightning_module, optimizer, optimizer_idx)
|
||||
|
||||
def on_before_zero_grad(self, optimizer):
|
||||
"""
|
||||
Called after optimizer.step() and before optimizer.zero_grad().
|
||||
|
|
|
@ -23,6 +23,7 @@ class FxValidator:
|
|||
on_configure_sharded_model=None,
|
||||
on_before_backward=dict(on_step=(False, True), on_epoch=(False, True)),
|
||||
on_after_backward=dict(on_step=(False, True), on_epoch=(False, True)),
|
||||
on_before_optimizer_step=dict(on_step=(False, True), on_epoch=(False, True)),
|
||||
on_before_zero_grad=dict(on_step=(False, True), on_epoch=(False, True)),
|
||||
on_init_start=None,
|
||||
on_init_end=None,
|
||||
|
|
|
@ -299,6 +299,10 @@ class HookedModel(BoringModel):
|
|||
using_native_amp = kwargs.get('amp_backend') == 'native'
|
||||
using_deepspeed = kwargs.get('plugins') == 'deepspeed'
|
||||
out = []
|
||||
on_before_optimizer_step = [
|
||||
dict(name='Callback.on_before_optimizer_step', args=(trainer, model, ANY, 0)),
|
||||
dict(name='on_before_optimizer_step', args=(ANY, 0)),
|
||||
]
|
||||
for i in range(batches):
|
||||
out.extend([
|
||||
dict(name='on_before_batch_transfer', args=(ANY, 0)),
|
||||
|
@ -308,7 +312,10 @@ class HookedModel(BoringModel):
|
|||
dict(name='Callback.on_batch_start', args=(trainer, model)),
|
||||
dict(name='Callback.on_train_batch_start', args=(trainer, model, ANY, i, 0)),
|
||||
dict(name='on_train_batch_start', args=(ANY, i, 0)),
|
||||
# TODO: `on_before_optimizer_step`
|
||||
# these are before the training step because
|
||||
# they are not part of the `training_step_and_backward` closure, however,
|
||||
# with native amp, the closure is run first and then the optimizer step.
|
||||
*(on_before_optimizer_step if not using_native_amp else []),
|
||||
dict(name='forward', args=(ANY, )),
|
||||
dict(name='training_step', args=(ANY, i)),
|
||||
dict(name='training_step_end', args=(dict(loss=ANY), )),
|
||||
|
@ -321,6 +328,7 @@ class HookedModel(BoringModel):
|
|||
*([dict(name='backward', args=(ANY, ANY, 0))] if not using_deepspeed else []),
|
||||
dict(name='Callback.on_after_backward', args=(trainer, model)),
|
||||
dict(name='on_after_backward'),
|
||||
*(on_before_optimizer_step if using_native_amp else []),
|
||||
dict(
|
||||
name='optimizer_step',
|
||||
args=(current_epoch, i, ANY, 0, ANY),
|
||||
|
@ -354,7 +362,8 @@ class HookedModel(BoringModel):
|
|||
dict(name='on_after_backward'),
|
||||
# `manual_backward` calls the previous 3
|
||||
dict(name='manual_backward', args=(ANY, )),
|
||||
# TODO: `on_before_optimizer_step`
|
||||
dict(name='Callback.on_before_optimizer_step', args=(trainer, model, ANY, 0)),
|
||||
dict(name='on_before_optimizer_step', args=(ANY, 0)),
|
||||
dict(name='training_step', args=(ANY, i)),
|
||||
dict(name='training_step_end', args=(dict(loss=ANY), )),
|
||||
dict(name='Callback.on_train_batch_end', args=(trainer, model, dict(loss=ANY), ANY, i, 0)),
|
||||
|
|
|
@ -71,12 +71,7 @@ def test_amp_apex_ddp(
|
|||
|
||||
class GradientUnscaleBoringModel(BoringModel):
|
||||
|
||||
def on_after_backward(self):
|
||||
# TODO: replace with `on_before_optimizer_step` so we don't need to check accumulate and unscale manually
|
||||
if self.trainer.fit_loop.should_accumulate():
|
||||
return
|
||||
opt = self.optimizers()
|
||||
self.trainer.precision_plugin.scaler.unscale_(opt)
|
||||
def on_before_optimizer_step(self, *_):
|
||||
norm = torch.nn.utils.clip_grad_norm_(self.parameters(), 2)
|
||||
if not (torch.isinf(norm) or torch.isnan(norm)):
|
||||
assert norm.item() < 15.
|
||||
|
|
|
@ -34,6 +34,7 @@ def test_fx_validator(tmpdir):
|
|||
callbacks_func = [
|
||||
'on_before_backward',
|
||||
'on_after_backward',
|
||||
'on_before_optimizer_step',
|
||||
'on_batch_end',
|
||||
'on_batch_start',
|
||||
'on_before_accelerator_backend_setup',
|
||||
|
@ -124,6 +125,7 @@ def test_fx_validator(tmpdir):
|
|||
# creating allowed condition
|
||||
allowed = (
|
||||
is_stage or "batch" in func_name or "epoch" in func_name or "grad" in func_name or "backward" in func_name
|
||||
or "optimizer_step" in func_name
|
||||
)
|
||||
allowed = (
|
||||
allowed and "pretrain" not in func_name and "predict" not in func_name
|
||||
|
|
Loading…
Reference in New Issue