Handle `set_to_none` when using DeepSpeed optimizer in Lite (#16275)

This commit is contained in:
Adrian Wälchli 2023-01-09 15:01:11 +01:00 committed by GitHub
parent b195b7c116
commit c656307127
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 50 additions and 1 deletions

View File

@ -31,6 +31,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Added support for managing callbacks via `Fabric(callbacks=...)` and emitting events through `Fabric.call()` ([#16074](https://github.com/Lightning-AI/lightning/issues/16074))
- Added support for a consistent `.zero_grad(set_to_none=...)` on the wrapped optimizer regardless of which strategy is used ([#16275](https://github.com/Lightning-AI/lightning/issues/16275))
### Changed
- Renamed the class `LightningLite` to `Fabric` ([#15932](https://github.com/Lightning-AI/lightning/issues/15932), [#15938](https://github.com/Lightning-AI/lightning/issues/15938))

View File

@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import inspect
from typing import Any, Callable, Dict, Generator, Iterator, Mapping, Optional, overload, TypeVar, Union
import torch
@ -44,7 +45,9 @@ class _FabricOptimizer:
"""
# `__del__` is skipped in case the optimizer has implemented custom destructor logic which we would
# not want to call on destruction of the `_FabricOptimizer
self.__dict__ = {k: v for k, v in optimizer.__dict__.items() if k not in ("state_dict", "step", "__del__")}
self.__dict__ = {
k: v for k, v in optimizer.__dict__.items() if k not in ("state_dict", "step", "zero_grad", "__del__")
}
self.__class__ = type("Fabric" + optimizer.__class__.__name__, (self.__class__, optimizer.__class__), {})
self._optimizer = optimizer
self._strategy = strategy
@ -68,6 +71,10 @@ class _FabricOptimizer:
**kwargs,
)
def zero_grad(self, **kwargs: Any) -> None:
kwargs = _process_optimizer_zero_grad_kwargs(self.optimizer, kwargs)
self.optimizer.zero_grad(**kwargs)
class _FabricModule(_DeviceDtypeModuleMixin):
def __init__(
@ -175,3 +182,10 @@ class _FabricDataLoader:
for item in iterator:
yield move_data_to_device(item, self._device)
def _process_optimizer_zero_grad_kwargs(optimizer: Optimizer, kwargs: Dict[str, Any]) -> Dict[str, Any]:
if "set_to_none" in kwargs and "set_grads_to_None" in inspect.signature(optimizer.zero_grad).parameters:
# Some optimizers out there, for example DeepSpeedZeroOptimizer, use a different name than PyTorch
kwargs["set_grads_to_None"] = kwargs.pop("set_to_none")
return kwargs

View File

@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from unittest import mock
from unittest.mock import call, Mock
import pytest
@ -291,3 +292,34 @@ def test_lite_optimizer_steps():
lite_optimizer = _FabricOptimizer(optimizer=optimizer, strategy=strategy)
lite_optimizer.step()
strategy.optimizer_step.assert_called_once_with(strategy.model)
def test_fabric_optimizer_zero_grad_kwargs():
"""Test that Fabric can adapt the `.zero_grad()` arguments to the underlying optimizer."""
# Test PyTorch's standard `.zero_grad()` signature
with mock.patch("torch.optim.SGD.zero_grad") as zero_grad_mock:
optimizer = torch.optim.SGD(torch.nn.Linear(1, 1).parameters(), 0.1)
fabric_optimizer = _FabricOptimizer(optimizer=optimizer, strategy=Mock())
fabric_optimizer.zero_grad()
zero_grad_mock.assert_called_with()
fabric_optimizer.zero_grad(set_to_none=False)
zero_grad_mock.assert_called_with(set_to_none=False)
fabric_optimizer.zero_grad(set_to_none=True)
zero_grad_mock.assert_called_with(set_to_none=True)
# Test weird `.zero_grad()` signatures from other libraries
custom_zero_grad = Mock()
class CustomSGD(torch.optim.SGD):
def zero_grad(self, set_grads_to_None=False):
custom_zero_grad(set_grads_to_None=set_grads_to_None)
optimizer = CustomSGD(torch.nn.Linear(1, 1).parameters(), 0.1)
fabric_optimizer = _FabricOptimizer(optimizer=optimizer, strategy=Mock())
fabric_optimizer.zero_grad()
custom_zero_grad.assert_called_with(set_grads_to_None=False)
fabric_optimizer.zero_grad(set_to_none=False)
custom_zero_grad.assert_called_with(set_grads_to_None=False)
fabric_optimizer.zero_grad(set_to_none=True)
custom_zero_grad.assert_called_with(set_grads_to_None=True)