Document gradient clipping in Fabric (#16943)

This commit is contained in:
Adrian Wälchli 2023-03-05 18:03:57 +01:00 committed by GitHub
parent ac4180fc2f
commit f2caa01bb3
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 48 additions and 0 deletions

View File

@ -61,6 +61,41 @@ This replaces any occurrences of ``loss.backward()`` and makes your code acceler
fabric.backward(loss)
clip_gradients
==============
Clip the gradients of the model to a given max value or max norm.
This is useful if your model experiences *exploding gradients* during training.
.. code-block:: python
# Clip gradients to a max value of +/- 0.5
fabric.clip_gradients(model, clip_val=0.5)
# Clip gradients such that their total norm is no bigger than 2.0
fabric.clip_gradients(model, clip_norm=2.0)
# By default, clipping by norm uses the 2-norm
fabric.clip_gradients(model, clip_norm=2.0, norm_type=2)
# You can also choose the infinity-norm, which clips the largest
# element among all
fabric.clip_gradients(model, clip_norm=2.0, norm_type="inf")
You can also reduce the gradient clipping to just one layer or to the parameters a particular optimizer is referencing (if using multiple optimizers):
.. code-block:: python
# Clip gradients on a specific layer of your model
fabric.clip_gradients(model.fc3, clip_val=1.0)
# Clip gradients for a specific optimizer if using multiple optimizers
fabric.clip_gradients(model, optimizer1, clip_val=1.0)
The :meth:`~lightning.fabric.fabric.Fabric.clip_gradients` method is agnostic to the precision and strategy being used.
Note: Gradient clipping with FSDP is not yet fully supported.
to_device
=========

View File

@ -372,6 +372,19 @@ class Fabric:
norm_type: Union[float, int] = 2.0,
error_if_nonfinite: bool = True,
) -> Optional[torch.Tensor]:
"""Clip the gradients of the model to a given max value or max norm.
Args:
module: The module whose parameters should be clipped. This can also be just one submodule of your model.
optimizer: Optional optimizer. If passed, clipping will be applied to only the parameters that the
optimizer is referencing.
clip_val: If passed, gradients will be clipped to this value.
max_norm: If passed, clips the gradients in such a way that the p-norm of the resulting parameters is
no larger than the given value.
norm_type: The type of norm if `max_norm` was passed. Can be ``'inf'`` for infinity norm.
Default is the 2-norm.
error_if_nonfinite: An error is raised if the total norm of the gradients is NaN or infinite.
"""
if clip_val is not None and max_norm is not None:
raise ValueError(
"Only one of `clip_val` or `max_norm` can be set as this specifies the underlying clipping algorithm!"