Meta device initialization for FSDP in Fabric (#18122)
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
parent
c5cb532694
commit
50e01c7012
|
@ -40,7 +40,10 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
|
|||
* Handles initialization for FSDP models before wrapping and the Zero stage 3 initialization for DeepSpeed before sharding
|
||||
|
||||
|
||||
- Added supports for empty weight initialization with `Fabric.init_module(empty_init=True)` for checkpoint loading ([#17627](https://github.com/Lightning-AI/lightning/pull/17627))
|
||||
- Added support for empty weight initialization with `Fabric.init_module(empty_init=True)` for checkpoint loading ([#17627](https://github.com/Lightning-AI/lightning/pull/17627))
|
||||
|
||||
|
||||
- Added support for meta-device initialization with `Fabric.init_module(empty_init=True)` in FSDP ([#18122](https://github.com/Lightning-AI/lightning/pull/18122))
|
||||
|
||||
|
||||
- Added `lightning.fabric.plugins.Precision.init_context()` and `lightning.fabric.strategies.Strategy.module_init_context()` context managers to control model and tensor instantiation ([#17462](https://github.com/Lightning-AI/lightning/pull/17462))
|
||||
|
|
|
@ -41,6 +41,7 @@ from lightning.fabric.strategies import (
|
|||
Strategy,
|
||||
XLAStrategy,
|
||||
)
|
||||
from lightning.fabric.strategies.fsdp import _has_meta_device_parameters
|
||||
from lightning.fabric.strategies.launchers import _MultiProcessingLauncher, _XLALauncher
|
||||
from lightning.fabric.strategies.strategy import _Sharded, TBroadcast
|
||||
from lightning.fabric.utilities import move_data_to_device
|
||||
|
@ -224,8 +225,9 @@ class Fabric:
|
|||
|
||||
module = _FabricModule(module, self._precision, original_module=original_module)
|
||||
|
||||
# Update the _DeviceDtypeModuleMixin's device parameter
|
||||
module.to(self.device if move_to_device else next(module.parameters(), torch.tensor(0)).device)
|
||||
if not _has_meta_device_parameters(module):
|
||||
# Update the _DeviceDtypeModuleMixin's device parameter
|
||||
module.to(self.device if move_to_device else next(module.parameters(), torch.tensor(0)).device)
|
||||
|
||||
optimizers = [
|
||||
_FabricOptimizer(optimizer=optimizer, strategy=self._strategy, callbacks=self._callbacks)
|
||||
|
@ -383,7 +385,7 @@ class Fabric:
|
|||
if isinstance(self._strategy, DeepSpeedStrategy):
|
||||
if model is None:
|
||||
if self._models_setup == 0:
|
||||
raise RuntimeError("No models were set up for backward. Did you forget to call `self.setup()`?")
|
||||
raise RuntimeError("No models were set up for backward. Did you forget to call `fabric.setup()`?")
|
||||
if self._models_setup > 1:
|
||||
raise ValueError(
|
||||
"When using multiple models + deepspeed, please provide the model used to perform"
|
||||
|
@ -588,14 +590,14 @@ class Fabric:
|
|||
Example::
|
||||
|
||||
# Accumulate gradient 8 batches at a time
|
||||
with self.no_backward_sync(model, enabled=(batch_idx % 8 != 0)):
|
||||
with fabric.no_backward_sync(model, enabled=(batch_idx % 8 != 0)):
|
||||
output = model(input)
|
||||
loss = ...
|
||||
self.backward(loss)
|
||||
fabric.backward(loss)
|
||||
...
|
||||
|
||||
For those strategies that don't support it, a warning is emitted. For single-device strategies, it is a no-op.
|
||||
Both the model's `.forward()` and the `self.backward()` call need to run under this context.
|
||||
Both the model's `.forward()` and the `fabric.backward()` call need to run under this context.
|
||||
|
||||
Args:
|
||||
module: The module for which to control the gradient synchronization.
|
||||
|
@ -605,8 +607,8 @@ class Fabric:
|
|||
module = _unwrap_compiled(module)
|
||||
if not isinstance(module, _FabricModule):
|
||||
raise TypeError(
|
||||
"You need to set up the model first before you can call `self.no_backward_sync()`:"
|
||||
" `model = self.setup(model, ...)`"
|
||||
"You need to set up the model first before you can call `fabric.no_backward_sync()`:"
|
||||
" `model = fabric.setup(model, ...)`"
|
||||
)
|
||||
if not enabled or isinstance(self._strategy, (SingleDeviceStrategy, XLAStrategy)):
|
||||
context = nullcontext()
|
||||
|
@ -956,12 +958,20 @@ class Fabric:
|
|||
if any(isinstance(opt, _FabricOptimizer) for opt in optimizers):
|
||||
raise ValueError("An optimizer should be passed only once to the `setup` method.")
|
||||
|
||||
if isinstance(self._strategy, FSDPStrategy) and not _TORCH_GREATER_EQUAL_2_0:
|
||||
raise RuntimeError(
|
||||
f"The `{type(self).__name__}` requires the model and optimizer(s) to be set up separately."
|
||||
" Create and set up the model first through `model = self.setup_module(model)`. Then create the"
|
||||
" optimizer and set it up: `optimizer = self.setup_optimizer(optimizer)`."
|
||||
)
|
||||
if isinstance(self._strategy, FSDPStrategy):
|
||||
if not _TORCH_GREATER_EQUAL_2_0:
|
||||
raise RuntimeError(
|
||||
f"The `{type(self).__name__}` requires the model and optimizer(s) to be set up separately."
|
||||
" Create and set up the model first through `model = fabric.setup_module(model)`. Then create the"
|
||||
" optimizer and set it up: `optimizer = fabric.setup_optimizers(optimizer)`."
|
||||
)
|
||||
if any(_has_meta_device_parameters(optimizer) for optimizer in optimizers):
|
||||
raise RuntimeError(
|
||||
"The optimizer has references to the model's meta-device parameters. Materializing them is"
|
||||
" is currently not supported unless you to set up the model and optimizer(s) separately."
|
||||
" Create and set up the model first through `model = fabric.setup_module(model)`. Then create the"
|
||||
" optimizer and set it up: `optimizer = fabric.setup_optimizers(optimizer)`."
|
||||
)
|
||||
|
||||
def _validate_setup_module(self, module: nn.Module) -> None:
|
||||
self._validate_launched()
|
||||
|
@ -982,6 +992,13 @@ class Fabric:
|
|||
if any(isinstance(opt, _FabricOptimizer) for opt in optimizers):
|
||||
raise ValueError("An optimizer should be passed only once to the `setup_optimizers` method.")
|
||||
|
||||
if any(_has_meta_device_parameters(optimizer) for optimizer in optimizers):
|
||||
raise RuntimeError(
|
||||
"The optimizer has references to the model's meta-device parameters. Materializing them is"
|
||||
" is currently not supported. Create the optimizer after setting up the model, then call"
|
||||
" `fabric.setup_optimizers(optimizer)`."
|
||||
)
|
||||
|
||||
def _validate_setup_dataloaders(self, dataloaders: Sequence[DataLoader]) -> None:
|
||||
self._validate_launched()
|
||||
if not dataloaders:
|
||||
|
|
|
@ -35,7 +35,7 @@ from typing import (
|
|||
|
||||
import torch
|
||||
from torch import Tensor
|
||||
from torch.nn import Module
|
||||
from torch.nn import Module, Parameter
|
||||
from torch.optim import Optimizer
|
||||
from typing_extensions import TypeGuard
|
||||
|
||||
|
@ -264,6 +264,11 @@ class FSDPStrategy(ParallelStrategy, _Sharded):
|
|||
from torch.distributed.fsdp import FullyShardedDataParallel
|
||||
|
||||
if any(isinstance(mod, FullyShardedDataParallel) for mod in module.modules()):
|
||||
# The user has wrapped their submodules manually, don't apply the auto wrap policy.
|
||||
if _has_meta_device_parameters(module):
|
||||
rank_zero_warn(
|
||||
"The model is already wrapped in `FSDP` but there are still parameters on the meta device."
|
||||
)
|
||||
if "auto_wrap_policy" in self._fsdp_kwargs:
|
||||
rank_zero_warn(
|
||||
"A FSDP `auto_wrap_policy` is set, but the model is already wrapped. The policy will be ignored."
|
||||
|
@ -317,9 +322,16 @@ class FSDPStrategy(ParallelStrategy, _Sharded):
|
|||
|
||||
@contextmanager
|
||||
def module_init_context(self, empty_init: Optional[bool] = None) -> Generator[None, None, None]:
|
||||
# TODO: Use the meta device and reset parameters after https://github.com/pytorch/pytorch/issues/90465
|
||||
# is resolved. For now, the module will get moved to the device in `setup_module`.
|
||||
empty_init_context = _EmptyInit(enabled=bool(empty_init)) if _TORCH_GREATER_EQUAL_1_13 else nullcontext()
|
||||
empty_init_context: Union[torch.device, _EmptyInit, nullcontext]
|
||||
if _TORCH_GREATER_EQUAL_2_1 and empty_init:
|
||||
# Materialization happens in `setup`. When modules get wrapped by FSDP, the sequence of operations is:
|
||||
# 1) materialize module 2) call `reset_parameters()` 3) shard the module.
|
||||
# These operations are applied to each submodule 'bottom up' in the module hierarchy.
|
||||
empty_init_context = torch.device("meta")
|
||||
elif _TORCH_GREATER_EQUAL_1_13:
|
||||
empty_init_context = _EmptyInit(enabled=bool(empty_init))
|
||||
else:
|
||||
empty_init_context = nullcontext()
|
||||
with empty_init_context, self.precision.init_context(), self.module_sharded_context():
|
||||
yield
|
||||
|
||||
|
@ -841,6 +853,16 @@ def _load_raw_module_state(state_dict: Dict[str, Any], module: Module, strict: b
|
|||
module.load_state_dict(state_dict, strict=strict)
|
||||
|
||||
|
||||
def _has_meta_device_parameters(obj: Union[Module, Optimizer]) -> bool:
|
||||
if isinstance(obj, Optimizer):
|
||||
return any(
|
||||
t.is_meta for param_group in obj.param_groups for t in param_group["params"] if isinstance(t, Parameter)
|
||||
)
|
||||
if isinstance(obj, Module):
|
||||
return any(t.is_meta for t in obj.parameters())
|
||||
raise TypeError(f"Expected `torch.nn.Module` or `torch.optim.Optimizer`, got: {type(obj).__name__}")
|
||||
|
||||
|
||||
def _no_op() -> None:
|
||||
pass
|
||||
|
||||
|
|
|
@ -29,7 +29,11 @@ import lightning.fabric
|
|||
from lightning.fabric import Fabric
|
||||
from lightning.fabric.plugins.environments import LightningEnvironment
|
||||
from lightning.fabric.strategies import FSDPStrategy
|
||||
from lightning.fabric.strategies.fsdp import _FSDPBackwardSyncControl, fsdp_overlap_step_with_backward
|
||||
from lightning.fabric.strategies.fsdp import (
|
||||
_FSDPBackwardSyncControl,
|
||||
_has_meta_device_parameters,
|
||||
fsdp_overlap_step_with_backward,
|
||||
)
|
||||
from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_1_12, _TORCH_GREATER_EQUAL_2_1
|
||||
from tests_fabric.helpers.runif import RunIf
|
||||
from tests_fabric.strategies.test_single_device import _MyFabricGradNorm
|
||||
|
@ -395,6 +399,25 @@ def test_set_timeout(init_process_group_mock):
|
|||
)
|
||||
|
||||
|
||||
def test_has_meta_device_parameters():
|
||||
"""Test that the `_has_meta_device_parameters` function can find meta-device parameters in models and
|
||||
optimizers."""
|
||||
# nn.Module
|
||||
module = nn.Linear(2, 2)
|
||||
meta_module = nn.Linear(2, 2, device="meta")
|
||||
assert not _has_meta_device_parameters(module)
|
||||
assert _has_meta_device_parameters(meta_module)
|
||||
assert _has_meta_device_parameters(nn.Sequential(module, meta_module, nn.ReLU()))
|
||||
# optim.Optimizer
|
||||
optimizer = torch.optim.SGD(module.parameters(), lr=0.1)
|
||||
meta_optimizer = torch.optim.SGD(meta_module.parameters(), lr=0.1)
|
||||
assert not _has_meta_device_parameters(optimizer)
|
||||
assert _has_meta_device_parameters(meta_optimizer)
|
||||
# unsupported objects
|
||||
with pytest.raises(TypeError, match="Expected `torch.nn.Module` or `torch.optim.Optimizer`"):
|
||||
_has_meta_device_parameters(None)
|
||||
|
||||
|
||||
class SubBlock(nn.Sequential):
|
||||
def __init__(self, feature_dim: int) -> None:
|
||||
super().__init__(
|
||||
|
|
|
@ -23,7 +23,11 @@ from torch.nn import Parameter
|
|||
from lightning.fabric import Fabric
|
||||
from lightning.fabric.plugins import FSDPPrecision
|
||||
from lightning.fabric.strategies import FSDPStrategy
|
||||
from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_1_12, _TORCH_GREATER_EQUAL_2_0
|
||||
from lightning.fabric.utilities.imports import (
|
||||
_TORCH_GREATER_EQUAL_1_12,
|
||||
_TORCH_GREATER_EQUAL_2_0,
|
||||
_TORCH_GREATER_EQUAL_2_1,
|
||||
)
|
||||
from lightning.fabric.wrappers import _FabricOptimizer
|
||||
from tests_fabric.helpers.models import BoringFabric
|
||||
from tests_fabric.helpers.runif import RunIf
|
||||
|
@ -393,21 +397,27 @@ def test_module_init_context(precision, expected_dtype):
|
|||
)
|
||||
fabric.launch()
|
||||
|
||||
with fabric.init_module():
|
||||
model = torch.nn.Linear(100, 100, bias=False)
|
||||
def _run_setup_assertions(empty_init, expected_device):
|
||||
with fabric.init_module(empty_init=empty_init):
|
||||
model = torch.nn.Linear(100, 100, bias=False)
|
||||
|
||||
# The model is on the CPU until after `.setup()``
|
||||
# TODO: Support initialization on meta device
|
||||
expected_device = torch.device("cpu")
|
||||
assert model.weight.device == expected_device
|
||||
assert model.weight.dtype == expected_dtype
|
||||
# The model is on the CPU/meta-device until after `.setup()``
|
||||
assert model.weight.device == expected_device
|
||||
assert model.weight.dtype == expected_dtype
|
||||
model = fabric.setup(model)
|
||||
# Parameters get sharded in `.setup()` and moved to the target device
|
||||
assert model.weight.device == torch.device("cuda", fabric.local_rank)
|
||||
assert model.weight.dtype == expected_dtype
|
||||
|
||||
optimizer = torch.optim.SGD(model.parameters(), lr=0.1)
|
||||
model, optimizer = fabric.setup(model, optimizer)
|
||||
# Case 1: No empty init
|
||||
_run_setup_assertions(empty_init=False, expected_device=torch.device("cpu"))
|
||||
|
||||
# Parameters get sharded in `.setup()` and moved to the target device
|
||||
assert model.weight.device == torch.device("cuda", fabric.local_rank)
|
||||
assert model.weight.dtype == expected_dtype
|
||||
if _TORCH_GREATER_EQUAL_2_1:
|
||||
# Case 2: Empty-init with PyTorch >= 2.1 supports meta device
|
||||
_run_setup_assertions(empty_init=True, expected_device=torch.device("meta"))
|
||||
else:
|
||||
# Case 2: Empty-init with PyTorch < 2.1 only supports `torch.empty()`-init
|
||||
_run_setup_assertions(empty_init=True, expected_device=torch.device("cpu"))
|
||||
|
||||
|
||||
@RunIf(min_cuda_gpus=2, standalone=True, min_torch="2.0.0")
|
||||
|
@ -460,7 +470,7 @@ def test_fsdp_manual_activation_checkpointing():
|
|||
|
||||
|
||||
@RunIf(min_torch="1.12", min_cuda_gpus=1)
|
||||
def test_rewrap_warning():
|
||||
def test_rewrap_warnings():
|
||||
from torch.distributed.fsdp import FullyShardedDataParallel
|
||||
from torch.distributed.fsdp.wrap import wrap
|
||||
|
||||
|
@ -473,3 +483,13 @@ def test_rewrap_warning():
|
|||
model = fabric.setup(model)
|
||||
assert not isinstance(model._forward_module, FullyShardedDataParallel)
|
||||
assert isinstance(model._forward_module[2], FullyShardedDataParallel)
|
||||
|
||||
if not _TORCH_GREATER_EQUAL_2_1:
|
||||
return
|
||||
|
||||
with fabric.init_module(empty_init=True):
|
||||
model = torch.nn.Sequential(torch.nn.Linear(1, 1), torch.nn.ReLU(), wrap(torch.nn.Linear(1, 1)))
|
||||
assert model[0].weight.is_meta
|
||||
with pytest.warns(match="there are still parameters on the meta device"):
|
||||
fabric_model = fabric.setup(model)
|
||||
assert next(fabric_model.parameters()).is_meta
|
||||
|
|
|
@ -273,6 +273,22 @@ def test_setup_optimizers_not_supported(strategy_cls):
|
|||
fabric.setup_optimizers(optimizer)
|
||||
|
||||
|
||||
@RunIf(min_cuda_gpus=1, min_torch="2.1")
|
||||
def test_setup_optimizer_on_meta_device():
|
||||
"""Test that the setup-methods validate that the optimizer doesn't have references to meta-device
|
||||
parameters."""
|
||||
fabric = Fabric(strategy="fsdp", devices=1)
|
||||
fabric._launched = True # pretend we have launched multiple processes
|
||||
with fabric.init_module(empty_init=True):
|
||||
model = nn.Linear(1, 2)
|
||||
assert model.weight.is_meta
|
||||
optimizer = torch.optim.Adam(model.parameters()) # optimizer references meta device params
|
||||
with pytest.raises(RuntimeError, match="The optimizer has references to the model's meta-device parameters"):
|
||||
fabric.setup(model, optimizer)
|
||||
with pytest.raises(RuntimeError, match="The optimizer has references to the model's meta-device parameters"):
|
||||
fabric.setup_optimizers(optimizer)
|
||||
|
||||
|
||||
def test_setup_tracks_num_models():
|
||||
"""Test that setup() tracks how many times it has setup a model."""
|
||||
fabric = Fabric(devices=1)
|
||||
|
|
Loading…
Reference in New Issue