Meta device initialization for FSDP in Fabric (#18122)

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
Adrian Wälchli 2023-08-02 13:58:32 +02:00 committed by GitHub
parent c5cb532694
commit 50e01c7012
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 135 additions and 34 deletions

View File

@ -40,7 +40,10 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
* Handles initialization for FSDP models before wrapping and the Zero stage 3 initialization for DeepSpeed before sharding
- Added supports for empty weight initialization with `Fabric.init_module(empty_init=True)` for checkpoint loading ([#17627](https://github.com/Lightning-AI/lightning/pull/17627))
- Added support for empty weight initialization with `Fabric.init_module(empty_init=True)` for checkpoint loading ([#17627](https://github.com/Lightning-AI/lightning/pull/17627))
- Added support for meta-device initialization with `Fabric.init_module(empty_init=True)` in FSDP ([#18122](https://github.com/Lightning-AI/lightning/pull/18122))
- Added `lightning.fabric.plugins.Precision.init_context()` and `lightning.fabric.strategies.Strategy.module_init_context()` context managers to control model and tensor instantiation ([#17462](https://github.com/Lightning-AI/lightning/pull/17462))

View File

@ -41,6 +41,7 @@ from lightning.fabric.strategies import (
Strategy,
XLAStrategy,
)
from lightning.fabric.strategies.fsdp import _has_meta_device_parameters
from lightning.fabric.strategies.launchers import _MultiProcessingLauncher, _XLALauncher
from lightning.fabric.strategies.strategy import _Sharded, TBroadcast
from lightning.fabric.utilities import move_data_to_device
@ -224,8 +225,9 @@ class Fabric:
module = _FabricModule(module, self._precision, original_module=original_module)
# Update the _DeviceDtypeModuleMixin's device parameter
module.to(self.device if move_to_device else next(module.parameters(), torch.tensor(0)).device)
if not _has_meta_device_parameters(module):
# Update the _DeviceDtypeModuleMixin's device parameter
module.to(self.device if move_to_device else next(module.parameters(), torch.tensor(0)).device)
optimizers = [
_FabricOptimizer(optimizer=optimizer, strategy=self._strategy, callbacks=self._callbacks)
@ -383,7 +385,7 @@ class Fabric:
if isinstance(self._strategy, DeepSpeedStrategy):
if model is None:
if self._models_setup == 0:
raise RuntimeError("No models were set up for backward. Did you forget to call `self.setup()`?")
raise RuntimeError("No models were set up for backward. Did you forget to call `fabric.setup()`?")
if self._models_setup > 1:
raise ValueError(
"When using multiple models + deepspeed, please provide the model used to perform"
@ -588,14 +590,14 @@ class Fabric:
Example::
# Accumulate gradient 8 batches at a time
with self.no_backward_sync(model, enabled=(batch_idx % 8 != 0)):
with fabric.no_backward_sync(model, enabled=(batch_idx % 8 != 0)):
output = model(input)
loss = ...
self.backward(loss)
fabric.backward(loss)
...
For those strategies that don't support it, a warning is emitted. For single-device strategies, it is a no-op.
Both the model's `.forward()` and the `self.backward()` call need to run under this context.
Both the model's `.forward()` and the `fabric.backward()` call need to run under this context.
Args:
module: The module for which to control the gradient synchronization.
@ -605,8 +607,8 @@ class Fabric:
module = _unwrap_compiled(module)
if not isinstance(module, _FabricModule):
raise TypeError(
"You need to set up the model first before you can call `self.no_backward_sync()`:"
" `model = self.setup(model, ...)`"
"You need to set up the model first before you can call `fabric.no_backward_sync()`:"
" `model = fabric.setup(model, ...)`"
)
if not enabled or isinstance(self._strategy, (SingleDeviceStrategy, XLAStrategy)):
context = nullcontext()
@ -956,12 +958,20 @@ class Fabric:
if any(isinstance(opt, _FabricOptimizer) for opt in optimizers):
raise ValueError("An optimizer should be passed only once to the `setup` method.")
if isinstance(self._strategy, FSDPStrategy) and not _TORCH_GREATER_EQUAL_2_0:
raise RuntimeError(
f"The `{type(self).__name__}` requires the model and optimizer(s) to be set up separately."
" Create and set up the model first through `model = self.setup_module(model)`. Then create the"
" optimizer and set it up: `optimizer = self.setup_optimizer(optimizer)`."
)
if isinstance(self._strategy, FSDPStrategy):
if not _TORCH_GREATER_EQUAL_2_0:
raise RuntimeError(
f"The `{type(self).__name__}` requires the model and optimizer(s) to be set up separately."
" Create and set up the model first through `model = fabric.setup_module(model)`. Then create the"
" optimizer and set it up: `optimizer = fabric.setup_optimizers(optimizer)`."
)
if any(_has_meta_device_parameters(optimizer) for optimizer in optimizers):
raise RuntimeError(
"The optimizer has references to the model's meta-device parameters. Materializing them is"
" is currently not supported unless you to set up the model and optimizer(s) separately."
" Create and set up the model first through `model = fabric.setup_module(model)`. Then create the"
" optimizer and set it up: `optimizer = fabric.setup_optimizers(optimizer)`."
)
def _validate_setup_module(self, module: nn.Module) -> None:
self._validate_launched()
@ -982,6 +992,13 @@ class Fabric:
if any(isinstance(opt, _FabricOptimizer) for opt in optimizers):
raise ValueError("An optimizer should be passed only once to the `setup_optimizers` method.")
if any(_has_meta_device_parameters(optimizer) for optimizer in optimizers):
raise RuntimeError(
"The optimizer has references to the model's meta-device parameters. Materializing them is"
" is currently not supported. Create the optimizer after setting up the model, then call"
" `fabric.setup_optimizers(optimizer)`."
)
def _validate_setup_dataloaders(self, dataloaders: Sequence[DataLoader]) -> None:
self._validate_launched()
if not dataloaders:

View File

@ -35,7 +35,7 @@ from typing import (
import torch
from torch import Tensor
from torch.nn import Module
from torch.nn import Module, Parameter
from torch.optim import Optimizer
from typing_extensions import TypeGuard
@ -264,6 +264,11 @@ class FSDPStrategy(ParallelStrategy, _Sharded):
from torch.distributed.fsdp import FullyShardedDataParallel
if any(isinstance(mod, FullyShardedDataParallel) for mod in module.modules()):
# The user has wrapped their submodules manually, don't apply the auto wrap policy.
if _has_meta_device_parameters(module):
rank_zero_warn(
"The model is already wrapped in `FSDP` but there are still parameters on the meta device."
)
if "auto_wrap_policy" in self._fsdp_kwargs:
rank_zero_warn(
"A FSDP `auto_wrap_policy` is set, but the model is already wrapped. The policy will be ignored."
@ -317,9 +322,16 @@ class FSDPStrategy(ParallelStrategy, _Sharded):
@contextmanager
def module_init_context(self, empty_init: Optional[bool] = None) -> Generator[None, None, None]:
# TODO: Use the meta device and reset parameters after https://github.com/pytorch/pytorch/issues/90465
# is resolved. For now, the module will get moved to the device in `setup_module`.
empty_init_context = _EmptyInit(enabled=bool(empty_init)) if _TORCH_GREATER_EQUAL_1_13 else nullcontext()
empty_init_context: Union[torch.device, _EmptyInit, nullcontext]
if _TORCH_GREATER_EQUAL_2_1 and empty_init:
# Materialization happens in `setup`. When modules get wrapped by FSDP, the sequence of operations is:
# 1) materialize module 2) call `reset_parameters()` 3) shard the module.
# These operations are applied to each submodule 'bottom up' in the module hierarchy.
empty_init_context = torch.device("meta")
elif _TORCH_GREATER_EQUAL_1_13:
empty_init_context = _EmptyInit(enabled=bool(empty_init))
else:
empty_init_context = nullcontext()
with empty_init_context, self.precision.init_context(), self.module_sharded_context():
yield
@ -841,6 +853,16 @@ def _load_raw_module_state(state_dict: Dict[str, Any], module: Module, strict: b
module.load_state_dict(state_dict, strict=strict)
def _has_meta_device_parameters(obj: Union[Module, Optimizer]) -> bool:
if isinstance(obj, Optimizer):
return any(
t.is_meta for param_group in obj.param_groups for t in param_group["params"] if isinstance(t, Parameter)
)
if isinstance(obj, Module):
return any(t.is_meta for t in obj.parameters())
raise TypeError(f"Expected `torch.nn.Module` or `torch.optim.Optimizer`, got: {type(obj).__name__}")
def _no_op() -> None:
pass

View File

@ -29,7 +29,11 @@ import lightning.fabric
from lightning.fabric import Fabric
from lightning.fabric.plugins.environments import LightningEnvironment
from lightning.fabric.strategies import FSDPStrategy
from lightning.fabric.strategies.fsdp import _FSDPBackwardSyncControl, fsdp_overlap_step_with_backward
from lightning.fabric.strategies.fsdp import (
_FSDPBackwardSyncControl,
_has_meta_device_parameters,
fsdp_overlap_step_with_backward,
)
from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_1_12, _TORCH_GREATER_EQUAL_2_1
from tests_fabric.helpers.runif import RunIf
from tests_fabric.strategies.test_single_device import _MyFabricGradNorm
@ -395,6 +399,25 @@ def test_set_timeout(init_process_group_mock):
)
def test_has_meta_device_parameters():
"""Test that the `_has_meta_device_parameters` function can find meta-device parameters in models and
optimizers."""
# nn.Module
module = nn.Linear(2, 2)
meta_module = nn.Linear(2, 2, device="meta")
assert not _has_meta_device_parameters(module)
assert _has_meta_device_parameters(meta_module)
assert _has_meta_device_parameters(nn.Sequential(module, meta_module, nn.ReLU()))
# optim.Optimizer
optimizer = torch.optim.SGD(module.parameters(), lr=0.1)
meta_optimizer = torch.optim.SGD(meta_module.parameters(), lr=0.1)
assert not _has_meta_device_parameters(optimizer)
assert _has_meta_device_parameters(meta_optimizer)
# unsupported objects
with pytest.raises(TypeError, match="Expected `torch.nn.Module` or `torch.optim.Optimizer`"):
_has_meta_device_parameters(None)
class SubBlock(nn.Sequential):
def __init__(self, feature_dim: int) -> None:
super().__init__(

View File

@ -23,7 +23,11 @@ from torch.nn import Parameter
from lightning.fabric import Fabric
from lightning.fabric.plugins import FSDPPrecision
from lightning.fabric.strategies import FSDPStrategy
from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_1_12, _TORCH_GREATER_EQUAL_2_0
from lightning.fabric.utilities.imports import (
_TORCH_GREATER_EQUAL_1_12,
_TORCH_GREATER_EQUAL_2_0,
_TORCH_GREATER_EQUAL_2_1,
)
from lightning.fabric.wrappers import _FabricOptimizer
from tests_fabric.helpers.models import BoringFabric
from tests_fabric.helpers.runif import RunIf
@ -393,21 +397,27 @@ def test_module_init_context(precision, expected_dtype):
)
fabric.launch()
with fabric.init_module():
model = torch.nn.Linear(100, 100, bias=False)
def _run_setup_assertions(empty_init, expected_device):
with fabric.init_module(empty_init=empty_init):
model = torch.nn.Linear(100, 100, bias=False)
# The model is on the CPU until after `.setup()``
# TODO: Support initialization on meta device
expected_device = torch.device("cpu")
assert model.weight.device == expected_device
assert model.weight.dtype == expected_dtype
# The model is on the CPU/meta-device until after `.setup()``
assert model.weight.device == expected_device
assert model.weight.dtype == expected_dtype
model = fabric.setup(model)
# Parameters get sharded in `.setup()` and moved to the target device
assert model.weight.device == torch.device("cuda", fabric.local_rank)
assert model.weight.dtype == expected_dtype
optimizer = torch.optim.SGD(model.parameters(), lr=0.1)
model, optimizer = fabric.setup(model, optimizer)
# Case 1: No empty init
_run_setup_assertions(empty_init=False, expected_device=torch.device("cpu"))
# Parameters get sharded in `.setup()` and moved to the target device
assert model.weight.device == torch.device("cuda", fabric.local_rank)
assert model.weight.dtype == expected_dtype
if _TORCH_GREATER_EQUAL_2_1:
# Case 2: Empty-init with PyTorch >= 2.1 supports meta device
_run_setup_assertions(empty_init=True, expected_device=torch.device("meta"))
else:
# Case 2: Empty-init with PyTorch < 2.1 only supports `torch.empty()`-init
_run_setup_assertions(empty_init=True, expected_device=torch.device("cpu"))
@RunIf(min_cuda_gpus=2, standalone=True, min_torch="2.0.0")
@ -460,7 +470,7 @@ def test_fsdp_manual_activation_checkpointing():
@RunIf(min_torch="1.12", min_cuda_gpus=1)
def test_rewrap_warning():
def test_rewrap_warnings():
from torch.distributed.fsdp import FullyShardedDataParallel
from torch.distributed.fsdp.wrap import wrap
@ -473,3 +483,13 @@ def test_rewrap_warning():
model = fabric.setup(model)
assert not isinstance(model._forward_module, FullyShardedDataParallel)
assert isinstance(model._forward_module[2], FullyShardedDataParallel)
if not _TORCH_GREATER_EQUAL_2_1:
return
with fabric.init_module(empty_init=True):
model = torch.nn.Sequential(torch.nn.Linear(1, 1), torch.nn.ReLU(), wrap(torch.nn.Linear(1, 1)))
assert model[0].weight.is_meta
with pytest.warns(match="there are still parameters on the meta device"):
fabric_model = fabric.setup(model)
assert next(fabric_model.parameters()).is_meta

View File

@ -273,6 +273,22 @@ def test_setup_optimizers_not_supported(strategy_cls):
fabric.setup_optimizers(optimizer)
@RunIf(min_cuda_gpus=1, min_torch="2.1")
def test_setup_optimizer_on_meta_device():
"""Test that the setup-methods validate that the optimizer doesn't have references to meta-device
parameters."""
fabric = Fabric(strategy="fsdp", devices=1)
fabric._launched = True # pretend we have launched multiple processes
with fabric.init_module(empty_init=True):
model = nn.Linear(1, 2)
assert model.weight.is_meta
optimizer = torch.optim.Adam(model.parameters()) # optimizer references meta device params
with pytest.raises(RuntimeError, match="The optimizer has references to the model's meta-device parameters"):
fabric.setup(model, optimizer)
with pytest.raises(RuntimeError, match="The optimizer has references to the model's meta-device parameters"):
fabric.setup_optimizers(optimizer)
def test_setup_tracks_num_models():
"""Test that setup() tracks how many times it has setup a model."""
fabric = Fabric(devices=1)