From 3d1927e6bc542f11a395c12a17f39f0e44b7c653 Mon Sep 17 00:00:00 2001 From: Justus Schock <12886177+justusschock@users.noreply.github.com> Date: Tue, 28 Feb 2023 00:44:13 +0100 Subject: [PATCH] Adds Gradient Clipping to Fabric (#16715) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Adrian Wälchli Co-authored-by: Carlos Mocholí Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Jirka Borovec <6035284+Borda@users.noreply.github.com> --- src/lightning/fabric/fabric.py | 27 +++++ src/lightning/fabric/plugins/precision/amp.py | 19 +++- .../fabric/plugins/precision/precision.py | 3 + src/lightning/fabric/strategies/deepspeed.py | 21 ++++ src/lightning/fabric/strategies/fsdp.py | 25 ++++- src/lightning/fabric/strategies/strategy.py | 21 ++++ .../pytorch/plugins/precision/amp.py | 11 +-- tests/tests_fabric/helpers/models.py | 4 +- .../plugins/precision/test_amp_integration.py | 2 +- .../precision/test_double_integration.py | 2 +- tests/tests_fabric/strategies/test_ddp.py | 27 +++++ .../tests_fabric/strategies/test_deepspeed.py | 22 +++++ tests/tests_fabric/strategies/test_dp.py | 22 +++++ tests/tests_fabric/strategies/test_fsdp.py | 38 +++++++ .../strategies/test_single_device.py | 98 +++++++++++++++++++ tests/tests_fabric/test_fabric.py | 30 ++++++ 16 files changed, 356 insertions(+), 16 deletions(-) diff --git a/src/lightning/fabric/fabric.py b/src/lightning/fabric/fabric.py index 60ce9e84b1..c6b9423b40 100644 --- a/src/lightning/fabric/fabric.py +++ b/src/lightning/fabric/fabric.py @@ -363,6 +363,33 @@ class Fabric: self._precision.backward(tensor, module, *args, **kwargs) + def clip_gradients( + self, + module: Union[torch.nn.Module, _FabricModule], + optimizer: Union[Optimizer, _FabricOptimizer], + clip_val: Optional[Union[float, int]] = None, + max_norm: Optional[Union[float, int]] = None, + norm_type: Union[float, int] = 2.0, + error_if_nonfinite: bool = True, + ) -> Optional[torch.Tensor]: + if clip_val is not None and max_norm is not None: + raise ValueError( + "Only one of `clip_val` or `max_norm` can be set as this specifies the underlying clipping algorithm!" + ) + + if clip_val is not None: + self.strategy.clip_gradients_value(_unwrap_objects(module), _unwrap_objects(optimizer), clip_val=clip_val) + return None + elif max_norm is not None: + return self.strategy.clip_gradients_norm( + _unwrap_objects(module), + _unwrap_objects(optimizer), + max_norm=max_norm, + norm_type=norm_type, + error_if_nonfinite=error_if_nonfinite, + ) + raise ValueError("You have to specify either `clip_val` or `max_norm` to do gradient clipping!") + @contextmanager def autocast(self) -> Generator[None, None, None]: """A context manager to automatically convert operations for the chosen precision. diff --git a/src/lightning/fabric/plugins/precision/amp.py b/src/lightning/fabric/plugins/precision/amp.py index e7ff485822..3e35fb96f8 100644 --- a/src/lightning/fabric/plugins/precision/amp.py +++ b/src/lightning/fabric/plugins/precision/amp.py @@ -17,7 +17,7 @@ from typing import Any, cast, Dict, Generator, Literal, Optional import torch from torch import Tensor from torch.nn import Module -from torch.optim import LBFGS +from torch.optim import LBFGS, Optimizer from lightning.fabric.accelerators.cuda import _patch_cuda_is_available from lightning.fabric.plugins.precision.precision import Precision @@ -93,3 +93,20 @@ class MixedPrecision(Precision): # the dtype could be automatically inferred but we need to manually set it due to a bug upstream # https://github.com/pytorch/pytorch/issues/67233 return torch.autocast(self.device, dtype=torch.bfloat16 if self.precision == "bf16-mixed" else torch.half) + + def unscale_gradients(self, optimizer: Optimizer) -> None: + scaler = self.scaler + if scaler is not None: + if _optimizer_handles_unscaling(optimizer): + raise NotImplementedError("Gradient clipping is not implemented for optimizers handling the unscaling.") + scaler.unscale_(optimizer) + + +def _optimizer_handles_unscaling(optimizer: Any) -> bool: + """Determines whether a PyTorch optimizer handles unscaling gradients in the step method rather than through the + :class:`torch.cuda.amp.GradScaler`. + + Since, the current implementation of this function checks a PyTorch internal variable on the optimizer, the return + value will only be reliable for built-in PyTorch optimizers. + """ + return getattr(optimizer, "_step_supports_amp_scaling", False) diff --git a/src/lightning/fabric/plugins/precision/precision.py b/src/lightning/fabric/plugins/precision/precision.py index e1add04366..65c77f0565 100644 --- a/src/lightning/fabric/plugins/precision/precision.py +++ b/src/lightning/fabric/plugins/precision/precision.py @@ -97,6 +97,9 @@ class Precision: for group in optimizer.param_groups: yield from group["params"] + def unscale_gradients(self, optimizer: Optimizer) -> None: + return + def state_dict(self) -> Dict[str, Any]: """Called when saving a checkpoint, implement to generate precision plugin state_dict. diff --git a/src/lightning/fabric/strategies/deepspeed.py b/src/lightning/fabric/strategies/deepspeed.py index 290b2d05c6..b355c08fdb 100644 --- a/src/lightning/fabric/strategies/deepspeed.py +++ b/src/lightning/fabric/strategies/deepspeed.py @@ -486,6 +486,27 @@ class DeepSpeedStrategy(DDPStrategy, _Sharded): state[k] = client_state.pop(k) return client_state + def clip_gradients_norm( + self, + module: "deepspeed.DeepSpeedEngine", + optimizer: Optimizer, + max_norm: Union[float, int], + norm_type: Union[float, int] = 2.0, + error_if_nonfinite: bool = True, + ) -> torch.Tensor: + raise NotImplementedError( + "DeepSpeed handles gradient clipping automatically within the optimizer. " + "Make sure to set the `gradient_clipping` value in your Config." + ) + + def clip_gradients_value( + self, module: "deepspeed.DeepSpeedEngine", optimizer: Optimizer, clip_val: Union[float, int] + ) -> None: + raise NotImplementedError( + "DeepSpeed handles gradient clipping automatically within the optimizer. " + "Make sure to set the `gradient_clipping` value in your Config." + ) + @classmethod def register_strategies(cls, strategy_registry: Dict) -> None: strategy_registry.register("deepspeed", cls, description="Default DeepSpeed Strategy") diff --git a/src/lightning/fabric/strategies/fsdp.py b/src/lightning/fabric/strategies/fsdp.py index b465bc9a94..1e2133073f 100644 --- a/src/lightning/fabric/strategies/fsdp.py +++ b/src/lightning/fabric/strategies/fsdp.py @@ -37,7 +37,7 @@ from lightning.fabric.utilities.distributed import ( from lightning.fabric.utilities.distributed import group as _group from lightning.fabric.utilities.distributed import ReduceOp from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_1_12, _TORCH_GREATER_EQUAL_1_13 -from lightning.fabric.utilities.rank_zero import rank_zero_only +from lightning.fabric.utilities.rank_zero import rank_zero_only, rank_zero_warn from lightning.fabric.utilities.seed import reset_seed if TYPE_CHECKING: @@ -268,6 +268,29 @@ class FSDPStrategy(ParallelStrategy, _Sharded): torch.distributed.broadcast_object_list(obj, src, group=_group.WORLD) return obj[0] + def clip_gradients_norm( # type: ignore[override] + self, + module: "FullyShardedDataParallel", + optimizer: Optimizer, + max_norm: Union[float, int], + norm_type: Union[float, int] = 2.0, + error_if_nonfinite: bool = True, + ) -> Tensor: + """Clip gradients by norm.""" + rank_zero_warn("Gradient Clipping by Norm is currently experimental for FSDP. Proceed with Caution!") + self.precision.unscale_gradients(optimizer) + return module.clip_grad_norm_(max_norm=max_norm, norm_type=norm_type) # type: ignore[return-value] + + def clip_gradients_value( # type: ignore[override] + self, module: "FullyShardedDataParallel", optimizer: Optimizer, clip_val: Union[float, int] + ) -> None: + """Clip gradients by value.""" + + raise NotImplementedError( + "FSDP currently does not support to clip gradients by value. " + "Consider clipping by norm instead or choose another strategy!" + ) + @classmethod def register_strategies(cls, strategy_registry: Dict) -> None: if not _TORCH_GREATER_EQUAL_1_12 or not torch.distributed.is_available(): diff --git a/src/lightning/fabric/strategies/strategy.py b/src/lightning/fabric/strategies/strategy.py index 53f3d58f60..afffaf21d0 100644 --- a/src/lightning/fabric/strategies/strategy.py +++ b/src/lightning/fabric/strategies/strategy.py @@ -305,6 +305,27 @@ class Strategy(ABC): self.accelerator.teardown() self.checkpoint_io.teardown() + def clip_gradients_norm( + self, + module: torch.nn.Module, + optimizer: Optimizer, + max_norm: Union[float, int], + norm_type: Union[float, int] = 2.0, + error_if_nonfinite: bool = True, + ) -> torch.Tensor: + """Clip gradients by norm.""" + self.precision.unscale_gradients(optimizer) + parameters = self.precision.main_params(optimizer) + return torch.nn.utils.clip_grad_norm_( + parameters, max_norm=max_norm, norm_type=norm_type, error_if_nonfinite=error_if_nonfinite + ) + + def clip_gradients_value(self, module: torch.nn.Module, optimizer: Optimizer, clip_val: Union[float, int]) -> None: + """Clip gradients by value.""" + self.precision.unscale_gradients(optimizer) + parameters = self.precision.main_params(optimizer) + return torch.nn.utils.clip_grad_value_(parameters, clip_value=clip_val) + @classmethod def register_strategies(cls, strategy_registry: Dict[str, Any]) -> None: pass diff --git a/src/lightning/pytorch/plugins/precision/amp.py b/src/lightning/pytorch/plugins/precision/amp.py index 1ac94415e3..4b22981375 100644 --- a/src/lightning/pytorch/plugins/precision/amp.py +++ b/src/lightning/pytorch/plugins/precision/amp.py @@ -18,6 +18,7 @@ from torch.optim import LBFGS, Optimizer import lightning.pytorch as pl from lightning.fabric.accelerators.cuda import _patch_cuda_is_available +from lightning.fabric.plugins.precision.amp import _optimizer_handles_unscaling from lightning.fabric.utilities.types import Optimizable from lightning.pytorch.plugins.precision.precision_plugin import PrecisionPlugin from lightning.pytorch.utilities import GradClipAlgorithmType @@ -116,13 +117,3 @@ class MixedPrecisionPlugin(PrecisionPlugin): def load_state_dict(self, state_dict: Dict[str, Any]) -> None: if self.scaler is not None: self.scaler.load_state_dict(state_dict) - - -def _optimizer_handles_unscaling(optimizer: Any) -> bool: - """Determines whether a PyTorch optimizer handles unscaling gradients in the step method rather than through the - :class:`torch.cuda.amp.GradScaler`. - - Since, the current implementation of this function checks a PyTorch internal variable on the optimizer, the return - value will only be reliable for built-in PyTorch optimizers. - """ - return getattr(optimizer, "_step_supports_amp_scaling", False) diff --git a/tests/tests_fabric/helpers/models.py b/tests/tests_fabric/helpers/models.py index 5122e0c967..bc40401d09 100644 --- a/tests/tests_fabric/helpers/models.py +++ b/tests/tests_fabric/helpers/models.py @@ -48,7 +48,7 @@ class BoringFabric(Fabric): loss = torch.nn.functional.mse_loss(output, torch.ones_like(output)) return loss - def after_backward(self, model: Module) -> None: + def after_backward(self, model: Module, optimizer: Optimizer) -> None: pass def after_optimizer_step(self, model: Module, optimizer: Optimizer) -> None: @@ -77,7 +77,7 @@ class BoringFabric(Fabric): batch = next(data_iter) loss = self.step(model, batch) self.backward(loss) - self.after_backward(model) + self.after_backward(model, optimizer) optimizer.step() self.after_optimizer_step(model, optimizer) optimizer.zero_grad() diff --git a/tests/tests_fabric/plugins/precision/test_amp_integration.py b/tests/tests_fabric/plugins/precision/test_amp_integration.py index b462dbb837..0b86fff7f4 100644 --- a/tests/tests_fabric/plugins/precision/test_amp_integration.py +++ b/tests/tests_fabric/plugins/precision/test_amp_integration.py @@ -54,7 +54,7 @@ class MixedPrecisionBoringFabric(BoringFabric): loss = torch.nn.functional.mse_loss(output, torch.ones_like(output)) return loss - def after_backward(self, model): + def after_backward(self, model, optimizer): assert model.layer.weight.grad.dtype == torch.float32 diff --git a/tests/tests_fabric/plugins/precision/test_double_integration.py b/tests/tests_fabric/plugins/precision/test_double_integration.py index 012d1f3962..0436ae6763 100644 --- a/tests/tests_fabric/plugins/precision/test_double_integration.py +++ b/tests/tests_fabric/plugins/precision/test_double_integration.py @@ -45,7 +45,7 @@ class DoublePrecisionBoringFabric(BoringFabric): loss = torch.nn.functional.mse_loss(output, torch.ones_like(output)) return loss - def after_backward(self, model): + def after_backward(self, model, optimizer): assert model.layer.weight.grad.dtype == torch.float64 diff --git a/tests/tests_fabric/strategies/test_ddp.py b/tests/tests_fabric/strategies/test_ddp.py index df9b302f13..4103e1504f 100644 --- a/tests/tests_fabric/strategies/test_ddp.py +++ b/tests/tests_fabric/strategies/test_ddp.py @@ -16,6 +16,8 @@ from unittest.mock import MagicMock, Mock import pytest import torch +from tests_fabric.helpers.runif import RunIf +from tests_fabric.strategies.test_single_device import _MyFabricGradNorm, _MyFabricGradVal from torch.nn.parallel import DistributedDataParallel from lightning.fabric.strategies import DDPStrategy @@ -99,3 +101,28 @@ def test_ddp_module_state_dict(): with mock.patch("lightning.fabric.strategies.ddp.DistributedDataParallel", DistributedDataParallelMock): wrapped_module = strategy.setup_module(original_module) assert strategy.get_module_state_dict(wrapped_module).keys() == original_module.state_dict().keys() + + +@pytest.mark.parametrize( + "clip_type,accelerator,precision", + [ + ("norm", "cpu", "32-true"), + ("val", "cpu", "32-true"), + ("norm", "cpu", "bf16-mixed"), + ("val", "cpu", "bf16-mixed"), + pytest.param("norm", "cuda", "32-true", marks=RunIf(min_cuda_gpus=2)), + pytest.param("val", "cuda", "32-true", marks=RunIf(min_cuda_gpus=2)), + pytest.param("norm", "cuda", "16-mixed", marks=RunIf(min_cuda_gpus=2)), + pytest.param("val", "cuda", "16-mixed", marks=RunIf(min_cuda_gpus=2)), + pytest.param("norm", "cuda", "bf16-mixed", marks=RunIf(min_cuda_gpus=2, bf16_cuda=True)), + pytest.param("val", "cuda", "bf16-mixed", marks=RunIf(min_cuda_gpus=2, bf16_cuda=True)), + ], +) +@RunIf(standalone=True) +def test_ddp_grad_clipping(clip_type, accelerator, precision): + if clip_type == "norm": + clipping_test_cls = _MyFabricGradNorm + else: + clipping_test_cls = _MyFabricGradVal + fabric = clipping_test_cls(accelerator=accelerator, devices=2, precision=precision, strategy="ddp") + fabric.run() diff --git a/tests/tests_fabric/strategies/test_deepspeed.py b/tests/tests_fabric/strategies/test_deepspeed.py index 471d4903a5..185e206aed 100644 --- a/tests/tests_fabric/strategies/test_deepspeed.py +++ b/tests/tests_fabric/strategies/test_deepspeed.py @@ -321,3 +321,25 @@ def test_deepspeed_load_checkpoint_optimzer_state_requested(optimzer_state_reque load_lr_scheduler_states=False, load_module_strict=True, ) + + +@RunIf(deepspeed=True) +def test_errors_grad_clipping(): + strategy = DeepSpeedStrategy() + with pytest.raises( + NotImplementedError, + match=( + "DeepSpeed handles gradient clipping automatically within the optimizer. " + "Make sure to set the `gradient_clipping` value in your Config." + ), + ): + strategy.clip_gradients_norm(Mock(), Mock(), Mock(), Mock(), Mock()) + + with pytest.raises( + NotImplementedError, + match=( + "DeepSpeed handles gradient clipping automatically within the optimizer. " + "Make sure to set the `gradient_clipping` value in your Config." + ), + ): + strategy.clip_gradients_value(Mock(), Mock(), Mock()) diff --git a/tests/tests_fabric/strategies/test_dp.py b/tests/tests_fabric/strategies/test_dp.py index 1295164534..dfc336f735 100644 --- a/tests/tests_fabric/strategies/test_dp.py +++ b/tests/tests_fabric/strategies/test_dp.py @@ -14,7 +14,10 @@ from unittest import mock from unittest.mock import MagicMock, Mock +import pytest import torch +from tests_fabric.helpers.runif import RunIf +from tests_fabric.strategies.test_single_device import _MyFabricGradNorm, _MyFabricGradVal from lightning.fabric.strategies import DataParallelStrategy @@ -68,3 +71,22 @@ def test_dp_module_state_dict(): with mock.patch("lightning.fabric.strategies.dp.DataParallel", DataParallelMock): wrapped_module = strategy.setup_module(original_module) assert strategy.get_module_state_dict(wrapped_module).keys() == original_module.state_dict().keys() + + +@pytest.mark.parametrize( + "precision", + [ + "32-true", + "16-mixed", + pytest.param("bf16-mixed", marks=RunIf(bf16_cuda=True)), + ], +) +@pytest.mark.parametrize("clip_type", ["norm", "val"]) +@RunIf(min_cuda_gpus=2) +def test_dp_grad_clipping(clip_type, precision): + if clip_type == "norm": + clipping_test_cls = _MyFabricGradNorm + else: + clipping_test_cls = _MyFabricGradVal + fabric = clipping_test_cls(accelerator="cuda", devices=2, precision=precision, strategy="dp") + fabric.run() diff --git a/tests/tests_fabric/strategies/test_fsdp.py b/tests/tests_fabric/strategies/test_fsdp.py index 0f228d8632..34e2cc73ad 100644 --- a/tests/tests_fabric/strategies/test_fsdp.py +++ b/tests/tests_fabric/strategies/test_fsdp.py @@ -27,6 +27,7 @@ from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_1_12 if _TORCH_GREATER_EQUAL_1_12: from torch.distributed.fsdp.fully_sharded_data_parallel import CPUOffload, FullyShardedDataParallel, MixedPrecision +from tests_fabric.strategies.test_single_device import _MyFabricGradNorm @mock.patch("lightning.fabric.strategies.fsdp._TORCH_GREATER_EQUAL_1_12", False) @@ -131,3 +132,40 @@ def test_fsdp_activation_checkpointing(): ) as ckpt_mock: strategy.setup_module(Model()) ckpt_mock.assert_called_with(fsdp_mock(), checkpoint_wrapper_fn=ANY, check_fn=ANY) + + +@RunIf(min_torch="1.13") +def test_fsdp_grad_clipping_value_error(): + strategy = FSDPStrategy() + with pytest.raises( + NotImplementedError, + match=( + "FSDP currently does not support to clip gradients by value. " + "Consider clipping by norm instead or choose another strategy!" + ), + ): + strategy.clip_gradients_value(Mock(), Mock(), Mock()) + + +class _MyFSDPFabricGradientNorm(_MyFabricGradNorm): + def after_backward(self, model, optimizer): + self.clip_gradients(model, optimizer, max_norm=0.05, error_if_nonfinite=True) + + with model._forward_module.summon_full_params(model._forward_module): + parameters = model.parameters() + grad_norm = torch.linalg.vector_norm( + torch.stack([torch.linalg.vector_norm(p.grad.detach(), 2, dtype=torch.float32) for p in parameters]), + 2, + ) + torch.testing.assert_close(grad_norm, torch.tensor(0.05, device=self.device)) + + +@pytest.mark.parametrize( + "precision", + ["32-true", "16-mixed", pytest.param("bf16-mixed", marks=RunIf(bf16_cuda=True))], +) +@RunIf(min_cuda_gpus=2, standalone=True) +@pytest.mark.xfail(reason="Testing with FSDP is not yet correct") # TODO: Investigate testing with fsdp +def test_fsdp_grad_clipping_norm(precision): + fabric = _MyFSDPFabricGradientNorm(accelerator="cuda", devices=2, precision=precision, strategy="fsdp") + fabric.run() diff --git a/tests/tests_fabric/strategies/test_single_device.py b/tests/tests_fabric/strategies/test_single_device.py index 4353678957..2572dd50c4 100644 --- a/tests/tests_fabric/strategies/test_single_device.py +++ b/tests/tests_fabric/strategies/test_single_device.py @@ -15,8 +15,11 @@ from unittest.mock import Mock import pytest import torch +from tests_fabric.helpers.models import BoringFabric +from tests_fabric.helpers.runif import RunIf from lightning.fabric.strategies import SingleDeviceStrategy +from lightning.fabric.wrappers import _FabricModule, _FabricOptimizer def test_single_device_default_device(): @@ -52,3 +55,98 @@ def test_single_device_module_to_device(): module = Mock(spec=torch.nn.Module) strategy.module_to_device(module) module.to.assert_called_with(strategy.root_device) + + +class _MyFabricGradNorm(BoringFabric): + def after_backward(self, model: _FabricModule, optimizer: _FabricOptimizer): + self.clip_gradients(model, optimizer, max_norm=0.05, error_if_nonfinite=True) + + parameters = model.parameters() + grad_norm = torch.linalg.vector_norm( + torch.stack([torch.linalg.vector_norm(p.grad.detach(), 2, dtype=torch.float32) for p in parameters]), + 2, + ) + torch.testing.assert_close(grad_norm, torch.tensor(0.05, device=self.device)) + + def run(self): + # 10 retries + i = 0 + while True: + try: + super().run() + break + except RuntimeError as e: + # nonfinite grads -> skip and continue + # this may repeat until the scaler finds a factor where overflow is avoided, + # so the while loop should eventually break + # stop after a max of 10 tries + if i > 10 or not str(e).startswith("The total norm"): + raise e + + # unscale was already called by last attempt, + # but no update afterwards since optimizer step was missing. + # Manually update here -> Need to update inf stats first. + scaler = getattr(self._precision, "scaler", None) + if scaler is not None: + scaler._check_inf_per_device(self.optimizer) + scaler.update() + finally: + i += 1 + + +class _MyFabricGradVal(BoringFabric): + def after_backward(self, model, optimizer): + for p in model.parameters(): + if p.grad is not None and torch.isnan(p.grad).any().item() or torch.isinf(p.grad).any().item(): + raise RuntimeError("Nonfinite grads") + + self.clip_gradients(model, optimizer, clip_val=1e-10) + + parameters = model.parameters() + grad_max_list = [torch.max(p.grad.detach().abs()) for p in parameters] + grad_max = torch.max(torch.stack(grad_max_list)) + torch.testing.assert_close(grad_max.abs(), torch.tensor(1e-10, device=self.device)) + print("done") + + def run(self): + # 10 retries + i = 0 + while True: + try: + super().run() + break + except RuntimeError as e: + # nonfinite grads -> skip and continue + # this may repeat until the scaler finds a factor where overflow is avoided, + # so the while loop should eventually break + # stop after a max of 10 tries + if i > 10 or not str(e).startswith("Nonfinite grads"): + raise e + + # unscale was already called by last attempt, + # but no update afterwards since optimizer step was missing. + # Manually update here -> Need to update inf stats first. + scaler = getattr(self._precision, "scaler", None) + if scaler is not None: + scaler._check_inf_per_device(self.optimizer) + scaler.update() + finally: + i += 1 + + +@pytest.mark.parametrize( + "precision", + [ + "32-true", + pytest.param("16-mixed", marks=RunIf(min_cuda_gpus=1)), + pytest.param("bf16-mixed", marks=RunIf(bf16_cuda=True)), + ], +) +@pytest.mark.parametrize("clip_type", ["norm", "val"]) +def test_single_device_grad_clipping(clip_type, precision): + if clip_type == "norm": + clipping_test_cls = _MyFabricGradNorm + else: + clipping_test_cls = _MyFabricGradVal + fabric = clipping_test_cls(accelerator="auto", devices=1, precision=precision) + fabric.run() diff --git a/tests/tests_fabric/test_fabric.py b/tests/tests_fabric/test_fabric.py index a57d9907a8..993183b3f1 100644 --- a/tests/tests_fabric/test_fabric.py +++ b/tests/tests_fabric/test_fabric.py @@ -893,3 +893,33 @@ def test_all_reduce(): # dict fabric.all_reduce({"a": torch.tensor(4), "b": [torch.tensor(5)], "c": "string"}) fabric._strategy.all_reduce.assert_has_calls([call(torch.tensor(4), **defaults), call(torch.tensor(5), **defaults)]) + + +@pytest.mark.parametrize("clip_val,max_norm", [(1e-3, None), (None, 1)]) +def test_grad_clipping(clip_val, max_norm): + fabric = Fabric() + + fabric.strategy.clip_gradients_norm = Mock() + fabric.strategy.clip_gradients_value = Mock() + + torch_model = nn.Linear(1, 1) + torch_optimizer = torch.optim.SGD(torch_model.parameters(), lr=1e-3) + + model, optimizer = fabric.setup(torch_model, torch_optimizer) + + loss = model(torch.rand(1, 1).to(fabric.device)) + fabric.backward(loss) + + fabric.strategy.clip_gradients_value.assert_not_called() + fabric.strategy.clip_gradients_norm.assert_not_called() + + fabric.clip_gradients(model, optimizer, max_norm=max_norm, clip_val=clip_val) + + if clip_val is not None: + fabric.strategy.clip_gradients_value.assert_called_once_with(torch_model, torch_optimizer, clip_val=clip_val) + fabric.strategy.clip_gradients_norm.assert_not_called() + else: + fabric.strategy.clip_gradients_value.assert_not_called() + fabric.strategy.clip_gradients_norm.assert_called_once_with( + torch_model, torch_optimizer, max_norm=max_norm, norm_type=2.0, error_if_nonfinite=True + )