Adds Gradient Clipping to Fabric (#16715)
Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com> Co-authored-by: Carlos Mocholí <carlossmocholi@gmail.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Jirka Borovec <6035284+Borda@users.noreply.github.com>
This commit is contained in:
parent
8884c8970c
commit
3d1927e6bc
|
@ -363,6 +363,33 @@ class Fabric:
|
|||
|
||||
self._precision.backward(tensor, module, *args, **kwargs)
|
||||
|
||||
def clip_gradients(
|
||||
self,
|
||||
module: Union[torch.nn.Module, _FabricModule],
|
||||
optimizer: Union[Optimizer, _FabricOptimizer],
|
||||
clip_val: Optional[Union[float, int]] = None,
|
||||
max_norm: Optional[Union[float, int]] = None,
|
||||
norm_type: Union[float, int] = 2.0,
|
||||
error_if_nonfinite: bool = True,
|
||||
) -> Optional[torch.Tensor]:
|
||||
if clip_val is not None and max_norm is not None:
|
||||
raise ValueError(
|
||||
"Only one of `clip_val` or `max_norm` can be set as this specifies the underlying clipping algorithm!"
|
||||
)
|
||||
|
||||
if clip_val is not None:
|
||||
self.strategy.clip_gradients_value(_unwrap_objects(module), _unwrap_objects(optimizer), clip_val=clip_val)
|
||||
return None
|
||||
elif max_norm is not None:
|
||||
return self.strategy.clip_gradients_norm(
|
||||
_unwrap_objects(module),
|
||||
_unwrap_objects(optimizer),
|
||||
max_norm=max_norm,
|
||||
norm_type=norm_type,
|
||||
error_if_nonfinite=error_if_nonfinite,
|
||||
)
|
||||
raise ValueError("You have to specify either `clip_val` or `max_norm` to do gradient clipping!")
|
||||
|
||||
@contextmanager
|
||||
def autocast(self) -> Generator[None, None, None]:
|
||||
"""A context manager to automatically convert operations for the chosen precision.
|
||||
|
|
|
@ -17,7 +17,7 @@ from typing import Any, cast, Dict, Generator, Literal, Optional
|
|||
import torch
|
||||
from torch import Tensor
|
||||
from torch.nn import Module
|
||||
from torch.optim import LBFGS
|
||||
from torch.optim import LBFGS, Optimizer
|
||||
|
||||
from lightning.fabric.accelerators.cuda import _patch_cuda_is_available
|
||||
from lightning.fabric.plugins.precision.precision import Precision
|
||||
|
@ -93,3 +93,20 @@ class MixedPrecision(Precision):
|
|||
# the dtype could be automatically inferred but we need to manually set it due to a bug upstream
|
||||
# https://github.com/pytorch/pytorch/issues/67233
|
||||
return torch.autocast(self.device, dtype=torch.bfloat16 if self.precision == "bf16-mixed" else torch.half)
|
||||
|
||||
def unscale_gradients(self, optimizer: Optimizer) -> None:
|
||||
scaler = self.scaler
|
||||
if scaler is not None:
|
||||
if _optimizer_handles_unscaling(optimizer):
|
||||
raise NotImplementedError("Gradient clipping is not implemented for optimizers handling the unscaling.")
|
||||
scaler.unscale_(optimizer)
|
||||
|
||||
|
||||
def _optimizer_handles_unscaling(optimizer: Any) -> bool:
|
||||
"""Determines whether a PyTorch optimizer handles unscaling gradients in the step method rather than through the
|
||||
:class:`torch.cuda.amp.GradScaler`.
|
||||
|
||||
Since, the current implementation of this function checks a PyTorch internal variable on the optimizer, the return
|
||||
value will only be reliable for built-in PyTorch optimizers.
|
||||
"""
|
||||
return getattr(optimizer, "_step_supports_amp_scaling", False)
|
||||
|
|
|
@ -97,6 +97,9 @@ class Precision:
|
|||
for group in optimizer.param_groups:
|
||||
yield from group["params"]
|
||||
|
||||
def unscale_gradients(self, optimizer: Optimizer) -> None:
|
||||
return
|
||||
|
||||
def state_dict(self) -> Dict[str, Any]:
|
||||
"""Called when saving a checkpoint, implement to generate precision plugin state_dict.
|
||||
|
||||
|
|
|
@ -486,6 +486,27 @@ class DeepSpeedStrategy(DDPStrategy, _Sharded):
|
|||
state[k] = client_state.pop(k)
|
||||
return client_state
|
||||
|
||||
def clip_gradients_norm(
|
||||
self,
|
||||
module: "deepspeed.DeepSpeedEngine",
|
||||
optimizer: Optimizer,
|
||||
max_norm: Union[float, int],
|
||||
norm_type: Union[float, int] = 2.0,
|
||||
error_if_nonfinite: bool = True,
|
||||
) -> torch.Tensor:
|
||||
raise NotImplementedError(
|
||||
"DeepSpeed handles gradient clipping automatically within the optimizer. "
|
||||
"Make sure to set the `gradient_clipping` value in your Config."
|
||||
)
|
||||
|
||||
def clip_gradients_value(
|
||||
self, module: "deepspeed.DeepSpeedEngine", optimizer: Optimizer, clip_val: Union[float, int]
|
||||
) -> None:
|
||||
raise NotImplementedError(
|
||||
"DeepSpeed handles gradient clipping automatically within the optimizer. "
|
||||
"Make sure to set the `gradient_clipping` value in your Config."
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def register_strategies(cls, strategy_registry: Dict) -> None:
|
||||
strategy_registry.register("deepspeed", cls, description="Default DeepSpeed Strategy")
|
||||
|
|
|
@ -37,7 +37,7 @@ from lightning.fabric.utilities.distributed import (
|
|||
from lightning.fabric.utilities.distributed import group as _group
|
||||
from lightning.fabric.utilities.distributed import ReduceOp
|
||||
from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_1_12, _TORCH_GREATER_EQUAL_1_13
|
||||
from lightning.fabric.utilities.rank_zero import rank_zero_only
|
||||
from lightning.fabric.utilities.rank_zero import rank_zero_only, rank_zero_warn
|
||||
from lightning.fabric.utilities.seed import reset_seed
|
||||
|
||||
if TYPE_CHECKING:
|
||||
|
@ -268,6 +268,29 @@ class FSDPStrategy(ParallelStrategy, _Sharded):
|
|||
torch.distributed.broadcast_object_list(obj, src, group=_group.WORLD)
|
||||
return obj[0]
|
||||
|
||||
def clip_gradients_norm( # type: ignore[override]
|
||||
self,
|
||||
module: "FullyShardedDataParallel",
|
||||
optimizer: Optimizer,
|
||||
max_norm: Union[float, int],
|
||||
norm_type: Union[float, int] = 2.0,
|
||||
error_if_nonfinite: bool = True,
|
||||
) -> Tensor:
|
||||
"""Clip gradients by norm."""
|
||||
rank_zero_warn("Gradient Clipping by Norm is currently experimental for FSDP. Proceed with Caution!")
|
||||
self.precision.unscale_gradients(optimizer)
|
||||
return module.clip_grad_norm_(max_norm=max_norm, norm_type=norm_type) # type: ignore[return-value]
|
||||
|
||||
def clip_gradients_value( # type: ignore[override]
|
||||
self, module: "FullyShardedDataParallel", optimizer: Optimizer, clip_val: Union[float, int]
|
||||
) -> None:
|
||||
"""Clip gradients by value."""
|
||||
|
||||
raise NotImplementedError(
|
||||
"FSDP currently does not support to clip gradients by value. "
|
||||
"Consider clipping by norm instead or choose another strategy!"
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def register_strategies(cls, strategy_registry: Dict) -> None:
|
||||
if not _TORCH_GREATER_EQUAL_1_12 or not torch.distributed.is_available():
|
||||
|
|
|
@ -305,6 +305,27 @@ class Strategy(ABC):
|
|||
self.accelerator.teardown()
|
||||
self.checkpoint_io.teardown()
|
||||
|
||||
def clip_gradients_norm(
|
||||
self,
|
||||
module: torch.nn.Module,
|
||||
optimizer: Optimizer,
|
||||
max_norm: Union[float, int],
|
||||
norm_type: Union[float, int] = 2.0,
|
||||
error_if_nonfinite: bool = True,
|
||||
) -> torch.Tensor:
|
||||
"""Clip gradients by norm."""
|
||||
self.precision.unscale_gradients(optimizer)
|
||||
parameters = self.precision.main_params(optimizer)
|
||||
return torch.nn.utils.clip_grad_norm_(
|
||||
parameters, max_norm=max_norm, norm_type=norm_type, error_if_nonfinite=error_if_nonfinite
|
||||
)
|
||||
|
||||
def clip_gradients_value(self, module: torch.nn.Module, optimizer: Optimizer, clip_val: Union[float, int]) -> None:
|
||||
"""Clip gradients by value."""
|
||||
self.precision.unscale_gradients(optimizer)
|
||||
parameters = self.precision.main_params(optimizer)
|
||||
return torch.nn.utils.clip_grad_value_(parameters, clip_value=clip_val)
|
||||
|
||||
@classmethod
|
||||
def register_strategies(cls, strategy_registry: Dict[str, Any]) -> None:
|
||||
pass
|
||||
|
|
|
@ -18,6 +18,7 @@ from torch.optim import LBFGS, Optimizer
|
|||
|
||||
import lightning.pytorch as pl
|
||||
from lightning.fabric.accelerators.cuda import _patch_cuda_is_available
|
||||
from lightning.fabric.plugins.precision.amp import _optimizer_handles_unscaling
|
||||
from lightning.fabric.utilities.types import Optimizable
|
||||
from lightning.pytorch.plugins.precision.precision_plugin import PrecisionPlugin
|
||||
from lightning.pytorch.utilities import GradClipAlgorithmType
|
||||
|
@ -116,13 +117,3 @@ class MixedPrecisionPlugin(PrecisionPlugin):
|
|||
def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
|
||||
if self.scaler is not None:
|
||||
self.scaler.load_state_dict(state_dict)
|
||||
|
||||
|
||||
def _optimizer_handles_unscaling(optimizer: Any) -> bool:
|
||||
"""Determines whether a PyTorch optimizer handles unscaling gradients in the step method rather than through the
|
||||
:class:`torch.cuda.amp.GradScaler`.
|
||||
|
||||
Since, the current implementation of this function checks a PyTorch internal variable on the optimizer, the return
|
||||
value will only be reliable for built-in PyTorch optimizers.
|
||||
"""
|
||||
return getattr(optimizer, "_step_supports_amp_scaling", False)
|
||||
|
|
|
@ -48,7 +48,7 @@ class BoringFabric(Fabric):
|
|||
loss = torch.nn.functional.mse_loss(output, torch.ones_like(output))
|
||||
return loss
|
||||
|
||||
def after_backward(self, model: Module) -> None:
|
||||
def after_backward(self, model: Module, optimizer: Optimizer) -> None:
|
||||
pass
|
||||
|
||||
def after_optimizer_step(self, model: Module, optimizer: Optimizer) -> None:
|
||||
|
@ -77,7 +77,7 @@ class BoringFabric(Fabric):
|
|||
batch = next(data_iter)
|
||||
loss = self.step(model, batch)
|
||||
self.backward(loss)
|
||||
self.after_backward(model)
|
||||
self.after_backward(model, optimizer)
|
||||
optimizer.step()
|
||||
self.after_optimizer_step(model, optimizer)
|
||||
optimizer.zero_grad()
|
||||
|
|
|
@ -54,7 +54,7 @@ class MixedPrecisionBoringFabric(BoringFabric):
|
|||
loss = torch.nn.functional.mse_loss(output, torch.ones_like(output))
|
||||
return loss
|
||||
|
||||
def after_backward(self, model):
|
||||
def after_backward(self, model, optimizer):
|
||||
assert model.layer.weight.grad.dtype == torch.float32
|
||||
|
||||
|
||||
|
|
|
@ -45,7 +45,7 @@ class DoublePrecisionBoringFabric(BoringFabric):
|
|||
loss = torch.nn.functional.mse_loss(output, torch.ones_like(output))
|
||||
return loss
|
||||
|
||||
def after_backward(self, model):
|
||||
def after_backward(self, model, optimizer):
|
||||
assert model.layer.weight.grad.dtype == torch.float64
|
||||
|
||||
|
||||
|
|
|
@ -16,6 +16,8 @@ from unittest.mock import MagicMock, Mock
|
|||
|
||||
import pytest
|
||||
import torch
|
||||
from tests_fabric.helpers.runif import RunIf
|
||||
from tests_fabric.strategies.test_single_device import _MyFabricGradNorm, _MyFabricGradVal
|
||||
from torch.nn.parallel import DistributedDataParallel
|
||||
|
||||
from lightning.fabric.strategies import DDPStrategy
|
||||
|
@ -99,3 +101,28 @@ def test_ddp_module_state_dict():
|
|||
with mock.patch("lightning.fabric.strategies.ddp.DistributedDataParallel", DistributedDataParallelMock):
|
||||
wrapped_module = strategy.setup_module(original_module)
|
||||
assert strategy.get_module_state_dict(wrapped_module).keys() == original_module.state_dict().keys()
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"clip_type,accelerator,precision",
|
||||
[
|
||||
("norm", "cpu", "32-true"),
|
||||
("val", "cpu", "32-true"),
|
||||
("norm", "cpu", "bf16-mixed"),
|
||||
("val", "cpu", "bf16-mixed"),
|
||||
pytest.param("norm", "cuda", "32-true", marks=RunIf(min_cuda_gpus=2)),
|
||||
pytest.param("val", "cuda", "32-true", marks=RunIf(min_cuda_gpus=2)),
|
||||
pytest.param("norm", "cuda", "16-mixed", marks=RunIf(min_cuda_gpus=2)),
|
||||
pytest.param("val", "cuda", "16-mixed", marks=RunIf(min_cuda_gpus=2)),
|
||||
pytest.param("norm", "cuda", "bf16-mixed", marks=RunIf(min_cuda_gpus=2, bf16_cuda=True)),
|
||||
pytest.param("val", "cuda", "bf16-mixed", marks=RunIf(min_cuda_gpus=2, bf16_cuda=True)),
|
||||
],
|
||||
)
|
||||
@RunIf(standalone=True)
|
||||
def test_ddp_grad_clipping(clip_type, accelerator, precision):
|
||||
if clip_type == "norm":
|
||||
clipping_test_cls = _MyFabricGradNorm
|
||||
else:
|
||||
clipping_test_cls = _MyFabricGradVal
|
||||
fabric = clipping_test_cls(accelerator=accelerator, devices=2, precision=precision, strategy="ddp")
|
||||
fabric.run()
|
||||
|
|
|
@ -321,3 +321,25 @@ def test_deepspeed_load_checkpoint_optimzer_state_requested(optimzer_state_reque
|
|||
load_lr_scheduler_states=False,
|
||||
load_module_strict=True,
|
||||
)
|
||||
|
||||
|
||||
@RunIf(deepspeed=True)
|
||||
def test_errors_grad_clipping():
|
||||
strategy = DeepSpeedStrategy()
|
||||
with pytest.raises(
|
||||
NotImplementedError,
|
||||
match=(
|
||||
"DeepSpeed handles gradient clipping automatically within the optimizer. "
|
||||
"Make sure to set the `gradient_clipping` value in your Config."
|
||||
),
|
||||
):
|
||||
strategy.clip_gradients_norm(Mock(), Mock(), Mock(), Mock(), Mock())
|
||||
|
||||
with pytest.raises(
|
||||
NotImplementedError,
|
||||
match=(
|
||||
"DeepSpeed handles gradient clipping automatically within the optimizer. "
|
||||
"Make sure to set the `gradient_clipping` value in your Config."
|
||||
),
|
||||
):
|
||||
strategy.clip_gradients_value(Mock(), Mock(), Mock())
|
||||
|
|
|
@ -14,7 +14,10 @@
|
|||
from unittest import mock
|
||||
from unittest.mock import MagicMock, Mock
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
from tests_fabric.helpers.runif import RunIf
|
||||
from tests_fabric.strategies.test_single_device import _MyFabricGradNorm, _MyFabricGradVal
|
||||
|
||||
from lightning.fabric.strategies import DataParallelStrategy
|
||||
|
||||
|
@ -68,3 +71,22 @@ def test_dp_module_state_dict():
|
|||
with mock.patch("lightning.fabric.strategies.dp.DataParallel", DataParallelMock):
|
||||
wrapped_module = strategy.setup_module(original_module)
|
||||
assert strategy.get_module_state_dict(wrapped_module).keys() == original_module.state_dict().keys()
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"precision",
|
||||
[
|
||||
"32-true",
|
||||
"16-mixed",
|
||||
pytest.param("bf16-mixed", marks=RunIf(bf16_cuda=True)),
|
||||
],
|
||||
)
|
||||
@pytest.mark.parametrize("clip_type", ["norm", "val"])
|
||||
@RunIf(min_cuda_gpus=2)
|
||||
def test_dp_grad_clipping(clip_type, precision):
|
||||
if clip_type == "norm":
|
||||
clipping_test_cls = _MyFabricGradNorm
|
||||
else:
|
||||
clipping_test_cls = _MyFabricGradVal
|
||||
fabric = clipping_test_cls(accelerator="cuda", devices=2, precision=precision, strategy="dp")
|
||||
fabric.run()
|
||||
|
|
|
@ -27,6 +27,7 @@ from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_1_12
|
|||
|
||||
if _TORCH_GREATER_EQUAL_1_12:
|
||||
from torch.distributed.fsdp.fully_sharded_data_parallel import CPUOffload, FullyShardedDataParallel, MixedPrecision
|
||||
from tests_fabric.strategies.test_single_device import _MyFabricGradNorm
|
||||
|
||||
|
||||
@mock.patch("lightning.fabric.strategies.fsdp._TORCH_GREATER_EQUAL_1_12", False)
|
||||
|
@ -131,3 +132,40 @@ def test_fsdp_activation_checkpointing():
|
|||
) as ckpt_mock:
|
||||
strategy.setup_module(Model())
|
||||
ckpt_mock.assert_called_with(fsdp_mock(), checkpoint_wrapper_fn=ANY, check_fn=ANY)
|
||||
|
||||
|
||||
@RunIf(min_torch="1.13")
|
||||
def test_fsdp_grad_clipping_value_error():
|
||||
strategy = FSDPStrategy()
|
||||
with pytest.raises(
|
||||
NotImplementedError,
|
||||
match=(
|
||||
"FSDP currently does not support to clip gradients by value. "
|
||||
"Consider clipping by norm instead or choose another strategy!"
|
||||
),
|
||||
):
|
||||
strategy.clip_gradients_value(Mock(), Mock(), Mock())
|
||||
|
||||
|
||||
class _MyFSDPFabricGradientNorm(_MyFabricGradNorm):
|
||||
def after_backward(self, model, optimizer):
|
||||
self.clip_gradients(model, optimizer, max_norm=0.05, error_if_nonfinite=True)
|
||||
|
||||
with model._forward_module.summon_full_params(model._forward_module):
|
||||
parameters = model.parameters()
|
||||
grad_norm = torch.linalg.vector_norm(
|
||||
torch.stack([torch.linalg.vector_norm(p.grad.detach(), 2, dtype=torch.float32) for p in parameters]),
|
||||
2,
|
||||
)
|
||||
torch.testing.assert_close(grad_norm, torch.tensor(0.05, device=self.device))
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"precision",
|
||||
["32-true", "16-mixed", pytest.param("bf16-mixed", marks=RunIf(bf16_cuda=True))],
|
||||
)
|
||||
@RunIf(min_cuda_gpus=2, standalone=True)
|
||||
@pytest.mark.xfail(reason="Testing with FSDP is not yet correct") # TODO: Investigate testing with fsdp
|
||||
def test_fsdp_grad_clipping_norm(precision):
|
||||
fabric = _MyFSDPFabricGradientNorm(accelerator="cuda", devices=2, precision=precision, strategy="fsdp")
|
||||
fabric.run()
|
||||
|
|
|
@ -15,8 +15,11 @@ from unittest.mock import Mock
|
|||
|
||||
import pytest
|
||||
import torch
|
||||
from tests_fabric.helpers.models import BoringFabric
|
||||
from tests_fabric.helpers.runif import RunIf
|
||||
|
||||
from lightning.fabric.strategies import SingleDeviceStrategy
|
||||
from lightning.fabric.wrappers import _FabricModule, _FabricOptimizer
|
||||
|
||||
|
||||
def test_single_device_default_device():
|
||||
|
@ -52,3 +55,98 @@ def test_single_device_module_to_device():
|
|||
module = Mock(spec=torch.nn.Module)
|
||||
strategy.module_to_device(module)
|
||||
module.to.assert_called_with(strategy.root_device)
|
||||
|
||||
|
||||
class _MyFabricGradNorm(BoringFabric):
|
||||
def after_backward(self, model: _FabricModule, optimizer: _FabricOptimizer):
|
||||
self.clip_gradients(model, optimizer, max_norm=0.05, error_if_nonfinite=True)
|
||||
|
||||
parameters = model.parameters()
|
||||
grad_norm = torch.linalg.vector_norm(
|
||||
torch.stack([torch.linalg.vector_norm(p.grad.detach(), 2, dtype=torch.float32) for p in parameters]),
|
||||
2,
|
||||
)
|
||||
torch.testing.assert_close(grad_norm, torch.tensor(0.05, device=self.device))
|
||||
|
||||
def run(self):
|
||||
# 10 retries
|
||||
i = 0
|
||||
while True:
|
||||
try:
|
||||
super().run()
|
||||
break
|
||||
except RuntimeError as e:
|
||||
# nonfinite grads -> skip and continue
|
||||
# this may repeat until the scaler finds a factor where overflow is avoided,
|
||||
# so the while loop should eventually break
|
||||
# stop after a max of 10 tries
|
||||
if i > 10 or not str(e).startswith("The total norm"):
|
||||
raise e
|
||||
|
||||
# unscale was already called by last attempt,
|
||||
# but no update afterwards since optimizer step was missing.
|
||||
# Manually update here -> Need to update inf stats first.
|
||||
scaler = getattr(self._precision, "scaler", None)
|
||||
if scaler is not None:
|
||||
scaler._check_inf_per_device(self.optimizer)
|
||||
scaler.update()
|
||||
finally:
|
||||
i += 1
|
||||
|
||||
|
||||
class _MyFabricGradVal(BoringFabric):
|
||||
def after_backward(self, model, optimizer):
|
||||
for p in model.parameters():
|
||||
if p.grad is not None and torch.isnan(p.grad).any().item() or torch.isinf(p.grad).any().item():
|
||||
raise RuntimeError("Nonfinite grads")
|
||||
|
||||
self.clip_gradients(model, optimizer, clip_val=1e-10)
|
||||
|
||||
parameters = model.parameters()
|
||||
grad_max_list = [torch.max(p.grad.detach().abs()) for p in parameters]
|
||||
grad_max = torch.max(torch.stack(grad_max_list))
|
||||
torch.testing.assert_close(grad_max.abs(), torch.tensor(1e-10, device=self.device))
|
||||
print("done")
|
||||
|
||||
def run(self):
|
||||
# 10 retries
|
||||
i = 0
|
||||
while True:
|
||||
try:
|
||||
super().run()
|
||||
break
|
||||
except RuntimeError as e:
|
||||
# nonfinite grads -> skip and continue
|
||||
# this may repeat until the scaler finds a factor where overflow is avoided,
|
||||
# so the while loop should eventually break
|
||||
# stop after a max of 10 tries
|
||||
if i > 10 or not str(e).startswith("Nonfinite grads"):
|
||||
raise e
|
||||
|
||||
# unscale was already called by last attempt,
|
||||
# but no update afterwards since optimizer step was missing.
|
||||
# Manually update here -> Need to update inf stats first.
|
||||
scaler = getattr(self._precision, "scaler", None)
|
||||
if scaler is not None:
|
||||
scaler._check_inf_per_device(self.optimizer)
|
||||
scaler.update()
|
||||
finally:
|
||||
i += 1
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"precision",
|
||||
[
|
||||
"32-true",
|
||||
pytest.param("16-mixed", marks=RunIf(min_cuda_gpus=1)),
|
||||
pytest.param("bf16-mixed", marks=RunIf(bf16_cuda=True)),
|
||||
],
|
||||
)
|
||||
@pytest.mark.parametrize("clip_type", ["norm", "val"])
|
||||
def test_single_device_grad_clipping(clip_type, precision):
|
||||
if clip_type == "norm":
|
||||
clipping_test_cls = _MyFabricGradNorm
|
||||
else:
|
||||
clipping_test_cls = _MyFabricGradVal
|
||||
fabric = clipping_test_cls(accelerator="auto", devices=1, precision=precision)
|
||||
fabric.run()
|
||||
|
|
|
@ -893,3 +893,33 @@ def test_all_reduce():
|
|||
# dict
|
||||
fabric.all_reduce({"a": torch.tensor(4), "b": [torch.tensor(5)], "c": "string"})
|
||||
fabric._strategy.all_reduce.assert_has_calls([call(torch.tensor(4), **defaults), call(torch.tensor(5), **defaults)])
|
||||
|
||||
|
||||
@pytest.mark.parametrize("clip_val,max_norm", [(1e-3, None), (None, 1)])
|
||||
def test_grad_clipping(clip_val, max_norm):
|
||||
fabric = Fabric()
|
||||
|
||||
fabric.strategy.clip_gradients_norm = Mock()
|
||||
fabric.strategy.clip_gradients_value = Mock()
|
||||
|
||||
torch_model = nn.Linear(1, 1)
|
||||
torch_optimizer = torch.optim.SGD(torch_model.parameters(), lr=1e-3)
|
||||
|
||||
model, optimizer = fabric.setup(torch_model, torch_optimizer)
|
||||
|
||||
loss = model(torch.rand(1, 1).to(fabric.device))
|
||||
fabric.backward(loss)
|
||||
|
||||
fabric.strategy.clip_gradients_value.assert_not_called()
|
||||
fabric.strategy.clip_gradients_norm.assert_not_called()
|
||||
|
||||
fabric.clip_gradients(model, optimizer, max_norm=max_norm, clip_val=clip_val)
|
||||
|
||||
if clip_val is not None:
|
||||
fabric.strategy.clip_gradients_value.assert_called_once_with(torch_model, torch_optimizer, clip_val=clip_val)
|
||||
fabric.strategy.clip_gradients_norm.assert_not_called()
|
||||
else:
|
||||
fabric.strategy.clip_gradients_value.assert_not_called()
|
||||
fabric.strategy.clip_gradients_norm.assert_called_once_with(
|
||||
torch_model, torch_optimizer, max_norm=max_norm, norm_type=2.0, error_if_nonfinite=True
|
||||
)
|
||||
|
|
Loading…
Reference in New Issue