lightning/tests/tests_fabric/strategies/test_fsdp.py

460 lines
20 KiB
Python

# Copyright The Lightning AI team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from datetime import timedelta
from re import escape
from unittest import mock
from unittest.mock import ANY, MagicMock, Mock
import lightning.fabric
import pytest
import torch
import torch.nn as nn
from lightning.fabric.plugins import HalfPrecision
from lightning.fabric.plugins.environments import LightningEnvironment
from lightning.fabric.strategies import FSDPStrategy
from lightning.fabric.strategies.fsdp import (
_FSDPBackwardSyncControl,
_get_full_state_dict_context,
_has_meta_device_parameters,
_is_sharded_checkpoint,
)
from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_1, _TORCH_GREATER_EQUAL_2_2
from torch.distributed.fsdp.fully_sharded_data_parallel import CPUOffload, FullyShardedDataParallel, MixedPrecision
from torch.optim import Adam
from tests_fabric.helpers.runif import RunIf
def test_fsdp_custom_mixed_precision():
"""Test that passing a custom mixed precision config works."""
config = MixedPrecision()
strategy = FSDPStrategy(mixed_precision=config)
assert strategy.mixed_precision_config == config
def test_fsdp_cpu_offload():
"""Test the different ways cpu offloading can be enabled."""
# bool
strategy = FSDPStrategy(cpu_offload=True)
assert strategy.cpu_offload == CPUOffload(offload_params=True)
# dataclass
config = CPUOffload()
strategy = FSDPStrategy(cpu_offload=config)
assert strategy.cpu_offload == config
def test_fsdp_sharding_strategy():
"""Test the different ways the sharding strategy can be set."""
from torch.distributed.fsdp import ShardingStrategy
# default
strategy = FSDPStrategy()
assert strategy.sharding_strategy == ShardingStrategy.FULL_SHARD
# enum
strategy = FSDPStrategy(sharding_strategy=ShardingStrategy.SHARD_GRAD_OP)
assert strategy.sharding_strategy == ShardingStrategy.SHARD_GRAD_OP
# string
strategy = FSDPStrategy(sharding_strategy="NO_SHARD")
assert strategy.sharding_strategy == ShardingStrategy.NO_SHARD
strategy = FSDPStrategy(sharding_strategy="no_shard")
assert strategy.sharding_strategy == ShardingStrategy.NO_SHARD
@RunIf(min_torch="2.0")
@pytest.mark.parametrize("sharding_strategy", ["HYBRID_SHARD", "_HYBRID_SHARD_ZERO2"])
def test_fsdp_hybrid_shard_configuration(sharding_strategy):
"""Test that the hybrid sharding strategies can only be used with automatic wrapping or a manually specified pg."""
with pytest.raises(RuntimeError, match="The hybrid sharding strategy requires you to pass at least one of"):
FSDPStrategy(sharding_strategy=sharding_strategy)
strategy = FSDPStrategy(auto_wrap_policy={nn.Linear}, sharding_strategy=sharding_strategy)
assert strategy.sharding_strategy.name == sharding_strategy
process_group = (Mock(), Mock())
strategy = FSDPStrategy(sharding_strategy=sharding_strategy, process_group=process_group)
assert strategy.sharding_strategy.name == sharding_strategy
assert strategy._fsdp_kwargs["process_group"] is process_group
device_mesh = Mock()
strategy = FSDPStrategy(sharding_strategy=sharding_strategy, device_mesh=device_mesh)
assert strategy.sharding_strategy.name == sharding_strategy
assert strategy._fsdp_kwargs["device_mesh"] is device_mesh
with pytest.raises(ValueError, match="process_group.* device_mesh=.* are mutually exclusive"):
FSDPStrategy(sharding_strategy=sharding_strategy, process_group=process_group, device_mesh=device_mesh)
def test_fsdp_checkpoint_io_unsupported():
"""Test that the FSDP strategy does not support the `CheckpointIO` plugin."""
strategy = FSDPStrategy()
with pytest.raises(NotImplementedError, match="does not use the `CheckpointIO` plugin"):
_ = strategy.checkpoint_io
with pytest.raises(NotImplementedError, match="does not support setting a `CheckpointIO` plugin"):
strategy.checkpoint_io = Mock()
@pytest.mark.parametrize("torch_ge_2_0", [False, True])
def test_fsdp_setup_optimizer_validation(torch_ge_2_0):
"""Test that `setup_optimizer()` validates the param groups and reference to FSDP parameters."""
module = nn.Linear(2, 2)
with mock.patch("lightning.fabric.strategies.fsdp._TORCH_GREATER_EQUAL_2_0", torch_ge_2_0):
strategy = FSDPStrategy(parallel_devices=[torch.device("cpu")])
bad_optimizer = Adam(module.parameters())
if torch_ge_2_0:
strategy.setup_optimizer(bad_optimizer)
else:
with pytest.raises(ValueError, match="The optimizer does not seem to reference any FSDP parameter"):
strategy.setup_optimizer(bad_optimizer)
@RunIf(min_torch="2.0.0")
@mock.patch("lightning.fabric.strategies.fsdp.FSDPStrategy.setup_module")
def test_fsdp_setup_use_orig_params(_):
module = nn.Linear(2, 2)
optimizer = Adam(module.parameters())
strategy = FSDPStrategy(parallel_devices=[torch.device("cpu")], use_orig_params=False)
assert not strategy._fsdp_kwargs["use_orig_params"]
with pytest.raises(ValueError, match=r"`FSDPStrategy\(use_orig_params=False\)` but this is not supported"):
strategy.setup_module_and_optimizers(module, optimizer)
strategy = FSDPStrategy(parallel_devices=[torch.device("cpu")])
assert strategy._fsdp_kwargs["use_orig_params"]
strategy.setup_module_and_optimizers(module, optimizer)
assert strategy._fsdp_kwargs["use_orig_params"]
def test_fsdp_no_backward_sync():
"""Test that the backward sync control calls `.no_sync()`, and only on a module wrapped in
FullyShardedDataParallel."""
strategy = FSDPStrategy()
assert isinstance(strategy._backward_sync_control, _FSDPBackwardSyncControl)
with pytest.raises(
TypeError, match="is only possible if the module passed to .* is wrapped in `FullyShardedDataParallel`"
), strategy._backward_sync_control.no_backward_sync(Mock()):
pass
module = MagicMock(spec=FullyShardedDataParallel)
with strategy._backward_sync_control.no_backward_sync(module):
pass
module.no_sync.assert_called_once()
def test_fsdp_activation_checkpointing_support(monkeypatch):
"""Test that we error out if activation checkpointing requires a newer PyTorch version."""
monkeypatch.setattr(lightning.fabric.strategies.fsdp, "_TORCH_GREATER_EQUAL_2_1", False)
with pytest.raises(ValueError, match="activation_checkpointing_policy` requires torch >= 2.1.0"):
FSDPStrategy(activation_checkpointing_policy=Mock())
def test_fsdp_activation_checkpointing():
"""Test that the FSDP strategy can apply activation checkpointing to the given layers."""
class Block1(nn.Linear):
pass
class Block2(nn.Linear):
pass
class Model(nn.Module):
def __init__(self):
super().__init__()
self.layer0 = nn.Sequential(Block1(4, 4), Block1(5, 5))
self.layer1 = Block2(2, 2)
self.layer2 = nn.Linear(3, 3)
if _TORCH_GREATER_EQUAL_2_1:
from torch.distributed.fsdp.wrap import ModuleWrapPolicy
strategy = FSDPStrategy(activation_checkpointing_policy={Block1})
assert set(strategy._activation_checkpointing_kwargs) == {"auto_wrap_policy"}
assert isinstance(strategy._activation_checkpointing_kwargs["auto_wrap_policy"], ModuleWrapPolicy)
strategy = FSDPStrategy(activation_checkpointing_policy=ModuleWrapPolicy({Block1, Block2}))
assert set(strategy._activation_checkpointing_kwargs) == {"auto_wrap_policy"}
assert isinstance(strategy._activation_checkpointing_kwargs["auto_wrap_policy"], ModuleWrapPolicy)
else:
strategy = FSDPStrategy(activation_checkpointing=Block1)
assert set(strategy._activation_checkpointing_kwargs) == {"check_fn"}
strategy = FSDPStrategy(activation_checkpointing=[Block1, Block2])
assert set(strategy._activation_checkpointing_kwargs) == {"check_fn"}
strategy = FSDPStrategy(activation_checkpointing_policy={Block1})
assert set(strategy._activation_checkpointing_kwargs) == {"check_fn"}
strategy = FSDPStrategy(activation_checkpointing_policy={Block1, Block2})
assert set(strategy._activation_checkpointing_kwargs) == {"check_fn"}
strategy._parallel_devices = [torch.device("cuda", 0)]
with mock.patch("torch.distributed.fsdp.FullyShardedDataParallel", new=MagicMock), mock.patch(
"torch.distributed.algorithms._checkpoint.checkpoint_wrapper.apply_activation_checkpointing"
) as apply_mock:
wrapped = strategy.setup_module(Model())
apply_mock.assert_called_with(wrapped, checkpoint_wrapper_fn=ANY, **strategy._activation_checkpointing_kwargs)
def test_fsdp_forbidden_precision_raises():
with pytest.raises(TypeError, match="can only work with the `FSDPPrecision"):
FSDPStrategy(precision=HalfPrecision())
strategy = FSDPStrategy()
with pytest.raises(TypeError, match="can only work with the `FSDPPrecision"):
strategy.precision = HalfPrecision()
def test_fsdp_grad_clipping_norm_error():
strategy = FSDPStrategy()
with pytest.raises(
TypeError,
match="only possible if the module.*is wrapped in `FullyShardedDataParallel`",
):
strategy.clip_gradients_norm(Mock(), Mock(), Mock())
@RunIf(min_torch="2.0.0")
def test_fsdp_save_checkpoint_storage_options(tmp_path):
"""Test that the FSDP strategy does not accept storage options for saving checkpoints."""
strategy = FSDPStrategy()
with pytest.raises(TypeError, match=escape("FSDPStrategy.save_checkpoint(..., storage_options=...)` is not")):
strategy.save_checkpoint(path=tmp_path, state=Mock(), storage_options=Mock())
@RunIf(min_torch="2.0.0")
@mock.patch("lightning.fabric.strategies.fsdp.FSDPStrategy.broadcast", lambda _, x: x)
@mock.patch("lightning.fabric.strategies.fsdp._get_full_state_dict_context")
@mock.patch("lightning.fabric.strategies.fsdp._get_sharded_state_dict_context")
@mock.patch("lightning.fabric.strategies.fsdp.torch.save")
@mock.patch("lightning.fabric.strategies.fsdp.shutil")
def test_fsdp_save_checkpoint_path_exists(shutil_mock, torch_save_mock, __, ___, tmp_path):
strategy = FSDPStrategy(state_dict_type="full")
# state_dict_type='full', path exists, path is not a sharded checkpoint: error
path = tmp_path / "not-empty"
path.mkdir()
(path / "file").touch()
assert not _is_sharded_checkpoint(path)
with pytest.raises(IsADirectoryError, match="exists and is a directory"):
strategy.save_checkpoint(path=path, state=Mock())
# state_dict_type='full', path exists, path is a sharded checkpoint: no error (overwrite)
path = tmp_path / "sharded-checkpoint"
path.mkdir()
(path / "meta.pt").touch()
assert _is_sharded_checkpoint(path)
model = Mock(spec=FullyShardedDataParallel)
model.modules.return_value = [model]
strategy.save_checkpoint(path=path, state={"model": model})
shutil_mock.rmtree.assert_called_once_with(path)
# state_dict_type='full', path exists, path is a file: no error (overwrite)
path = tmp_path / "file.pt"
path.touch()
model = Mock(spec=FullyShardedDataParallel)
model.modules.return_value = [model]
torch_save_mock.reset_mock()
strategy.save_checkpoint(path=path, state={"model": model})
torch_save_mock.assert_called_once()
strategy = FSDPStrategy(state_dict_type="sharded")
save_mock = mock.patch(
"torch.distributed.checkpoint.save"
if _TORCH_GREATER_EQUAL_2_2
else "torch.distributed.checkpoint.save_state_dict"
)
# state_dict_type='sharded', path exists, path is a folder: no error (overwrite)
path = tmp_path / "not-empty-2"
path.mkdir()
(path / "file").touch()
model = Mock(spec=FullyShardedDataParallel)
model.modules.return_value = [model]
with save_mock:
strategy.save_checkpoint(path=path, state={"model": model})
assert (path / "file").exists()
# state_dict_type='sharded', path exists, path is a file: no error (overwrite)
path = tmp_path / "file-2.pt"
path.touch()
model = Mock(spec=FullyShardedDataParallel)
model.modules.return_value = [model]
with save_mock:
strategy.save_checkpoint(path=path, state={"model": model})
assert path.is_dir()
@RunIf(min_torch="2.0.0")
@mock.patch("lightning.fabric.strategies.fsdp.FSDPStrategy.broadcast", lambda _, x: x)
def test_fsdp_save_checkpoint_one_fsdp_module_required(tmp_path):
"""Test that the FSDP strategy can only save one FSDP model per checkpoint."""
strategy = FSDPStrategy()
# missing FSDP model
with pytest.raises(ValueError, match="Could not find a FSDP model in the provided checkpoint state."):
strategy.save_checkpoint(path=tmp_path, state={})
with pytest.raises(ValueError, match="Could not find a FSDP model in the provided checkpoint state."):
strategy.load_checkpoint(path=tmp_path, state={"model": torch.nn.Linear(3, 3)})
# multiple FSDP models
model1 = Mock(spec=FullyShardedDataParallel)
model1.modules.return_value = [model1]
model2 = Mock(spec=FullyShardedDataParallel)
model2.modules.return_value = [model2]
with pytest.raises(ValueError, match="Found multiple FSDP models in the given state."):
strategy.save_checkpoint(path=tmp_path, state={"model1": model1, "model2": model2})
@RunIf(min_torch="2.0.0")
def test_fsdp_load_checkpoint_no_state(tmp_path):
"""Test that the FSDP strategy can't load the full state without access to a model instance from the user."""
strategy = FSDPStrategy()
with pytest.raises(ValueError, match=escape("Got FSDPStrategy.load_checkpoint(..., state=None")):
strategy.load_checkpoint(path=tmp_path, state=None)
with pytest.raises(ValueError, match=escape("Got FSDPStrategy.load_checkpoint(..., state={})")):
strategy.load_checkpoint(path=tmp_path, state={})
@RunIf(min_torch="2.0.0")
@mock.patch("lightning.fabric.strategies.fsdp.FSDPStrategy.broadcast", lambda _, x: x)
@mock.patch("lightning.fabric.strategies.fsdp._lazy_load", Mock())
def test_fsdp_load_checkpoint_one_fsdp_module_required(tmp_path):
"""Test that the FSDP strategy can only load one FSDP model per checkpoint."""
strategy = FSDPStrategy()
# missing FSDP model
with pytest.raises(ValueError, match="Could not find a FSDP model in the provided checkpoint state."):
strategy.load_checkpoint(path=tmp_path, state={"other": "data"})
with pytest.raises(ValueError, match="Could not find a FSDP model in the provided checkpoint state."):
strategy.load_checkpoint(path=tmp_path, state={"model": torch.nn.Linear(3, 3)})
# multiple FSDP models
model1 = Mock(spec=FullyShardedDataParallel)
model1.modules.return_value = [model1]
model2 = Mock(spec=FullyShardedDataParallel)
model2.modules.return_value = [model2]
with pytest.raises(ValueError, match="Found multiple FSDP models in the given state."):
strategy.load_checkpoint(path=tmp_path, state={"model1": model1, "model2": model2})
# A raw nn.Module instead of a dictionary is ok
model = Mock(spec=nn.Module)
path = tmp_path / "full.ckpt"
path.touch()
strategy.load_checkpoint(path=path, state=model)
@RunIf(min_torch="2.0.0")
@mock.patch("lightning.fabric.strategies.fsdp.FSDPStrategy.broadcast", lambda _, x: x)
def test_fsdp_save_checkpoint_unknown_state_dict_type(tmp_path):
strategy = FSDPStrategy(state_dict_type="invalid")
model = Mock(spec=FullyShardedDataParallel)
model.modules.return_value = [model]
with pytest.raises(ValueError, match="Unknown state_dict_type"):
strategy.save_checkpoint(path=tmp_path, state={"model": model})
@RunIf(min_torch="2.0.0")
def test_fsdp_load_unknown_checkpoint_type(tmp_path):
"""Test that the strategy validates the contents at the checkpoint path."""
strategy = FSDPStrategy()
model = Mock(spec=FullyShardedDataParallel)
model.modules.return_value = [model]
path = tmp_path / "empty_dir" # neither a single file nor a directory with meta file
path.mkdir()
with pytest.raises(ValueError, match="does not point to a valid checkpoint"):
strategy.load_checkpoint(path=path, state={"model": model})
@RunIf(min_torch="2.0.0")
def test_fsdp_load_raw_checkpoint_validate_single_file(tmp_path):
"""Test that we validate the given checkpoint is a single file when loading a raw PyTorch state-dict checkpoint."""
strategy = FSDPStrategy()
model = Mock(spec=nn.Module)
path = tmp_path / "folder"
path.mkdir()
with pytest.raises(ValueError, match="The given path must be a single file containing the full state dict"):
strategy.load_checkpoint(path=path, state=model)
@RunIf(min_torch="2.0.0")
def test_fsdp_load_raw_checkpoint_optimizer_unsupported(tmp_path):
"""Validate that the FSDP strategy does not yet support loading the raw PyTorch state-dict for an optimizer."""
strategy = FSDPStrategy()
optimizer = Mock(spec=torch.optim.Optimizer)
with pytest.raises(
NotImplementedError, match="Loading a single optimizer object from a checkpoint is not supported"
):
strategy.load_checkpoint(path=tmp_path, state=optimizer)
@mock.patch("torch.distributed.init_process_group")
def test_set_timeout(init_process_group_mock):
"""Test that the timeout gets passed to the ``torch.distributed.init_process_group`` function."""
test_timedelta = timedelta(seconds=30)
strategy = FSDPStrategy(timeout=test_timedelta, parallel_devices=[torch.device("cpu")])
strategy.cluster_environment = LightningEnvironment()
strategy.accelerator = Mock()
strategy.setup_environment()
process_group_backend = strategy._get_process_group_backend()
global_rank = strategy.cluster_environment.global_rank()
world_size = strategy.cluster_environment.world_size()
init_process_group_mock.assert_called_with(
process_group_backend, rank=global_rank, world_size=world_size, timeout=test_timedelta
)
def test_has_meta_device_parameters():
"""Test that the `_has_meta_device_parameters` function can find meta-device parameters in models and
optimizers."""
# nn.Module
module = nn.Linear(2, 2)
meta_module = nn.Linear(2, 2, device="meta")
assert not _has_meta_device_parameters(module)
assert _has_meta_device_parameters(meta_module)
assert _has_meta_device_parameters(nn.Sequential(module, meta_module, nn.ReLU()))
# optim.Optimizer
optimizer = torch.optim.SGD(module.parameters(), lr=0.1)
meta_optimizer = torch.optim.SGD(meta_module.parameters(), lr=0.1)
assert not _has_meta_device_parameters(optimizer)
assert _has_meta_device_parameters(meta_optimizer)
# unsupported objects
with pytest.raises(TypeError, match="Expected `torch.nn.Module` or `torch.optim.Optimizer`"):
_has_meta_device_parameters(None)
@RunIf(min_torch="2.0")
@pytest.mark.parametrize("torch_ge_2_1", [True, False])
@mock.patch("torch.distributed.fsdp.fully_sharded_data_parallel.FullyShardedDataParallel.set_state_dict_type")
def test_get_full_state_dict_context_offload(set_type_mock, monkeypatch, torch_ge_2_1):
"""Test that the state dict context manager handles CPU offloading depending on the PyTorch version."""
monkeypatch.setattr("lightning.fabric.strategies.fsdp._TORCH_GREATER_EQUAL_2_1", torch_ge_2_1)
with _get_full_state_dict_context(module=Mock(spec=FullyShardedDataParallel), world_size=1):
assert set_type_mock.call_args_list[0][0][2].offload_to_cpu is torch_ge_2_1 # model config
assert set_type_mock.call_args_list[0][0][3].offload_to_cpu is torch_ge_2_1 # optim config
set_type_mock.reset_mock()
with _get_full_state_dict_context(module=Mock(spec=FullyShardedDataParallel), world_size=4):
assert set_type_mock.call_args_list[0][0][2].offload_to_cpu # model config
assert set_type_mock.call_args_list[0][0][3].offload_to_cpu # optim config