diff --git a/CHANGELOG.md b/CHANGELOG.md index 1d8dc8647b..e77df63712 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -81,6 +81,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Added trainer stage hooks for Training Plugins and Accelerators ([#7864](https://github.com/PyTorchLightning/pytorch-lightning/pull/7864)) +- Added the `on_before_optimizer_step` hook ([#8048](https://github.com/PyTorchLightning/pytorch-lightning/pull/8048)) + + - Added IPU Accelerator ([#7867](https://github.com/PyTorchLightning/pytorch-lightning/pull/7867)) @@ -244,10 +247,10 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Moved profilers to their own file ([#7822](https://github.com/PyTorchLightning/pytorch-lightning/pull/7822)) -- The `on_after_backward` hook is now called on accumulating iterations ([#8328](https://github.com/PyTorchLightning/pytorch-lightning/pull/8328)) +- The `on_after_backward` hook is now called on accumulating iterations. Use the `on_before_optimizer_step` hook to mimic the old behaviour ([#8328](https://github.com/PyTorchLightning/pytorch-lightning/pull/8328)) -- The mixed precision loss is no longer unscaled before the `on_after_backward` hook ([#8328](https://github.com/PyTorchLightning/pytorch-lightning/pull/8328)) +- The mixed precision loss is no longer unscaled before the `on_after_backward` hook. Use the `on_before_optimizer_step` hook to mimic the old behaviour ([#8328](https://github.com/PyTorchLightning/pytorch-lightning/pull/8328)) - The `TrainingTypePlugin.{pre,post}_backward` hooks no longer take the `optimizer, opt_idx, should_accumulate` arguments ([#8328](https://github.com/PyTorchLightning/pytorch-lightning/pull/8328)) diff --git a/docs/source/common/lightning_module.rst b/docs/source/common/lightning_module.rst index 40c0ef92a8..84ffb7cec8 100644 --- a/docs/source/common/lightning_module.rst +++ b/docs/source/common/lightning_module.rst @@ -1195,6 +1195,7 @@ for more information. backward() on_after_backward() + on_before_optimizer_step() optimizer_step() on_train_batch_end() @@ -1451,6 +1452,12 @@ on_test_model_train .. automethod:: pytorch_lightning.core.hooks.ModelHooks.on_test_model_train :noindex: +on_before_optimizer_step +~~~~~~~~~~~~~~~~~~~~~~~~ + +.. automethod:: pytorch_lightning.core.hooks.ModelHooks.on_before_optimizer_step + :noindex: + optimizer_step ~~~~~~~~~~~~~~ diff --git a/docs/source/extensions/callbacks.rst b/docs/source/extensions/callbacks.rst index a905958aac..88527f1177 100644 --- a/docs/source/extensions/callbacks.rst +++ b/docs/source/extensions/callbacks.rst @@ -363,6 +363,12 @@ on_after_backward .. automethod:: pytorch_lightning.callbacks.Callback.on_after_backward :noindex: +on_before_optimizer_step +^^^^^^^^^^^^^^^^^^^^^^^^ + +.. automethod:: pytorch_lightning.callbacks.Callback.on_before_optimizer_step + :noindex: + on_before_zero_grad ^^^^^^^^^^^^^^^^^^^ diff --git a/pytorch_lightning/callbacks/base.py b/pytorch_lightning/callbacks/base.py index 5f3fc5b8d0..1db7b20b92 100644 --- a/pytorch_lightning/callbacks/base.py +++ b/pytorch_lightning/callbacks/base.py @@ -302,7 +302,13 @@ class Callback(abc.ABC): pass def on_after_backward(self, trainer: 'pl.Trainer', pl_module: 'pl.LightningModule') -> None: - """Called after ``loss.backward()`` and before optimizers do anything.""" + """Called after ``loss.backward()`` and before optimizers are stepped.""" + pass + + def on_before_optimizer_step( + self, trainer: 'pl.Trainer', pl_module: 'pl.LightningModule', optimizer: Optimizer, opt_idx: int + ) -> None: + """Called before ``optimizer.step()``.""" pass def on_before_zero_grad(self, trainer: 'pl.Trainer', pl_module: 'pl.LightningModule', optimizer: Optimizer) -> None: diff --git a/pytorch_lightning/callbacks/lambda_function.py b/pytorch_lightning/callbacks/lambda_function.py index e10274bd59..ca9af484db 100644 --- a/pytorch_lightning/callbacks/lambda_function.py +++ b/pytorch_lightning/callbacks/lambda_function.py @@ -79,6 +79,7 @@ class LambdaCallback(Callback): on_load_checkpoint: Optional[Callable] = None, on_before_backward: Optional[Callable] = None, on_after_backward: Optional[Callable] = None, + on_before_optimizer_step: Optional[Callable] = None, on_before_zero_grad: Optional[Callable] = None, on_predict_start: Optional[Callable] = None, on_predict_end: Optional[Callable] = None, diff --git a/pytorch_lightning/core/hooks.py b/pytorch_lightning/core/hooks.py index c99350879a..a30f699c70 100644 --- a/pytorch_lightning/core/hooks.py +++ b/pytorch_lightning/core/hooks.py @@ -306,19 +306,36 @@ class ModelHooks: def on_after_backward(self) -> None: """ - Called in the training loop after loss.backward() and before optimizers do anything. - This is the ideal place to inspect or log gradient information. + Called after ``loss.backward()`` and before optimizers are stepped. + + Note: + If using native AMP, the gradients will not be unscaled at this point. + Use the ``on_before_optimizer_step`` if you need the unscaled gradients. + """ + + def on_before_optimizer_step(self, optimizer: Optimizer, optimizer_idx: int) -> None: + """ + Called before ``optimizer.step()``. + + The hook is only called if gradients do not need to be accumulated. + See: :paramref:`~pytorch_lightning.trainer.Trainer.accumulate_grad_batches`. + If using native AMP, the loss will be unscaled before calling this hook. + See these `docs `__ + for more information on the scaling of gradients. + + Args: + optimizer: Current optimizer being used. + optimizer_idx: Index of the current optimizer being used. Example:: - def on_after_backward(self): + def on_before_optimizer_step(self, optimizer, optimizer_idx): # example to inspect gradient information in tensorboard if self.trainer.global_step % 25 == 0: # don't make the tf file huge for k, v in self.named_parameters(): self.logger.experiment.add_histogram( tag=k, values=v.grad, global_step=self.trainer.global_step ) - """ def on_post_move_to_device(self) -> None: diff --git a/pytorch_lightning/plugins/precision/apex_amp.py b/pytorch_lightning/plugins/precision/apex_amp.py index 022dd3ee39..1eae67b992 100644 --- a/pytorch_lightning/plugins/precision/apex_amp.py +++ b/pytorch_lightning/plugins/precision/apex_amp.py @@ -90,17 +90,16 @@ class ApexMixedPrecisionPlugin(MixedPrecisionPlugin): def pre_optimizer_step( self, - pl_module: 'pl.LightningModule', + model: 'pl.LightningModule', optimizer: Optimizer, optimizer_idx: int, lambda_closure: Callable, **kwargs: Any, ) -> bool: - """ - always called before the optimizer step. - """ - # apex amp does not support closures. - lambda_closure() + """Hook to do something before each optimizer step.""" + super().pre_optimizer_step(model, optimizer, optimizer_idx, lambda_closure, **kwargs) + # the following should be in a `optimizer_step` hook but we don't have one in the precision plugin. + lambda_closure() # APEX amp does not support closures optimizer.step(**kwargs) return False diff --git a/pytorch_lightning/plugins/precision/deepspeed_precision.py b/pytorch_lightning/plugins/precision/deepspeed_precision.py index 4809b4e8c2..53e813e75a 100644 --- a/pytorch_lightning/plugins/precision/deepspeed_precision.py +++ b/pytorch_lightning/plugins/precision/deepspeed_precision.py @@ -35,15 +35,17 @@ class DeepSpeedPrecisionPlugin(PrecisionPlugin): def pre_optimizer_step( self, - pl_module: 'pl.LightningModule', + model: 'pl.LightningModule', optimizer: Optimizer, optimizer_idx: int, lambda_closure: Callable, **kwargs: Any, ) -> bool: - # DeepSpeed not support closures. - lambda_closure() - deepspeed_engine = pl_module.trainer.model + """Hook to do something before each optimizer step.""" + super().pre_optimizer_step(model, optimizer, optimizer_idx, lambda_closure, **kwargs) + # the following should be in a `optimizer_step` hook but we don't have one in the precision plugin. + lambda_closure() # DeepSpeed does not support closures + deepspeed_engine = model.trainer.model deepspeed_engine.step() return False diff --git a/pytorch_lightning/plugins/precision/native_amp.py b/pytorch_lightning/plugins/precision/native_amp.py index c40025dd1d..7cf4f089f1 100644 --- a/pytorch_lightning/plugins/precision/native_amp.py +++ b/pytorch_lightning/plugins/precision/native_amp.py @@ -47,7 +47,7 @@ class NativeMixedPrecisionPlugin(MixedPrecisionPlugin): def pre_optimizer_step( self, - pl_module: 'pl.LightningModule', + model: 'pl.LightningModule', optimizer: Optimizer, optimizer_idx: int, lambda_closure: Callable, @@ -58,16 +58,15 @@ class NativeMixedPrecisionPlugin(MixedPrecisionPlugin): f"native PyTorch amp and lbfgs are not compatible (optimizer {optimizer_idx})." " To request, please file a Github issue in PyTorch and tag @mcarilli" ) - # TODO: Add `on_before_optimizer_step` - # self.scaler.unscale_(optimizer) - # pl_module.trainer.call_hook("on_before_optimizer_step") - if pl_module.automatic_optimization: + result = True + if model.automatic_optimization: result = lambda_closure() - if result is None: - # lambda_closure returning None indicates that backward has been skipped - return False - self.scaler.step(optimizer) - self.scaler.update() + self.scaler.unscale_(optimizer) + super().pre_optimizer_step(model, optimizer, optimizer_idx, lambda_closure, **kwargs) + # lambda_closure returning None indicates that backward has been skipped + if result is not None: + self.scaler.step(optimizer) + self.scaler.update() return False @contextmanager diff --git a/pytorch_lightning/plugins/precision/precision_plugin.py b/pytorch_lightning/plugins/precision/precision_plugin.py index fea02f87ba..ae806ff25e 100644 --- a/pytorch_lightning/plugins/precision/precision_plugin.py +++ b/pytorch_lightning/plugins/precision/precision_plugin.py @@ -104,13 +104,14 @@ class PrecisionPlugin(Plugin, CheckpointHooks): def pre_optimizer_step( self, - pl_module: 'pl.LightningModule', + model: 'pl.LightningModule', optimizer: Optimizer, optimizer_idx: int, lambda_closure: Callable, **kwargs: Any, ) -> bool: """Hook to do something before each optimizer step.""" + model.trainer.call_hook("on_before_optimizer_step", optimizer, optimizer_idx) return True def post_optimizer_step(self, optimizer: Optimizer, optimizer_idx: int) -> None: diff --git a/pytorch_lightning/trainer/callback_hook.py b/pytorch_lightning/trainer/callback_hook.py index 0985075694..63c23d50fa 100644 --- a/pytorch_lightning/trainer/callback_hook.py +++ b/pytorch_lightning/trainer/callback_hook.py @@ -327,6 +327,13 @@ class TrainerCallbackHookMixin(ABC): for callback in self.callbacks: callback.on_after_backward(self, self.lightning_module) + def on_before_optimizer_step(self, optimizer, optimizer_idx): + """ + Called after on_after_backward() once the gradient is accumulated and before optimizer.step(). + """ + for callback in self.callbacks: + callback.on_before_optimizer_step(self, self.lightning_module, optimizer, optimizer_idx) + def on_before_zero_grad(self, optimizer): """ Called after optimizer.step() and before optimizer.zero_grad(). diff --git a/pytorch_lightning/trainer/connectors/logger_connector/fx_validator.py b/pytorch_lightning/trainer/connectors/logger_connector/fx_validator.py index d66b069817..3604574fd1 100644 --- a/pytorch_lightning/trainer/connectors/logger_connector/fx_validator.py +++ b/pytorch_lightning/trainer/connectors/logger_connector/fx_validator.py @@ -23,6 +23,7 @@ class FxValidator: on_configure_sharded_model=None, on_before_backward=dict(on_step=(False, True), on_epoch=(False, True)), on_after_backward=dict(on_step=(False, True), on_epoch=(False, True)), + on_before_optimizer_step=dict(on_step=(False, True), on_epoch=(False, True)), on_before_zero_grad=dict(on_step=(False, True), on_epoch=(False, True)), on_init_start=None, on_init_end=None, diff --git a/tests/models/test_hooks.py b/tests/models/test_hooks.py index 6987977e46..d89fc090c4 100644 --- a/tests/models/test_hooks.py +++ b/tests/models/test_hooks.py @@ -299,6 +299,10 @@ class HookedModel(BoringModel): using_native_amp = kwargs.get('amp_backend') == 'native' using_deepspeed = kwargs.get('plugins') == 'deepspeed' out = [] + on_before_optimizer_step = [ + dict(name='Callback.on_before_optimizer_step', args=(trainer, model, ANY, 0)), + dict(name='on_before_optimizer_step', args=(ANY, 0)), + ] for i in range(batches): out.extend([ dict(name='on_before_batch_transfer', args=(ANY, 0)), @@ -308,7 +312,10 @@ class HookedModel(BoringModel): dict(name='Callback.on_batch_start', args=(trainer, model)), dict(name='Callback.on_train_batch_start', args=(trainer, model, ANY, i, 0)), dict(name='on_train_batch_start', args=(ANY, i, 0)), - # TODO: `on_before_optimizer_step` + # these are before the training step because + # they are not part of the `training_step_and_backward` closure, however, + # with native amp, the closure is run first and then the optimizer step. + *(on_before_optimizer_step if not using_native_amp else []), dict(name='forward', args=(ANY, )), dict(name='training_step', args=(ANY, i)), dict(name='training_step_end', args=(dict(loss=ANY), )), @@ -321,6 +328,7 @@ class HookedModel(BoringModel): *([dict(name='backward', args=(ANY, ANY, 0))] if not using_deepspeed else []), dict(name='Callback.on_after_backward', args=(trainer, model)), dict(name='on_after_backward'), + *(on_before_optimizer_step if using_native_amp else []), dict( name='optimizer_step', args=(current_epoch, i, ANY, 0, ANY), @@ -354,7 +362,8 @@ class HookedModel(BoringModel): dict(name='on_after_backward'), # `manual_backward` calls the previous 3 dict(name='manual_backward', args=(ANY, )), - # TODO: `on_before_optimizer_step` + dict(name='Callback.on_before_optimizer_step', args=(trainer, model, ANY, 0)), + dict(name='on_before_optimizer_step', args=(ANY, 0)), dict(name='training_step', args=(ANY, i)), dict(name='training_step_end', args=(dict(loss=ANY), )), dict(name='Callback.on_train_batch_end', args=(trainer, model, dict(loss=ANY), ANY, i, 0)), diff --git a/tests/plugins/test_amp_plugins.py b/tests/plugins/test_amp_plugins.py index 9d9c029a65..217192d5ca 100644 --- a/tests/plugins/test_amp_plugins.py +++ b/tests/plugins/test_amp_plugins.py @@ -71,12 +71,7 @@ def test_amp_apex_ddp( class GradientUnscaleBoringModel(BoringModel): - def on_after_backward(self): - # TODO: replace with `on_before_optimizer_step` so we don't need to check accumulate and unscale manually - if self.trainer.fit_loop.should_accumulate(): - return - opt = self.optimizers() - self.trainer.precision_plugin.scaler.unscale_(opt) + def on_before_optimizer_step(self, *_): norm = torch.nn.utils.clip_grad_norm_(self.parameters(), 2) if not (torch.isinf(norm) or torch.isnan(norm)): assert norm.item() < 15. diff --git a/tests/trainer/logging_/test_logger_connector.py b/tests/trainer/logging_/test_logger_connector.py index 431ca12eca..27598b40fb 100644 --- a/tests/trainer/logging_/test_logger_connector.py +++ b/tests/trainer/logging_/test_logger_connector.py @@ -34,6 +34,7 @@ def test_fx_validator(tmpdir): callbacks_func = [ 'on_before_backward', 'on_after_backward', + 'on_before_optimizer_step', 'on_batch_end', 'on_batch_start', 'on_before_accelerator_backend_setup', @@ -124,6 +125,7 @@ def test_fx_validator(tmpdir): # creating allowed condition allowed = ( is_stage or "batch" in func_name or "epoch" in func_name or "grad" in func_name or "backward" in func_name + or "optimizer_step" in func_name ) allowed = ( allowed and "pretrain" not in func_name and "predict" not in func_name