mark `OptimizerLoop.backward` method protected (#9514)

Co-authored-by: Carlos Mocholí <carlossmocholi@gmail.com>
This commit is contained in:
Adrian Wälchli 2021-09-15 14:58:01 +02:00 committed by GitHub
parent 23450e2905
commit 200ed9eb9f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 7 additions and 6 deletions

View File

@ -77,6 +77,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
* Removed `TrainingBatchLoop.backward()`; manual optimization now calls directly into `Accelerator.backward()` and automatic optimization handles backward in new `OptimizerLoop` ([#9265](https://github.com/PyTorchLightning/pytorch-lightning/pull/9265)) * Removed `TrainingBatchLoop.backward()`; manual optimization now calls directly into `Accelerator.backward()` and automatic optimization handles backward in new `OptimizerLoop` ([#9265](https://github.com/PyTorchLightning/pytorch-lightning/pull/9265))
* Extracted `ManualOptimization` logic from `TrainingBatchLoop` into its own separate loop class ([#9266](https://github.com/PyTorchLightning/pytorch-lightning/pull/9266)) * Extracted `ManualOptimization` logic from `TrainingBatchLoop` into its own separate loop class ([#9266](https://github.com/PyTorchLightning/pytorch-lightning/pull/9266))
* Added `OutputResult` and `ManualResult` classes ([#9437](https://github.com/PyTorchLightning/pytorch-lightning/pull/9437), [#9424](https://github.com/PyTorchLightning/pytorch-lightning/pull/9424)) * Added `OutputResult` and `ManualResult` classes ([#9437](https://github.com/PyTorchLightning/pytorch-lightning/pull/9437), [#9424](https://github.com/PyTorchLightning/pytorch-lightning/pull/9424))
* Marked `OptimizerLoop.backward` as protected ([#9514](https://github.com/PyTorchLightning/pytorch-lightning/pull/9514))
- Added support for saving and loading state of multiple callbacks of the same type ([#7187](https://github.com/PyTorchLightning/pytorch-lightning/pull/7187)) - Added support for saving and loading state of multiple callbacks of the same type ([#7187](https://github.com/PyTorchLightning/pytorch-lightning/pull/7187))

View File

@ -221,7 +221,7 @@ class OptimizerLoop(Loop):
outputs, self.outputs = self.outputs, [] # free memory outputs, self.outputs = self.outputs, [] # free memory
return outputs return outputs
def backward( def _backward(
self, loss: Tensor, optimizer: torch.optim.Optimizer, opt_idx: int, *args: Any, **kwargs: Any self, loss: Tensor, optimizer: torch.optim.Optimizer, opt_idx: int, *args: Any, **kwargs: Any
) -> Tensor: ) -> Tensor:
"""Performs the backward step. """Performs the backward step.
@ -337,7 +337,7 @@ class OptimizerLoop(Loop):
return None return None
def backward_fn(loss: Tensor) -> Tensor: def backward_fn(loss: Tensor) -> Tensor:
self.backward(loss, optimizer, opt_idx) self._backward(loss, optimizer, opt_idx)
# check if model weights are nan # check if model weights are nan
if self.trainer.terminate_on_nan: if self.trainer.terminate_on_nan:

View File

@ -961,7 +961,7 @@ def test_gradient_clipping_by_norm(tmpdir, precision):
gradient_clip_val=1.0, gradient_clip_val=1.0,
) )
old_backward = trainer.fit_loop.epoch_loop.batch_loop.optimizer_loop.backward old_backward = trainer.fit_loop.epoch_loop.batch_loop.optimizer_loop._backward
def backward(*args, **kwargs): def backward(*args, **kwargs):
# test that gradient is clipped correctly # test that gradient is clipped correctly
@ -971,7 +971,7 @@ def test_gradient_clipping_by_norm(tmpdir, precision):
assert (grad_norm - 1.0).abs() < 0.01, f"Gradient norm != 1.0: {grad_norm}" assert (grad_norm - 1.0).abs() < 0.01, f"Gradient norm != 1.0: {grad_norm}"
return ret_val return ret_val
trainer.fit_loop.epoch_loop.batch_loop.optimizer_loop.backward = backward trainer.fit_loop.epoch_loop.batch_loop.optimizer_loop._backward = backward
trainer.fit(model) trainer.fit(model)
@ -996,7 +996,7 @@ def test_gradient_clipping_by_value(tmpdir, precision):
default_root_dir=tmpdir, default_root_dir=tmpdir,
) )
old_backward = trainer.fit_loop.epoch_loop.batch_loop.optimizer_loop.backward old_backward = trainer.fit_loop.epoch_loop.batch_loop.optimizer_loop._backward
def backward(*args, **kwargs): def backward(*args, **kwargs):
# test that gradient is clipped correctly # test that gradient is clipped correctly
@ -1009,7 +1009,7 @@ def test_gradient_clipping_by_value(tmpdir, precision):
), f"Gradient max value {grad_max} != grad_clip_val {grad_clip_val} ." ), f"Gradient max value {grad_max} != grad_clip_val {grad_clip_val} ."
return ret_val return ret_val
trainer.fit_loop.epoch_loop.batch_loop.optimizer_loop.backward = backward trainer.fit_loop.epoch_loop.batch_loop.optimizer_loop._backward = backward
trainer.fit(model) trainer.fit(model)