diff --git a/src/lightning_fabric/CHANGELOG.md b/src/lightning_fabric/CHANGELOG.md index 7312f08c29..afb900294f 100644 --- a/src/lightning_fabric/CHANGELOG.md +++ b/src/lightning_fabric/CHANGELOG.md @@ -31,6 +31,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Added support for managing callbacks via `Fabric(callbacks=...)` and emitting events through `Fabric.call()` ([#16074](https://github.com/Lightning-AI/lightning/issues/16074)) +- Added support for a consistent `.zero_grad(set_to_none=...)` on the wrapped optimizer regardless of which strategy is used ([#16275](https://github.com/Lightning-AI/lightning/issues/16275)) + + ### Changed - Renamed the class `LightningLite` to `Fabric` ([#15932](https://github.com/Lightning-AI/lightning/issues/15932), [#15938](https://github.com/Lightning-AI/lightning/issues/15938)) diff --git a/src/lightning_fabric/wrappers.py b/src/lightning_fabric/wrappers.py index 05c144fdc9..75ea47fbf6 100644 --- a/src/lightning_fabric/wrappers.py +++ b/src/lightning_fabric/wrappers.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import inspect from typing import Any, Callable, Dict, Generator, Iterator, Mapping, Optional, overload, TypeVar, Union import torch @@ -44,7 +45,9 @@ class _FabricOptimizer: """ # `__del__` is skipped in case the optimizer has implemented custom destructor logic which we would # not want to call on destruction of the `_FabricOptimizer - self.__dict__ = {k: v for k, v in optimizer.__dict__.items() if k not in ("state_dict", "step", "__del__")} + self.__dict__ = { + k: v for k, v in optimizer.__dict__.items() if k not in ("state_dict", "step", "zero_grad", "__del__") + } self.__class__ = type("Fabric" + optimizer.__class__.__name__, (self.__class__, optimizer.__class__), {}) self._optimizer = optimizer self._strategy = strategy @@ -68,6 +71,10 @@ class _FabricOptimizer: **kwargs, ) + def zero_grad(self, **kwargs: Any) -> None: + kwargs = _process_optimizer_zero_grad_kwargs(self.optimizer, kwargs) + self.optimizer.zero_grad(**kwargs) + class _FabricModule(_DeviceDtypeModuleMixin): def __init__( @@ -175,3 +182,10 @@ class _FabricDataLoader: for item in iterator: yield move_data_to_device(item, self._device) + + +def _process_optimizer_zero_grad_kwargs(optimizer: Optimizer, kwargs: Dict[str, Any]) -> Dict[str, Any]: + if "set_to_none" in kwargs and "set_grads_to_None" in inspect.signature(optimizer.zero_grad).parameters: + # Some optimizers out there, for example DeepSpeedZeroOptimizer, use a different name than PyTorch + kwargs["set_grads_to_None"] = kwargs.pop("set_to_none") + return kwargs diff --git a/tests/tests_fabric/test_wrappers.py b/tests/tests_fabric/test_wrappers.py index 7230681f2c..5a4dd5827c 100644 --- a/tests/tests_fabric/test_wrappers.py +++ b/tests/tests_fabric/test_wrappers.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from unittest import mock from unittest.mock import call, Mock import pytest @@ -291,3 +292,34 @@ def test_lite_optimizer_steps(): lite_optimizer = _FabricOptimizer(optimizer=optimizer, strategy=strategy) lite_optimizer.step() strategy.optimizer_step.assert_called_once_with(strategy.model) + + +def test_fabric_optimizer_zero_grad_kwargs(): + """Test that Fabric can adapt the `.zero_grad()` arguments to the underlying optimizer.""" + + # Test PyTorch's standard `.zero_grad()` signature + with mock.patch("torch.optim.SGD.zero_grad") as zero_grad_mock: + optimizer = torch.optim.SGD(torch.nn.Linear(1, 1).parameters(), 0.1) + fabric_optimizer = _FabricOptimizer(optimizer=optimizer, strategy=Mock()) + fabric_optimizer.zero_grad() + zero_grad_mock.assert_called_with() + fabric_optimizer.zero_grad(set_to_none=False) + zero_grad_mock.assert_called_with(set_to_none=False) + fabric_optimizer.zero_grad(set_to_none=True) + zero_grad_mock.assert_called_with(set_to_none=True) + + # Test weird `.zero_grad()` signatures from other libraries + custom_zero_grad = Mock() + + class CustomSGD(torch.optim.SGD): + def zero_grad(self, set_grads_to_None=False): + custom_zero_grad(set_grads_to_None=set_grads_to_None) + + optimizer = CustomSGD(torch.nn.Linear(1, 1).parameters(), 0.1) + fabric_optimizer = _FabricOptimizer(optimizer=optimizer, strategy=Mock()) + fabric_optimizer.zero_grad() + custom_zero_grad.assert_called_with(set_grads_to_None=False) + fabric_optimizer.zero_grad(set_to_none=False) + custom_zero_grad.assert_called_with(set_grads_to_None=False) + fabric_optimizer.zero_grad(set_to_none=True) + custom_zero_grad.assert_called_with(set_grads_to_None=True)