Update docstrings for backward methods (#13886)
Co-authored-by: Carlos Mocholí <carlossmocholi@gmail.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com> Co-authored-by: thomas chaton <thomas@grid.ai>
This commit is contained in:
parent
8af85eeaaf
commit
dcb4dd55d9
|
@ -69,6 +69,7 @@ class ApexMixedPrecisionPlugin(MixedPrecisionPlugin):
|
|||
model: the model to be optimized
|
||||
closure_loss: the loss value obtained from the closure
|
||||
optimizer: current optimizer being used. ``None`` if using manual optimization
|
||||
optimizer_idx: the index of the current optimizer. ``None`` if using manual optimization
|
||||
"""
|
||||
opt = optimizer or model.trainer.optimizers
|
||||
with amp.scale_loss(closure_loss, opt) as closure_loss:
|
||||
|
|
|
@ -81,6 +81,16 @@ class DeepSpeedPrecisionPlugin(PrecisionPlugin):
|
|||
*args: Any,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
r"""Performs back-propagation using DeepSpeed's engine.
|
||||
|
||||
Args:
|
||||
model: the model to be optimized
|
||||
closure_loss: the loss tensor
|
||||
optimizer: ignored for DeepSpeed
|
||||
optimizer_idx: ignored for DeepSpeed
|
||||
\*args: additional positional arguments for the :meth:`deepspeed.DeepSpeedEngine.backward` call
|
||||
\**kwargs: additional keyword arguments for the :meth:`deepspeed.DeepSpeedEngine.backward` call
|
||||
"""
|
||||
if is_overridden("backward", model):
|
||||
warning_cache.warn(
|
||||
"You have overridden the `LightningModule.backward` hook but it will be ignored since DeepSpeed handles"
|
||||
|
|
|
@ -68,12 +68,16 @@ class PrecisionPlugin(CheckpointHooks):
|
|||
*args: Any,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Performs the actual backpropagation.
|
||||
r"""Performs the actual backpropagation.
|
||||
|
||||
Args:
|
||||
model: the model to be optimized
|
||||
closure_loss: the loss value obtained from the closure
|
||||
optimizer: current optimizer being used. ``None`` if using manual optimization
|
||||
optimizer_idx: the index of the current optimizer. ``None`` if using manual optimization
|
||||
\*args: Positional arguments intended for the actual function that performs the backward, like
|
||||
:meth:`~torch.Tensor.backward`.
|
||||
\**kwargs: Keyword arguments for the same purpose as ``*args``.
|
||||
"""
|
||||
# do backward pass
|
||||
if model is not None and isinstance(model, pl.LightningModule):
|
||||
|
|
|
@ -179,10 +179,15 @@ class Strategy(ABC):
|
|||
*args: Any,
|
||||
**kwargs: Any,
|
||||
) -> Tensor:
|
||||
"""Forwards backward-calls to the precision plugin.
|
||||
r"""Forwards backward-calls to the precision plugin.
|
||||
|
||||
Args:
|
||||
closure_loss: a tensor holding the loss value to backpropagate
|
||||
optimizer: An optional optimizer that gets passed down to the precision plugin's backward
|
||||
optimizer_idx: An optional optimizer index that gets passed down to the precision plugin's backward
|
||||
\*args: Positional arguments that get passed down to the precision plugin's backward, intended as arguments
|
||||
for the actual function that performs the backward, like :meth:`~torch.Tensor.backward`.
|
||||
\**kwargs: Keyword arguments for the same purpose as ``*args``.
|
||||
"""
|
||||
self.pre_backward(closure_loss)
|
||||
assert self.lightning_module is not None
|
||||
|
|
Loading…
Reference in New Issue