Handle `set_to_none` when using DeepSpeed optimizer in Lite (#16275)
This commit is contained in:
parent
b195b7c116
commit
c656307127
|
@ -31,6 +31,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
|
|||
- Added support for managing callbacks via `Fabric(callbacks=...)` and emitting events through `Fabric.call()` ([#16074](https://github.com/Lightning-AI/lightning/issues/16074))
|
||||
|
||||
|
||||
- Added support for a consistent `.zero_grad(set_to_none=...)` on the wrapped optimizer regardless of which strategy is used ([#16275](https://github.com/Lightning-AI/lightning/issues/16275))
|
||||
|
||||
|
||||
### Changed
|
||||
|
||||
- Renamed the class `LightningLite` to `Fabric` ([#15932](https://github.com/Lightning-AI/lightning/issues/15932), [#15938](https://github.com/Lightning-AI/lightning/issues/15938))
|
||||
|
|
|
@ -11,6 +11,7 @@
|
|||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import inspect
|
||||
from typing import Any, Callable, Dict, Generator, Iterator, Mapping, Optional, overload, TypeVar, Union
|
||||
|
||||
import torch
|
||||
|
@ -44,7 +45,9 @@ class _FabricOptimizer:
|
|||
"""
|
||||
# `__del__` is skipped in case the optimizer has implemented custom destructor logic which we would
|
||||
# not want to call on destruction of the `_FabricOptimizer
|
||||
self.__dict__ = {k: v for k, v in optimizer.__dict__.items() if k not in ("state_dict", "step", "__del__")}
|
||||
self.__dict__ = {
|
||||
k: v for k, v in optimizer.__dict__.items() if k not in ("state_dict", "step", "zero_grad", "__del__")
|
||||
}
|
||||
self.__class__ = type("Fabric" + optimizer.__class__.__name__, (self.__class__, optimizer.__class__), {})
|
||||
self._optimizer = optimizer
|
||||
self._strategy = strategy
|
||||
|
@ -68,6 +71,10 @@ class _FabricOptimizer:
|
|||
**kwargs,
|
||||
)
|
||||
|
||||
def zero_grad(self, **kwargs: Any) -> None:
|
||||
kwargs = _process_optimizer_zero_grad_kwargs(self.optimizer, kwargs)
|
||||
self.optimizer.zero_grad(**kwargs)
|
||||
|
||||
|
||||
class _FabricModule(_DeviceDtypeModuleMixin):
|
||||
def __init__(
|
||||
|
@ -175,3 +182,10 @@ class _FabricDataLoader:
|
|||
|
||||
for item in iterator:
|
||||
yield move_data_to_device(item, self._device)
|
||||
|
||||
|
||||
def _process_optimizer_zero_grad_kwargs(optimizer: Optimizer, kwargs: Dict[str, Any]) -> Dict[str, Any]:
|
||||
if "set_to_none" in kwargs and "set_grads_to_None" in inspect.signature(optimizer.zero_grad).parameters:
|
||||
# Some optimizers out there, for example DeepSpeedZeroOptimizer, use a different name than PyTorch
|
||||
kwargs["set_grads_to_None"] = kwargs.pop("set_to_none")
|
||||
return kwargs
|
||||
|
|
|
@ -11,6 +11,7 @@
|
|||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from unittest import mock
|
||||
from unittest.mock import call, Mock
|
||||
|
||||
import pytest
|
||||
|
@ -291,3 +292,34 @@ def test_lite_optimizer_steps():
|
|||
lite_optimizer = _FabricOptimizer(optimizer=optimizer, strategy=strategy)
|
||||
lite_optimizer.step()
|
||||
strategy.optimizer_step.assert_called_once_with(strategy.model)
|
||||
|
||||
|
||||
def test_fabric_optimizer_zero_grad_kwargs():
|
||||
"""Test that Fabric can adapt the `.zero_grad()` arguments to the underlying optimizer."""
|
||||
|
||||
# Test PyTorch's standard `.zero_grad()` signature
|
||||
with mock.patch("torch.optim.SGD.zero_grad") as zero_grad_mock:
|
||||
optimizer = torch.optim.SGD(torch.nn.Linear(1, 1).parameters(), 0.1)
|
||||
fabric_optimizer = _FabricOptimizer(optimizer=optimizer, strategy=Mock())
|
||||
fabric_optimizer.zero_grad()
|
||||
zero_grad_mock.assert_called_with()
|
||||
fabric_optimizer.zero_grad(set_to_none=False)
|
||||
zero_grad_mock.assert_called_with(set_to_none=False)
|
||||
fabric_optimizer.zero_grad(set_to_none=True)
|
||||
zero_grad_mock.assert_called_with(set_to_none=True)
|
||||
|
||||
# Test weird `.zero_grad()` signatures from other libraries
|
||||
custom_zero_grad = Mock()
|
||||
|
||||
class CustomSGD(torch.optim.SGD):
|
||||
def zero_grad(self, set_grads_to_None=False):
|
||||
custom_zero_grad(set_grads_to_None=set_grads_to_None)
|
||||
|
||||
optimizer = CustomSGD(torch.nn.Linear(1, 1).parameters(), 0.1)
|
||||
fabric_optimizer = _FabricOptimizer(optimizer=optimizer, strategy=Mock())
|
||||
fabric_optimizer.zero_grad()
|
||||
custom_zero_grad.assert_called_with(set_grads_to_None=False)
|
||||
fabric_optimizer.zero_grad(set_to_none=False)
|
||||
custom_zero_grad.assert_called_with(set_grads_to_None=False)
|
||||
fabric_optimizer.zero_grad(set_to_none=True)
|
||||
custom_zero_grad.assert_called_with(set_grads_to_None=True)
|
||||
|
|
Loading…
Reference in New Issue