Adds Gradient Clipping to Fabric (#16715)

Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com>
Co-authored-by: Carlos Mocholí <carlossmocholi@gmail.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Jirka Borovec <6035284+Borda@users.noreply.github.com>
This commit is contained in:
Justus Schock 2023-02-28 00:44:13 +01:00 committed by GitHub
parent 8884c8970c
commit 3d1927e6bc
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
16 changed files with 356 additions and 16 deletions

View File

@ -363,6 +363,33 @@ class Fabric:
self._precision.backward(tensor, module, *args, **kwargs)
def clip_gradients(
self,
module: Union[torch.nn.Module, _FabricModule],
optimizer: Union[Optimizer, _FabricOptimizer],
clip_val: Optional[Union[float, int]] = None,
max_norm: Optional[Union[float, int]] = None,
norm_type: Union[float, int] = 2.0,
error_if_nonfinite: bool = True,
) -> Optional[torch.Tensor]:
if clip_val is not None and max_norm is not None:
raise ValueError(
"Only one of `clip_val` or `max_norm` can be set as this specifies the underlying clipping algorithm!"
)
if clip_val is not None:
self.strategy.clip_gradients_value(_unwrap_objects(module), _unwrap_objects(optimizer), clip_val=clip_val)
return None
elif max_norm is not None:
return self.strategy.clip_gradients_norm(
_unwrap_objects(module),
_unwrap_objects(optimizer),
max_norm=max_norm,
norm_type=norm_type,
error_if_nonfinite=error_if_nonfinite,
)
raise ValueError("You have to specify either `clip_val` or `max_norm` to do gradient clipping!")
@contextmanager
def autocast(self) -> Generator[None, None, None]:
"""A context manager to automatically convert operations for the chosen precision.

View File

@ -17,7 +17,7 @@ from typing import Any, cast, Dict, Generator, Literal, Optional
import torch
from torch import Tensor
from torch.nn import Module
from torch.optim import LBFGS
from torch.optim import LBFGS, Optimizer
from lightning.fabric.accelerators.cuda import _patch_cuda_is_available
from lightning.fabric.plugins.precision.precision import Precision
@ -93,3 +93,20 @@ class MixedPrecision(Precision):
# the dtype could be automatically inferred but we need to manually set it due to a bug upstream
# https://github.com/pytorch/pytorch/issues/67233
return torch.autocast(self.device, dtype=torch.bfloat16 if self.precision == "bf16-mixed" else torch.half)
def unscale_gradients(self, optimizer: Optimizer) -> None:
scaler = self.scaler
if scaler is not None:
if _optimizer_handles_unscaling(optimizer):
raise NotImplementedError("Gradient clipping is not implemented for optimizers handling the unscaling.")
scaler.unscale_(optimizer)
def _optimizer_handles_unscaling(optimizer: Any) -> bool:
"""Determines whether a PyTorch optimizer handles unscaling gradients in the step method rather than through the
:class:`torch.cuda.amp.GradScaler`.
Since, the current implementation of this function checks a PyTorch internal variable on the optimizer, the return
value will only be reliable for built-in PyTorch optimizers.
"""
return getattr(optimizer, "_step_supports_amp_scaling", False)

View File

@ -97,6 +97,9 @@ class Precision:
for group in optimizer.param_groups:
yield from group["params"]
def unscale_gradients(self, optimizer: Optimizer) -> None:
return
def state_dict(self) -> Dict[str, Any]:
"""Called when saving a checkpoint, implement to generate precision plugin state_dict.

View File

@ -486,6 +486,27 @@ class DeepSpeedStrategy(DDPStrategy, _Sharded):
state[k] = client_state.pop(k)
return client_state
def clip_gradients_norm(
self,
module: "deepspeed.DeepSpeedEngine",
optimizer: Optimizer,
max_norm: Union[float, int],
norm_type: Union[float, int] = 2.0,
error_if_nonfinite: bool = True,
) -> torch.Tensor:
raise NotImplementedError(
"DeepSpeed handles gradient clipping automatically within the optimizer. "
"Make sure to set the `gradient_clipping` value in your Config."
)
def clip_gradients_value(
self, module: "deepspeed.DeepSpeedEngine", optimizer: Optimizer, clip_val: Union[float, int]
) -> None:
raise NotImplementedError(
"DeepSpeed handles gradient clipping automatically within the optimizer. "
"Make sure to set the `gradient_clipping` value in your Config."
)
@classmethod
def register_strategies(cls, strategy_registry: Dict) -> None:
strategy_registry.register("deepspeed", cls, description="Default DeepSpeed Strategy")

View File

@ -37,7 +37,7 @@ from lightning.fabric.utilities.distributed import (
from lightning.fabric.utilities.distributed import group as _group
from lightning.fabric.utilities.distributed import ReduceOp
from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_1_12, _TORCH_GREATER_EQUAL_1_13
from lightning.fabric.utilities.rank_zero import rank_zero_only
from lightning.fabric.utilities.rank_zero import rank_zero_only, rank_zero_warn
from lightning.fabric.utilities.seed import reset_seed
if TYPE_CHECKING:
@ -268,6 +268,29 @@ class FSDPStrategy(ParallelStrategy, _Sharded):
torch.distributed.broadcast_object_list(obj, src, group=_group.WORLD)
return obj[0]
def clip_gradients_norm( # type: ignore[override]
self,
module: "FullyShardedDataParallel",
optimizer: Optimizer,
max_norm: Union[float, int],
norm_type: Union[float, int] = 2.0,
error_if_nonfinite: bool = True,
) -> Tensor:
"""Clip gradients by norm."""
rank_zero_warn("Gradient Clipping by Norm is currently experimental for FSDP. Proceed with Caution!")
self.precision.unscale_gradients(optimizer)
return module.clip_grad_norm_(max_norm=max_norm, norm_type=norm_type) # type: ignore[return-value]
def clip_gradients_value( # type: ignore[override]
self, module: "FullyShardedDataParallel", optimizer: Optimizer, clip_val: Union[float, int]
) -> None:
"""Clip gradients by value."""
raise NotImplementedError(
"FSDP currently does not support to clip gradients by value. "
"Consider clipping by norm instead or choose another strategy!"
)
@classmethod
def register_strategies(cls, strategy_registry: Dict) -> None:
if not _TORCH_GREATER_EQUAL_1_12 or not torch.distributed.is_available():

View File

@ -305,6 +305,27 @@ class Strategy(ABC):
self.accelerator.teardown()
self.checkpoint_io.teardown()
def clip_gradients_norm(
self,
module: torch.nn.Module,
optimizer: Optimizer,
max_norm: Union[float, int],
norm_type: Union[float, int] = 2.0,
error_if_nonfinite: bool = True,
) -> torch.Tensor:
"""Clip gradients by norm."""
self.precision.unscale_gradients(optimizer)
parameters = self.precision.main_params(optimizer)
return torch.nn.utils.clip_grad_norm_(
parameters, max_norm=max_norm, norm_type=norm_type, error_if_nonfinite=error_if_nonfinite
)
def clip_gradients_value(self, module: torch.nn.Module, optimizer: Optimizer, clip_val: Union[float, int]) -> None:
"""Clip gradients by value."""
self.precision.unscale_gradients(optimizer)
parameters = self.precision.main_params(optimizer)
return torch.nn.utils.clip_grad_value_(parameters, clip_value=clip_val)
@classmethod
def register_strategies(cls, strategy_registry: Dict[str, Any]) -> None:
pass

View File

@ -18,6 +18,7 @@ from torch.optim import LBFGS, Optimizer
import lightning.pytorch as pl
from lightning.fabric.accelerators.cuda import _patch_cuda_is_available
from lightning.fabric.plugins.precision.amp import _optimizer_handles_unscaling
from lightning.fabric.utilities.types import Optimizable
from lightning.pytorch.plugins.precision.precision_plugin import PrecisionPlugin
from lightning.pytorch.utilities import GradClipAlgorithmType
@ -116,13 +117,3 @@ class MixedPrecisionPlugin(PrecisionPlugin):
def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
if self.scaler is not None:
self.scaler.load_state_dict(state_dict)
def _optimizer_handles_unscaling(optimizer: Any) -> bool:
"""Determines whether a PyTorch optimizer handles unscaling gradients in the step method rather than through the
:class:`torch.cuda.amp.GradScaler`.
Since, the current implementation of this function checks a PyTorch internal variable on the optimizer, the return
value will only be reliable for built-in PyTorch optimizers.
"""
return getattr(optimizer, "_step_supports_amp_scaling", False)

View File

@ -48,7 +48,7 @@ class BoringFabric(Fabric):
loss = torch.nn.functional.mse_loss(output, torch.ones_like(output))
return loss
def after_backward(self, model: Module) -> None:
def after_backward(self, model: Module, optimizer: Optimizer) -> None:
pass
def after_optimizer_step(self, model: Module, optimizer: Optimizer) -> None:
@ -77,7 +77,7 @@ class BoringFabric(Fabric):
batch = next(data_iter)
loss = self.step(model, batch)
self.backward(loss)
self.after_backward(model)
self.after_backward(model, optimizer)
optimizer.step()
self.after_optimizer_step(model, optimizer)
optimizer.zero_grad()

View File

@ -54,7 +54,7 @@ class MixedPrecisionBoringFabric(BoringFabric):
loss = torch.nn.functional.mse_loss(output, torch.ones_like(output))
return loss
def after_backward(self, model):
def after_backward(self, model, optimizer):
assert model.layer.weight.grad.dtype == torch.float32

View File

@ -45,7 +45,7 @@ class DoublePrecisionBoringFabric(BoringFabric):
loss = torch.nn.functional.mse_loss(output, torch.ones_like(output))
return loss
def after_backward(self, model):
def after_backward(self, model, optimizer):
assert model.layer.weight.grad.dtype == torch.float64

View File

@ -16,6 +16,8 @@ from unittest.mock import MagicMock, Mock
import pytest
import torch
from tests_fabric.helpers.runif import RunIf
from tests_fabric.strategies.test_single_device import _MyFabricGradNorm, _MyFabricGradVal
from torch.nn.parallel import DistributedDataParallel
from lightning.fabric.strategies import DDPStrategy
@ -99,3 +101,28 @@ def test_ddp_module_state_dict():
with mock.patch("lightning.fabric.strategies.ddp.DistributedDataParallel", DistributedDataParallelMock):
wrapped_module = strategy.setup_module(original_module)
assert strategy.get_module_state_dict(wrapped_module).keys() == original_module.state_dict().keys()
@pytest.mark.parametrize(
"clip_type,accelerator,precision",
[
("norm", "cpu", "32-true"),
("val", "cpu", "32-true"),
("norm", "cpu", "bf16-mixed"),
("val", "cpu", "bf16-mixed"),
pytest.param("norm", "cuda", "32-true", marks=RunIf(min_cuda_gpus=2)),
pytest.param("val", "cuda", "32-true", marks=RunIf(min_cuda_gpus=2)),
pytest.param("norm", "cuda", "16-mixed", marks=RunIf(min_cuda_gpus=2)),
pytest.param("val", "cuda", "16-mixed", marks=RunIf(min_cuda_gpus=2)),
pytest.param("norm", "cuda", "bf16-mixed", marks=RunIf(min_cuda_gpus=2, bf16_cuda=True)),
pytest.param("val", "cuda", "bf16-mixed", marks=RunIf(min_cuda_gpus=2, bf16_cuda=True)),
],
)
@RunIf(standalone=True)
def test_ddp_grad_clipping(clip_type, accelerator, precision):
if clip_type == "norm":
clipping_test_cls = _MyFabricGradNorm
else:
clipping_test_cls = _MyFabricGradVal
fabric = clipping_test_cls(accelerator=accelerator, devices=2, precision=precision, strategy="ddp")
fabric.run()

View File

@ -321,3 +321,25 @@ def test_deepspeed_load_checkpoint_optimzer_state_requested(optimzer_state_reque
load_lr_scheduler_states=False,
load_module_strict=True,
)
@RunIf(deepspeed=True)
def test_errors_grad_clipping():
strategy = DeepSpeedStrategy()
with pytest.raises(
NotImplementedError,
match=(
"DeepSpeed handles gradient clipping automatically within the optimizer. "
"Make sure to set the `gradient_clipping` value in your Config."
),
):
strategy.clip_gradients_norm(Mock(), Mock(), Mock(), Mock(), Mock())
with pytest.raises(
NotImplementedError,
match=(
"DeepSpeed handles gradient clipping automatically within the optimizer. "
"Make sure to set the `gradient_clipping` value in your Config."
),
):
strategy.clip_gradients_value(Mock(), Mock(), Mock())

View File

@ -14,7 +14,10 @@
from unittest import mock
from unittest.mock import MagicMock, Mock
import pytest
import torch
from tests_fabric.helpers.runif import RunIf
from tests_fabric.strategies.test_single_device import _MyFabricGradNorm, _MyFabricGradVal
from lightning.fabric.strategies import DataParallelStrategy
@ -68,3 +71,22 @@ def test_dp_module_state_dict():
with mock.patch("lightning.fabric.strategies.dp.DataParallel", DataParallelMock):
wrapped_module = strategy.setup_module(original_module)
assert strategy.get_module_state_dict(wrapped_module).keys() == original_module.state_dict().keys()
@pytest.mark.parametrize(
"precision",
[
"32-true",
"16-mixed",
pytest.param("bf16-mixed", marks=RunIf(bf16_cuda=True)),
],
)
@pytest.mark.parametrize("clip_type", ["norm", "val"])
@RunIf(min_cuda_gpus=2)
def test_dp_grad_clipping(clip_type, precision):
if clip_type == "norm":
clipping_test_cls = _MyFabricGradNorm
else:
clipping_test_cls = _MyFabricGradVal
fabric = clipping_test_cls(accelerator="cuda", devices=2, precision=precision, strategy="dp")
fabric.run()

View File

@ -27,6 +27,7 @@ from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_1_12
if _TORCH_GREATER_EQUAL_1_12:
from torch.distributed.fsdp.fully_sharded_data_parallel import CPUOffload, FullyShardedDataParallel, MixedPrecision
from tests_fabric.strategies.test_single_device import _MyFabricGradNorm
@mock.patch("lightning.fabric.strategies.fsdp._TORCH_GREATER_EQUAL_1_12", False)
@ -131,3 +132,40 @@ def test_fsdp_activation_checkpointing():
) as ckpt_mock:
strategy.setup_module(Model())
ckpt_mock.assert_called_with(fsdp_mock(), checkpoint_wrapper_fn=ANY, check_fn=ANY)
@RunIf(min_torch="1.13")
def test_fsdp_grad_clipping_value_error():
strategy = FSDPStrategy()
with pytest.raises(
NotImplementedError,
match=(
"FSDP currently does not support to clip gradients by value. "
"Consider clipping by norm instead or choose another strategy!"
),
):
strategy.clip_gradients_value(Mock(), Mock(), Mock())
class _MyFSDPFabricGradientNorm(_MyFabricGradNorm):
def after_backward(self, model, optimizer):
self.clip_gradients(model, optimizer, max_norm=0.05, error_if_nonfinite=True)
with model._forward_module.summon_full_params(model._forward_module):
parameters = model.parameters()
grad_norm = torch.linalg.vector_norm(
torch.stack([torch.linalg.vector_norm(p.grad.detach(), 2, dtype=torch.float32) for p in parameters]),
2,
)
torch.testing.assert_close(grad_norm, torch.tensor(0.05, device=self.device))
@pytest.mark.parametrize(
"precision",
["32-true", "16-mixed", pytest.param("bf16-mixed", marks=RunIf(bf16_cuda=True))],
)
@RunIf(min_cuda_gpus=2, standalone=True)
@pytest.mark.xfail(reason="Testing with FSDP is not yet correct") # TODO: Investigate testing with fsdp
def test_fsdp_grad_clipping_norm(precision):
fabric = _MyFSDPFabricGradientNorm(accelerator="cuda", devices=2, precision=precision, strategy="fsdp")
fabric.run()

View File

@ -15,8 +15,11 @@ from unittest.mock import Mock
import pytest
import torch
from tests_fabric.helpers.models import BoringFabric
from tests_fabric.helpers.runif import RunIf
from lightning.fabric.strategies import SingleDeviceStrategy
from lightning.fabric.wrappers import _FabricModule, _FabricOptimizer
def test_single_device_default_device():
@ -52,3 +55,98 @@ def test_single_device_module_to_device():
module = Mock(spec=torch.nn.Module)
strategy.module_to_device(module)
module.to.assert_called_with(strategy.root_device)
class _MyFabricGradNorm(BoringFabric):
def after_backward(self, model: _FabricModule, optimizer: _FabricOptimizer):
self.clip_gradients(model, optimizer, max_norm=0.05, error_if_nonfinite=True)
parameters = model.parameters()
grad_norm = torch.linalg.vector_norm(
torch.stack([torch.linalg.vector_norm(p.grad.detach(), 2, dtype=torch.float32) for p in parameters]),
2,
)
torch.testing.assert_close(grad_norm, torch.tensor(0.05, device=self.device))
def run(self):
# 10 retries
i = 0
while True:
try:
super().run()
break
except RuntimeError as e:
# nonfinite grads -> skip and continue
# this may repeat until the scaler finds a factor where overflow is avoided,
# so the while loop should eventually break
# stop after a max of 10 tries
if i > 10 or not str(e).startswith("The total norm"):
raise e
# unscale was already called by last attempt,
# but no update afterwards since optimizer step was missing.
# Manually update here -> Need to update inf stats first.
scaler = getattr(self._precision, "scaler", None)
if scaler is not None:
scaler._check_inf_per_device(self.optimizer)
scaler.update()
finally:
i += 1
class _MyFabricGradVal(BoringFabric):
def after_backward(self, model, optimizer):
for p in model.parameters():
if p.grad is not None and torch.isnan(p.grad).any().item() or torch.isinf(p.grad).any().item():
raise RuntimeError("Nonfinite grads")
self.clip_gradients(model, optimizer, clip_val=1e-10)
parameters = model.parameters()
grad_max_list = [torch.max(p.grad.detach().abs()) for p in parameters]
grad_max = torch.max(torch.stack(grad_max_list))
torch.testing.assert_close(grad_max.abs(), torch.tensor(1e-10, device=self.device))
print("done")
def run(self):
# 10 retries
i = 0
while True:
try:
super().run()
break
except RuntimeError as e:
# nonfinite grads -> skip and continue
# this may repeat until the scaler finds a factor where overflow is avoided,
# so the while loop should eventually break
# stop after a max of 10 tries
if i > 10 or not str(e).startswith("Nonfinite grads"):
raise e
# unscale was already called by last attempt,
# but no update afterwards since optimizer step was missing.
# Manually update here -> Need to update inf stats first.
scaler = getattr(self._precision, "scaler", None)
if scaler is not None:
scaler._check_inf_per_device(self.optimizer)
scaler.update()
finally:
i += 1
@pytest.mark.parametrize(
"precision",
[
"32-true",
pytest.param("16-mixed", marks=RunIf(min_cuda_gpus=1)),
pytest.param("bf16-mixed", marks=RunIf(bf16_cuda=True)),
],
)
@pytest.mark.parametrize("clip_type", ["norm", "val"])
def test_single_device_grad_clipping(clip_type, precision):
if clip_type == "norm":
clipping_test_cls = _MyFabricGradNorm
else:
clipping_test_cls = _MyFabricGradVal
fabric = clipping_test_cls(accelerator="auto", devices=1, precision=precision)
fabric.run()

View File

@ -893,3 +893,33 @@ def test_all_reduce():
# dict
fabric.all_reduce({"a": torch.tensor(4), "b": [torch.tensor(5)], "c": "string"})
fabric._strategy.all_reduce.assert_has_calls([call(torch.tensor(4), **defaults), call(torch.tensor(5), **defaults)])
@pytest.mark.parametrize("clip_val,max_norm", [(1e-3, None), (None, 1)])
def test_grad_clipping(clip_val, max_norm):
fabric = Fabric()
fabric.strategy.clip_gradients_norm = Mock()
fabric.strategy.clip_gradients_value = Mock()
torch_model = nn.Linear(1, 1)
torch_optimizer = torch.optim.SGD(torch_model.parameters(), lr=1e-3)
model, optimizer = fabric.setup(torch_model, torch_optimizer)
loss = model(torch.rand(1, 1).to(fabric.device))
fabric.backward(loss)
fabric.strategy.clip_gradients_value.assert_not_called()
fabric.strategy.clip_gradients_norm.assert_not_called()
fabric.clip_gradients(model, optimizer, max_norm=max_norm, clip_val=clip_val)
if clip_val is not None:
fabric.strategy.clip_gradients_value.assert_called_once_with(torch_model, torch_optimizer, clip_val=clip_val)
fabric.strategy.clip_gradients_norm.assert_not_called()
else:
fabric.strategy.clip_gradients_value.assert_not_called()
fabric.strategy.clip_gradients_norm.assert_called_once_with(
torch_model, torch_optimizer, max_norm=max_norm, norm_type=2.0, error_if_nonfinite=True
)