From 50e01c7012f675b33e30152453b23a9fdd4ba786 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Wed, 2 Aug 2023 13:58:32 +0200 Subject: [PATCH] Meta device initialization for FSDP in Fabric (#18122) Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- src/lightning/fabric/CHANGELOG.md | 5 +- src/lightning/fabric/fabric.py | 45 +++++++++++------ src/lightning/fabric/strategies/fsdp.py | 30 ++++++++++-- tests/tests_fabric/strategies/test_fsdp.py | 25 +++++++++- .../strategies/test_fsdp_integration.py | 48 +++++++++++++------ tests/tests_fabric/test_fabric.py | 16 +++++++ 6 files changed, 135 insertions(+), 34 deletions(-) diff --git a/src/lightning/fabric/CHANGELOG.md b/src/lightning/fabric/CHANGELOG.md index 155d0dd852..07f851b17e 100644 --- a/src/lightning/fabric/CHANGELOG.md +++ b/src/lightning/fabric/CHANGELOG.md @@ -40,7 +40,10 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). * Handles initialization for FSDP models before wrapping and the Zero stage 3 initialization for DeepSpeed before sharding -- Added supports for empty weight initialization with `Fabric.init_module(empty_init=True)` for checkpoint loading ([#17627](https://github.com/Lightning-AI/lightning/pull/17627)) +- Added support for empty weight initialization with `Fabric.init_module(empty_init=True)` for checkpoint loading ([#17627](https://github.com/Lightning-AI/lightning/pull/17627)) + + +- Added support for meta-device initialization with `Fabric.init_module(empty_init=True)` in FSDP ([#18122](https://github.com/Lightning-AI/lightning/pull/18122)) - Added `lightning.fabric.plugins.Precision.init_context()` and `lightning.fabric.strategies.Strategy.module_init_context()` context managers to control model and tensor instantiation ([#17462](https://github.com/Lightning-AI/lightning/pull/17462)) diff --git a/src/lightning/fabric/fabric.py b/src/lightning/fabric/fabric.py index 2044f12e5d..1fe9479f67 100644 --- a/src/lightning/fabric/fabric.py +++ b/src/lightning/fabric/fabric.py @@ -41,6 +41,7 @@ from lightning.fabric.strategies import ( Strategy, XLAStrategy, ) +from lightning.fabric.strategies.fsdp import _has_meta_device_parameters from lightning.fabric.strategies.launchers import _MultiProcessingLauncher, _XLALauncher from lightning.fabric.strategies.strategy import _Sharded, TBroadcast from lightning.fabric.utilities import move_data_to_device @@ -224,8 +225,9 @@ class Fabric: module = _FabricModule(module, self._precision, original_module=original_module) - # Update the _DeviceDtypeModuleMixin's device parameter - module.to(self.device if move_to_device else next(module.parameters(), torch.tensor(0)).device) + if not _has_meta_device_parameters(module): + # Update the _DeviceDtypeModuleMixin's device parameter + module.to(self.device if move_to_device else next(module.parameters(), torch.tensor(0)).device) optimizers = [ _FabricOptimizer(optimizer=optimizer, strategy=self._strategy, callbacks=self._callbacks) @@ -383,7 +385,7 @@ class Fabric: if isinstance(self._strategy, DeepSpeedStrategy): if model is None: if self._models_setup == 0: - raise RuntimeError("No models were set up for backward. Did you forget to call `self.setup()`?") + raise RuntimeError("No models were set up for backward. Did you forget to call `fabric.setup()`?") if self._models_setup > 1: raise ValueError( "When using multiple models + deepspeed, please provide the model used to perform" @@ -588,14 +590,14 @@ class Fabric: Example:: # Accumulate gradient 8 batches at a time - with self.no_backward_sync(model, enabled=(batch_idx % 8 != 0)): + with fabric.no_backward_sync(model, enabled=(batch_idx % 8 != 0)): output = model(input) loss = ... - self.backward(loss) + fabric.backward(loss) ... For those strategies that don't support it, a warning is emitted. For single-device strategies, it is a no-op. - Both the model's `.forward()` and the `self.backward()` call need to run under this context. + Both the model's `.forward()` and the `fabric.backward()` call need to run under this context. Args: module: The module for which to control the gradient synchronization. @@ -605,8 +607,8 @@ class Fabric: module = _unwrap_compiled(module) if not isinstance(module, _FabricModule): raise TypeError( - "You need to set up the model first before you can call `self.no_backward_sync()`:" - " `model = self.setup(model, ...)`" + "You need to set up the model first before you can call `fabric.no_backward_sync()`:" + " `model = fabric.setup(model, ...)`" ) if not enabled or isinstance(self._strategy, (SingleDeviceStrategy, XLAStrategy)): context = nullcontext() @@ -956,12 +958,20 @@ class Fabric: if any(isinstance(opt, _FabricOptimizer) for opt in optimizers): raise ValueError("An optimizer should be passed only once to the `setup` method.") - if isinstance(self._strategy, FSDPStrategy) and not _TORCH_GREATER_EQUAL_2_0: - raise RuntimeError( - f"The `{type(self).__name__}` requires the model and optimizer(s) to be set up separately." - " Create and set up the model first through `model = self.setup_module(model)`. Then create the" - " optimizer and set it up: `optimizer = self.setup_optimizer(optimizer)`." - ) + if isinstance(self._strategy, FSDPStrategy): + if not _TORCH_GREATER_EQUAL_2_0: + raise RuntimeError( + f"The `{type(self).__name__}` requires the model and optimizer(s) to be set up separately." + " Create and set up the model first through `model = fabric.setup_module(model)`. Then create the" + " optimizer and set it up: `optimizer = fabric.setup_optimizers(optimizer)`." + ) + if any(_has_meta_device_parameters(optimizer) for optimizer in optimizers): + raise RuntimeError( + "The optimizer has references to the model's meta-device parameters. Materializing them is" + " is currently not supported unless you to set up the model and optimizer(s) separately." + " Create and set up the model first through `model = fabric.setup_module(model)`. Then create the" + " optimizer and set it up: `optimizer = fabric.setup_optimizers(optimizer)`." + ) def _validate_setup_module(self, module: nn.Module) -> None: self._validate_launched() @@ -982,6 +992,13 @@ class Fabric: if any(isinstance(opt, _FabricOptimizer) for opt in optimizers): raise ValueError("An optimizer should be passed only once to the `setup_optimizers` method.") + if any(_has_meta_device_parameters(optimizer) for optimizer in optimizers): + raise RuntimeError( + "The optimizer has references to the model's meta-device parameters. Materializing them is" + " is currently not supported. Create the optimizer after setting up the model, then call" + " `fabric.setup_optimizers(optimizer)`." + ) + def _validate_setup_dataloaders(self, dataloaders: Sequence[DataLoader]) -> None: self._validate_launched() if not dataloaders: diff --git a/src/lightning/fabric/strategies/fsdp.py b/src/lightning/fabric/strategies/fsdp.py index 51c847cbe5..975c7e554a 100644 --- a/src/lightning/fabric/strategies/fsdp.py +++ b/src/lightning/fabric/strategies/fsdp.py @@ -35,7 +35,7 @@ from typing import ( import torch from torch import Tensor -from torch.nn import Module +from torch.nn import Module, Parameter from torch.optim import Optimizer from typing_extensions import TypeGuard @@ -264,6 +264,11 @@ class FSDPStrategy(ParallelStrategy, _Sharded): from torch.distributed.fsdp import FullyShardedDataParallel if any(isinstance(mod, FullyShardedDataParallel) for mod in module.modules()): + # The user has wrapped their submodules manually, don't apply the auto wrap policy. + if _has_meta_device_parameters(module): + rank_zero_warn( + "The model is already wrapped in `FSDP` but there are still parameters on the meta device." + ) if "auto_wrap_policy" in self._fsdp_kwargs: rank_zero_warn( "A FSDP `auto_wrap_policy` is set, but the model is already wrapped. The policy will be ignored." @@ -317,9 +322,16 @@ class FSDPStrategy(ParallelStrategy, _Sharded): @contextmanager def module_init_context(self, empty_init: Optional[bool] = None) -> Generator[None, None, None]: - # TODO: Use the meta device and reset parameters after https://github.com/pytorch/pytorch/issues/90465 - # is resolved. For now, the module will get moved to the device in `setup_module`. - empty_init_context = _EmptyInit(enabled=bool(empty_init)) if _TORCH_GREATER_EQUAL_1_13 else nullcontext() + empty_init_context: Union[torch.device, _EmptyInit, nullcontext] + if _TORCH_GREATER_EQUAL_2_1 and empty_init: + # Materialization happens in `setup`. When modules get wrapped by FSDP, the sequence of operations is: + # 1) materialize module 2) call `reset_parameters()` 3) shard the module. + # These operations are applied to each submodule 'bottom up' in the module hierarchy. + empty_init_context = torch.device("meta") + elif _TORCH_GREATER_EQUAL_1_13: + empty_init_context = _EmptyInit(enabled=bool(empty_init)) + else: + empty_init_context = nullcontext() with empty_init_context, self.precision.init_context(), self.module_sharded_context(): yield @@ -841,6 +853,16 @@ def _load_raw_module_state(state_dict: Dict[str, Any], module: Module, strict: b module.load_state_dict(state_dict, strict=strict) +def _has_meta_device_parameters(obj: Union[Module, Optimizer]) -> bool: + if isinstance(obj, Optimizer): + return any( + t.is_meta for param_group in obj.param_groups for t in param_group["params"] if isinstance(t, Parameter) + ) + if isinstance(obj, Module): + return any(t.is_meta for t in obj.parameters()) + raise TypeError(f"Expected `torch.nn.Module` or `torch.optim.Optimizer`, got: {type(obj).__name__}") + + def _no_op() -> None: pass diff --git a/tests/tests_fabric/strategies/test_fsdp.py b/tests/tests_fabric/strategies/test_fsdp.py index d2838044c3..dd782d636f 100644 --- a/tests/tests_fabric/strategies/test_fsdp.py +++ b/tests/tests_fabric/strategies/test_fsdp.py @@ -29,7 +29,11 @@ import lightning.fabric from lightning.fabric import Fabric from lightning.fabric.plugins.environments import LightningEnvironment from lightning.fabric.strategies import FSDPStrategy -from lightning.fabric.strategies.fsdp import _FSDPBackwardSyncControl, fsdp_overlap_step_with_backward +from lightning.fabric.strategies.fsdp import ( + _FSDPBackwardSyncControl, + _has_meta_device_parameters, + fsdp_overlap_step_with_backward, +) from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_1_12, _TORCH_GREATER_EQUAL_2_1 from tests_fabric.helpers.runif import RunIf from tests_fabric.strategies.test_single_device import _MyFabricGradNorm @@ -395,6 +399,25 @@ def test_set_timeout(init_process_group_mock): ) +def test_has_meta_device_parameters(): + """Test that the `_has_meta_device_parameters` function can find meta-device parameters in models and + optimizers.""" + # nn.Module + module = nn.Linear(2, 2) + meta_module = nn.Linear(2, 2, device="meta") + assert not _has_meta_device_parameters(module) + assert _has_meta_device_parameters(meta_module) + assert _has_meta_device_parameters(nn.Sequential(module, meta_module, nn.ReLU())) + # optim.Optimizer + optimizer = torch.optim.SGD(module.parameters(), lr=0.1) + meta_optimizer = torch.optim.SGD(meta_module.parameters(), lr=0.1) + assert not _has_meta_device_parameters(optimizer) + assert _has_meta_device_parameters(meta_optimizer) + # unsupported objects + with pytest.raises(TypeError, match="Expected `torch.nn.Module` or `torch.optim.Optimizer`"): + _has_meta_device_parameters(None) + + class SubBlock(nn.Sequential): def __init__(self, feature_dim: int) -> None: super().__init__( diff --git a/tests/tests_fabric/strategies/test_fsdp_integration.py b/tests/tests_fabric/strategies/test_fsdp_integration.py index 3a9b0e3ad4..b366e4c3c6 100644 --- a/tests/tests_fabric/strategies/test_fsdp_integration.py +++ b/tests/tests_fabric/strategies/test_fsdp_integration.py @@ -23,7 +23,11 @@ from torch.nn import Parameter from lightning.fabric import Fabric from lightning.fabric.plugins import FSDPPrecision from lightning.fabric.strategies import FSDPStrategy -from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_1_12, _TORCH_GREATER_EQUAL_2_0 +from lightning.fabric.utilities.imports import ( + _TORCH_GREATER_EQUAL_1_12, + _TORCH_GREATER_EQUAL_2_0, + _TORCH_GREATER_EQUAL_2_1, +) from lightning.fabric.wrappers import _FabricOptimizer from tests_fabric.helpers.models import BoringFabric from tests_fabric.helpers.runif import RunIf @@ -393,21 +397,27 @@ def test_module_init_context(precision, expected_dtype): ) fabric.launch() - with fabric.init_module(): - model = torch.nn.Linear(100, 100, bias=False) + def _run_setup_assertions(empty_init, expected_device): + with fabric.init_module(empty_init=empty_init): + model = torch.nn.Linear(100, 100, bias=False) - # The model is on the CPU until after `.setup()`` - # TODO: Support initialization on meta device - expected_device = torch.device("cpu") - assert model.weight.device == expected_device - assert model.weight.dtype == expected_dtype + # The model is on the CPU/meta-device until after `.setup()`` + assert model.weight.device == expected_device + assert model.weight.dtype == expected_dtype + model = fabric.setup(model) + # Parameters get sharded in `.setup()` and moved to the target device + assert model.weight.device == torch.device("cuda", fabric.local_rank) + assert model.weight.dtype == expected_dtype - optimizer = torch.optim.SGD(model.parameters(), lr=0.1) - model, optimizer = fabric.setup(model, optimizer) + # Case 1: No empty init + _run_setup_assertions(empty_init=False, expected_device=torch.device("cpu")) - # Parameters get sharded in `.setup()` and moved to the target device - assert model.weight.device == torch.device("cuda", fabric.local_rank) - assert model.weight.dtype == expected_dtype + if _TORCH_GREATER_EQUAL_2_1: + # Case 2: Empty-init with PyTorch >= 2.1 supports meta device + _run_setup_assertions(empty_init=True, expected_device=torch.device("meta")) + else: + # Case 2: Empty-init with PyTorch < 2.1 only supports `torch.empty()`-init + _run_setup_assertions(empty_init=True, expected_device=torch.device("cpu")) @RunIf(min_cuda_gpus=2, standalone=True, min_torch="2.0.0") @@ -460,7 +470,7 @@ def test_fsdp_manual_activation_checkpointing(): @RunIf(min_torch="1.12", min_cuda_gpus=1) -def test_rewrap_warning(): +def test_rewrap_warnings(): from torch.distributed.fsdp import FullyShardedDataParallel from torch.distributed.fsdp.wrap import wrap @@ -473,3 +483,13 @@ def test_rewrap_warning(): model = fabric.setup(model) assert not isinstance(model._forward_module, FullyShardedDataParallel) assert isinstance(model._forward_module[2], FullyShardedDataParallel) + + if not _TORCH_GREATER_EQUAL_2_1: + return + + with fabric.init_module(empty_init=True): + model = torch.nn.Sequential(torch.nn.Linear(1, 1), torch.nn.ReLU(), wrap(torch.nn.Linear(1, 1))) + assert model[0].weight.is_meta + with pytest.warns(match="there are still parameters on the meta device"): + fabric_model = fabric.setup(model) + assert next(fabric_model.parameters()).is_meta diff --git a/tests/tests_fabric/test_fabric.py b/tests/tests_fabric/test_fabric.py index d8867694e7..ae99397de3 100644 --- a/tests/tests_fabric/test_fabric.py +++ b/tests/tests_fabric/test_fabric.py @@ -273,6 +273,22 @@ def test_setup_optimizers_not_supported(strategy_cls): fabric.setup_optimizers(optimizer) +@RunIf(min_cuda_gpus=1, min_torch="2.1") +def test_setup_optimizer_on_meta_device(): + """Test that the setup-methods validate that the optimizer doesn't have references to meta-device + parameters.""" + fabric = Fabric(strategy="fsdp", devices=1) + fabric._launched = True # pretend we have launched multiple processes + with fabric.init_module(empty_init=True): + model = nn.Linear(1, 2) + assert model.weight.is_meta + optimizer = torch.optim.Adam(model.parameters()) # optimizer references meta device params + with pytest.raises(RuntimeError, match="The optimizer has references to the model's meta-device parameters"): + fabric.setup(model, optimizer) + with pytest.raises(RuntimeError, match="The optimizer has references to the model's meta-device parameters"): + fabric.setup_optimizers(optimizer) + + def test_setup_tracks_num_models(): """Test that setup() tracks how many times it has setup a model.""" fabric = Fabric(devices=1)