Update docstrings for backward methods (#13886)

Co-authored-by: Carlos Mocholí <carlossmocholi@gmail.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com>
Co-authored-by: thomas chaton <thomas@grid.ai>
This commit is contained in:
Adrian Wälchli 2022-08-03 15:06:07 +02:00 committed by GitHub
parent 8af85eeaaf
commit dcb4dd55d9
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 22 additions and 2 deletions

View File

@ -69,6 +69,7 @@ class ApexMixedPrecisionPlugin(MixedPrecisionPlugin):
model: the model to be optimized
closure_loss: the loss value obtained from the closure
optimizer: current optimizer being used. ``None`` if using manual optimization
optimizer_idx: the index of the current optimizer. ``None`` if using manual optimization
"""
opt = optimizer or model.trainer.optimizers
with amp.scale_loss(closure_loss, opt) as closure_loss:

View File

@ -81,6 +81,16 @@ class DeepSpeedPrecisionPlugin(PrecisionPlugin):
*args: Any,
**kwargs: Any,
) -> None:
r"""Performs back-propagation using DeepSpeed's engine.
Args:
model: the model to be optimized
closure_loss: the loss tensor
optimizer: ignored for DeepSpeed
optimizer_idx: ignored for DeepSpeed
\*args: additional positional arguments for the :meth:`deepspeed.DeepSpeedEngine.backward` call
\**kwargs: additional keyword arguments for the :meth:`deepspeed.DeepSpeedEngine.backward` call
"""
if is_overridden("backward", model):
warning_cache.warn(
"You have overridden the `LightningModule.backward` hook but it will be ignored since DeepSpeed handles"

View File

@ -68,12 +68,16 @@ class PrecisionPlugin(CheckpointHooks):
*args: Any,
**kwargs: Any,
) -> None:
"""Performs the actual backpropagation.
r"""Performs the actual backpropagation.
Args:
model: the model to be optimized
closure_loss: the loss value obtained from the closure
optimizer: current optimizer being used. ``None`` if using manual optimization
optimizer_idx: the index of the current optimizer. ``None`` if using manual optimization
\*args: Positional arguments intended for the actual function that performs the backward, like
:meth:`~torch.Tensor.backward`.
\**kwargs: Keyword arguments for the same purpose as ``*args``.
"""
# do backward pass
if model is not None and isinstance(model, pl.LightningModule):

View File

@ -179,10 +179,15 @@ class Strategy(ABC):
*args: Any,
**kwargs: Any,
) -> Tensor:
"""Forwards backward-calls to the precision plugin.
r"""Forwards backward-calls to the precision plugin.
Args:
closure_loss: a tensor holding the loss value to backpropagate
optimizer: An optional optimizer that gets passed down to the precision plugin's backward
optimizer_idx: An optional optimizer index that gets passed down to the precision plugin's backward
\*args: Positional arguments that get passed down to the precision plugin's backward, intended as arguments
for the actual function that performs the backward, like :meth:`~torch.Tensor.backward`.
\**kwargs: Keyword arguments for the same purpose as ``*args``.
"""
self.pre_backward(closure_loss)
assert self.lightning_module is not None