lightning/tests/tests_fabric/strategies/test_fsdp_integration.py

496 lines
21 KiB
Python

# Copyright The Lightning AI team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
from copy import deepcopy
from pathlib import Path
from unittest import mock
import pytest
import torch
from torch.nn import Parameter
from lightning.fabric import Fabric
from lightning.fabric.plugins import FSDPPrecision
from lightning.fabric.strategies import FSDPStrategy
from lightning.fabric.utilities.imports import (
_TORCH_GREATER_EQUAL_1_12,
_TORCH_GREATER_EQUAL_2_0,
_TORCH_GREATER_EQUAL_2_1,
)
from lightning.fabric.wrappers import _FabricOptimizer
from tests_fabric.helpers.models import BoringFabric
from tests_fabric.helpers.runif import RunIf
from tests_fabric.test_fabric import BoringModel
if _TORCH_GREATER_EQUAL_1_12:
from torch.distributed.fsdp import FlatParameter, FullyShardedDataParallel, OptimStateKeyType
from torch.distributed.fsdp.wrap import always_wrap_policy, wrap
class _MyFabric(BoringFabric):
def get_model(self):
model = torch.nn.Sequential(torch.nn.Linear(32, 32), torch.nn.ReLU(), torch.nn.Linear(32, 2))
self.num_wrapped = 4
return model
def step(self, model, batch):
wrapped_layers = [m for m in model.modules() if isinstance(m, FullyShardedDataParallel)]
assert len(wrapped_layers) == self.num_wrapped
assert (self.num_wrapped == 4) == isinstance(model._forward_module, FullyShardedDataParallel)
precision = self._precision
assert isinstance(precision, FSDPPrecision)
if precision.precision == "16-mixed":
param_dtype = torch.float32
reduce_dtype = buffer_dtype = torch.float16
elif precision.precision == "bf16-mixed":
param_dtype = torch.float32
reduce_dtype = buffer_dtype = torch.bfloat16
elif precision.precision == "16-true":
param_dtype = reduce_dtype = buffer_dtype = torch.float16
elif precision.precision == "bf16-true":
param_dtype = reduce_dtype = buffer_dtype = torch.bfloat16
else:
raise ValueError(f"Unknown precision {precision.precision}")
for layer in wrapped_layers:
assert layer.mixed_precision.param_dtype == param_dtype
assert layer.mixed_precision.reduce_dtype == reduce_dtype
assert layer.mixed_precision.buffer_dtype == buffer_dtype
output = model(batch)
return torch.nn.functional.mse_loss(output, torch.ones_like(output))
class _MyFabricManualWrapping(_MyFabric):
def get_model(self):
model = super().get_model()
for i, layer in enumerate(model):
if i % 2 == 0:
model[i] = wrap(layer)
self.num_wrapped = 2
return model
@RunIf(min_cuda_gpus=2, standalone=True, min_torch="2.0.0")
@pytest.mark.parametrize("precision", ["16-mixed", pytest.param("bf16-mixed", marks=RunIf(bf16_cuda=True))])
@pytest.mark.parametrize("manual_wrapping", [True, False])
def test_fsdp_train_save_load(tmp_path, manual_wrapping, precision):
"""Test FSDP training, saving and loading with different wrapping and precision settings."""
fabric_cls = _MyFabricManualWrapping if manual_wrapping else _MyFabric
fabric = fabric_cls(
accelerator="cuda", strategy=FSDPStrategy(auto_wrap_policy=always_wrap_policy), devices=2, precision=precision
)
fabric.run()
checkpoint_path = fabric.broadcast(str(tmp_path / "fsdp-checkpoint"))
params_before = deepcopy(list(fabric.model.parameters()))
state = {"model": fabric.model, "optimizer": fabric.optimizer, "steps": 1}
fabric.save(checkpoint_path, state)
assert set(os.listdir(checkpoint_path)) == {"meta.pt", ".metadata", "__0_0.distcp", "__1_0.distcp"}
# re-init all objects and resume
fabric = fabric_cls(
accelerator="cuda", strategy=FSDPStrategy(auto_wrap_policy=always_wrap_policy), devices=2, precision=precision
)
fabric.run()
# check correctness with loaded state
state = {"model": fabric.model, "optimizer": fabric.optimizer, "steps": 0}
metadata = fabric.load(checkpoint_path, state)
for p0, p1 in zip(params_before, fabric.model.parameters()):
torch.testing.assert_close(p0, p1, atol=0, rtol=0, equal_nan=True)
# check user data in state reloaded
assert state["steps"] == 1
assert not metadata
# attempt to load a key not in the metadata checkpoint
state = {"model": fabric.model, "coconut": 11}
with pytest.raises(KeyError, match="The requested state contains a key 'coconut' that does not exist"):
fabric.load(checkpoint_path, state)
# `strict=False` ignores the missing key
state = {"model": fabric.model, "coconut": 11}
fabric.load(checkpoint_path, state, strict=False)
assert state["coconut"] == 11
@RunIf(min_cuda_gpus=2, standalone=True, min_torch="2.0.0")
def test_fsdp_save_full_state_dict(tmp_path):
"""Test that FSDP saves the full state into a single file with `state_dict_type="full"`."""
fabric = BoringFabric(
accelerator="cuda",
strategy=FSDPStrategy(auto_wrap_policy=always_wrap_policy, state_dict_type="full"),
devices=2,
)
fabric.run()
checkpoint_path = Path(fabric.broadcast(str(tmp_path / "fsdp-checkpoint.pt")))
state = {"model": fabric.model, "optimizer": fabric.optimizer, "steps": 1}
fabric.save(checkpoint_path, state)
checkpoint = torch.load(checkpoint_path)
assert checkpoint["steps"] == 1
loaded_state_dict = checkpoint["model"]
# assert the correct state model was saved
with FullyShardedDataParallel.summon_full_params(fabric.model):
state_dict = fabric.model.state_dict()
assert set(loaded_state_dict.keys()) == set(state_dict.keys())
for param_name in state_dict:
assert torch.equal(loaded_state_dict[param_name], state_dict[param_name].cpu())
params_before = [p.cpu() for p in fabric.model.parameters()]
# assert the correct optimizer state was saved
optimizer_state_before = FullyShardedDataParallel.full_optim_state_dict(
fabric.model, fabric.optimizer, rank0_only=False
)
assert set(checkpoint["optimizer"].keys()) == set(optimizer_state_before.keys()) == {"state", "param_groups"}
# 1. verify the FSDP state can be loaded back into a FSDP model/strategy directly
fabric = BoringFabric(
accelerator="cuda",
strategy=FSDPStrategy(auto_wrap_policy=always_wrap_policy),
devices=2,
)
fabric.run()
metadata = fabric.load(checkpoint_path, {"model": fabric.model, "optimizer": fabric.optimizer})
assert metadata == {"steps": 1}
with FullyShardedDataParallel.summon_full_params(fabric.model):
params_after = list(fabric.model.parameters())
assert all(torch.equal(p0.cpu(), p1.cpu()) for p0, p1 in zip(params_before, params_after))
# assert the correct optimizer state was loaded
optimizer_state_after = FullyShardedDataParallel.full_optim_state_dict(
fabric.model, fabric.optimizer, rank0_only=False
)
assert set(optimizer_state_after.keys()) == set(optimizer_state_before.keys()) == {"state", "param_groups"}
torch.testing.assert_close(optimizer_state_after["state"], optimizer_state_before["state"], atol=0, rtol=0)
assert optimizer_state_after["param_groups"] == optimizer_state_before["param_groups"]
# run a step to verify the optimizer state is correct
fabric.run()
# 2. verify the FSDP state can be loaded back into a single-device model/strategy
fabric = BoringFabric(accelerator="cpu", devices=1)
fabric.run()
metadata = fabric.load(checkpoint_path, {"model": fabric.model, "optimizer": fabric.optimizer})
assert metadata == {"steps": 1}
params_after = list(fabric.model.parameters())
assert all(torch.equal(p0, p1) for p0, p1 in zip(params_before, params_after))
# get optimizer state after loading
normal_checkpoint_path = Path(fabric.broadcast(str(tmp_path / "normal-checkpoint.pt")))
fabric.save(normal_checkpoint_path, {"model": fabric.model, "optimizer": fabric.optimizer, "steps": 2})
optimizer_state_after = torch.load(normal_checkpoint_path)["optimizer"]
optimizer_state_after = FullyShardedDataParallel.rekey_optim_state_dict(
optimizer_state_after, optim_state_key_type=OptimStateKeyType.PARAM_NAME, model=fabric.model
)
# assert the correct optimizer state was loaded
assert set(optimizer_state_after.keys()) == set(optimizer_state_before.keys()) == {"state", "param_groups"}
torch.testing.assert_close(optimizer_state_after["state"], optimizer_state_before["state"], atol=0, rtol=0)
# run a step to verify the optimizer state is correct
fabric.run()
# 3. verify that a single-device model/strategy states can be loaded into a FSDP model/strategy
fabric = BoringFabric(
accelerator="cuda",
strategy=FSDPStrategy(auto_wrap_policy=always_wrap_policy),
devices=2,
)
fabric.run()
metadata = fabric.load(normal_checkpoint_path, {"model": fabric.model, "optimizer": fabric.optimizer})
assert metadata == {"steps": 2}
with FullyShardedDataParallel.summon_full_params(fabric.model):
params_after = list(fabric.model.parameters())
assert all(torch.equal(p0.cpu(), p1.cpu()) for p0, p1 in zip(params_before, params_after))
# assert the correct optimizer state was loaded
optimizer_state_after = FullyShardedDataParallel.full_optim_state_dict(
fabric.model, fabric.optimizer, rank0_only=False
)
assert set(optimizer_state_after.keys()) == set(optimizer_state_before.keys()) == {"state", "param_groups"}
torch.testing.assert_close(optimizer_state_after["state"], optimizer_state_before["state"], atol=0, rtol=0)
assert optimizer_state_after["param_groups"] == optimizer_state_before["param_groups"]
# run a step to verify the optimizer state is correct
fabric.run()
@RunIf(min_cuda_gpus=2, standalone=True, min_torch="2.0.0")
def test_fsdp_load_full_state_dict_into_sharded_model(tmp_path):
"""Test that the strategy can load a full-state checkpoint into a FSDP sharded model."""
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
fabric = BoringFabric(accelerator="cuda", devices=1)
fabric.seed_everything(0)
fabric.run()
# Save a full-state-dict checkpoint
checkpoint_path = Path(fabric.broadcast(str(tmp_path / "full-checkpoint.pt")))
state = {"model": fabric.model, "optimizer": fabric.optimizer, "steps": 1}
fabric.save(checkpoint_path, state)
# Gather all weights and store a copy manually
with FSDP.summon_full_params(fabric.model, writeback=False, rank0_only=False):
params_before = torch.cat([p.cpu().view(-1) for p in fabric.model.parameters()])
# Create a FSDP sharded model
fabric = BoringFabric(
accelerator="cuda",
strategy=FSDPStrategy(auto_wrap_policy=always_wrap_policy),
devices=2,
)
fabric.run()
state = {"model": fabric.model, "optimizer": fabric.optimizer, "steps": 44}
fabric.load(checkpoint_path, state)
assert state["steps"] == 1
# Gather all weights and compare
with FSDP.summon_full_params(fabric.model, writeback=False, rank0_only=False):
params_after = torch.cat([p.cpu().view(-1) for p in fabric.model.parameters()])
assert torch.equal(params_before, params_after)
# Create a raw state-dict checkpoint to test `Fabric.load_raw` too
raw_checkpoint_path = checkpoint_path.with_name("model-state-dict")
if fabric.global_rank == 0:
checkpoint = torch.load(checkpoint_path)
torch.save(checkpoint["model"], raw_checkpoint_path)
fabric.barrier()
fabric.run()
fabric.load_raw(raw_checkpoint_path, fabric.model)
# Gather all weights and compare
with FSDP.summon_full_params(fabric.model, writeback=False, rank0_only=False):
params_after = torch.cat([p.cpu().view(-1) for p in fabric.model.parameters()])
assert torch.equal(params_before, params_after)
@RunIf(min_cuda_gpus=2, skip_windows=True, standalone=True, min_torch="1.12")
@pytest.mark.parametrize("move_to_device", [True, False])
@mock.patch("lightning.fabric.wrappers._FabricModule")
def test_setup_module_move_to_device(fabric_module_mock, move_to_device):
"""Test that `move_to_device` does nothing, FSDP decides which device parameters get moved to which device
(sharding)."""
strategy = FSDPStrategy(auto_wrap_policy=always_wrap_policy)
fabric = Fabric(accelerator="cuda", devices=2, strategy=strategy)
fabric.launch()
model = torch.nn.Linear(10, 10, bias=False) # total params: 10 * 10 = 100
fabric_model = fabric.setup_module(model, move_to_device=move_to_device)
fabric_module_mock.assert_not_called()
assert len(list(fabric_model.parameters())) == 1
# the linear layer got sharded and each part is on the expected device
assert next(fabric_model.parameters()).device == torch.device("cuda", fabric.local_rank)
assert next(fabric_model.parameters()).numel() == 50
if _TORCH_GREATER_EQUAL_2_0:
# In PyTorch >= 2.0 we set `use_orig_params=True` and don't see flattened parameters
assert isinstance(next(fabric_model.parameters()), Parameter)
else:
assert isinstance(next(fabric_model.parameters()), FlatParameter)
# The _DeviceDtypeModuleMixin currently can't represent the device in a meaningful way for sharded models
assert fabric_model.device == torch.device("cpu")
assert fabric.device == torch.device("cuda", fabric.local_rank)
@RunIf(min_cuda_gpus=2, skip_windows=True, standalone=True, min_torch="2.0.0")
def test_setup_with_orig_params_and_multiple_param_groups():
"""Test that Fabric sets `use_orig_params` for the user when jointly setting up model and optimizer."""
strategy = FSDPStrategy(auto_wrap_policy=always_wrap_policy)
fabric = Fabric(accelerator="cuda", devices=2, strategy=strategy)
fabric.launch()
model = torch.nn.Sequential(
torch.nn.Linear(10, 10, bias=False),
torch.nn.Linear(5, 2, bias=False),
)
optimizer = torch.optim.Adam(
[
{"params": model[0].parameters(), "lr": 1e-2},
{"params": model[1].parameters(), "lr": 1e-6},
]
)
# set up model and optimizer jointly
wrapped_model, wrapped_optimizer = fabric.setup(model, optimizer)
assert fabric.strategy._fsdp_kwargs["use_orig_params"]
assert isinstance(wrapped_optimizer, _FabricOptimizer)
assert len(wrapped_optimizer.param_groups) == 2
for i in range(2):
layer = wrapped_model._forward_module.module[i]
assert isinstance(layer, FullyShardedDataParallel)
assert torch.equal(wrapped_optimizer.param_groups[i]["params"][0], layer.weight)
# A regular parameter as a view into the flattened parameters
assert isinstance(layer.weight, torch.nn.Parameter)
assert not isinstance(layer.weight, FlatParameter)
@RunIf(min_cuda_gpus=2, skip_windows=True, standalone=True, dynamo=True)
@mock.patch.dict(os.environ, {})
@pytest.mark.parametrize(
"compile_after_setup",
[
False,
# https://github.com/pytorch/pytorch/issues/97811
pytest.param(True, marks=RunIf(min_python="3.9")),
],
)
def test_compile(compile_after_setup):
"""Test that the model can be compiled before and after the model is wrapped in FSDP."""
model = BoringModel()
strategy = FSDPStrategy(auto_wrap_policy=always_wrap_policy)
fabric = Fabric(accelerator="cuda", devices=2, strategy=strategy)
fabric.launch()
if not compile_after_setup:
model = torch.compile(model)
model = fabric.setup(model)
if compile_after_setup:
model = torch.compile(model)
for _ in range(3):
model(torch.rand(2, 32, device=fabric.device)).sum().backward()
@RunIf(min_cuda_gpus=2, skip_windows=True, standalone=True)
@pytest.mark.parametrize(
("precision", "expected_dtype"),
[
("32-true", torch.float32),
("16-true", torch.float16),
pytest.param("bf16-true", torch.bfloat16, marks=RunIf(bf16_cuda=True)),
],
)
def test_module_init_context(precision, expected_dtype):
"""Test that the module under the init-context gets moved to the right device and dtype."""
fabric = Fabric(
accelerator="cuda",
devices=2,
strategy=FSDPStrategy(auto_wrap_policy=always_wrap_policy),
precision=precision,
)
fabric.launch()
def _run_setup_assertions(empty_init, expected_device):
with fabric.init_module(empty_init=empty_init):
model = torch.nn.Linear(100, 100, bias=False)
# The model is on the CPU/meta-device until after `.setup()``
assert model.weight.device == expected_device
assert model.weight.dtype == expected_dtype
model = fabric.setup(model)
# Parameters get sharded in `.setup()` and moved to the target device
assert model.weight.device == torch.device("cuda", fabric.local_rank)
assert model.weight.dtype == expected_dtype
# Case 1: No empty init
_run_setup_assertions(empty_init=False, expected_device=torch.device("cpu"))
if _TORCH_GREATER_EQUAL_2_1:
# Case 2: Empty-init with PyTorch >= 2.1 supports meta device
_run_setup_assertions(empty_init=True, expected_device=torch.device("meta"))
else:
# Case 2: Empty-init with PyTorch < 2.1 only supports `torch.empty()`-init
_run_setup_assertions(empty_init=True, expected_device=torch.device("cpu"))
@RunIf(min_cuda_gpus=2, standalone=True, min_torch="2.0.0")
def test_fsdp_save_filter(tmp_path):
fabric = BoringFabric(accelerator="cuda", strategy=FSDPStrategy(state_dict_type="full"), devices=2)
fabric.launch()
model = fabric.get_model()
model = fabric.setup_module(model)
tmp_path = Path(fabric.broadcast(str(tmp_path)))
state = {"model": model}
filter = {"model": lambda k, v: "bias" in k}
checkpoint_path = tmp_path / "full.pth"
fabric.save(checkpoint_path, state, filter=filter)
checkpoint = torch.load(checkpoint_path)["model"]
assert set(checkpoint) == {"bias"}
assert isinstance(checkpoint["bias"], torch.Tensor)
fabric.strategy._state_dict_type = "sharded"
checkpoint_path = tmp_path / "sharded"
with pytest.raises(NotImplementedError, match="doesn't support loading sharded filtered"):
fabric.save(checkpoint_path, state, filter=filter)
@RunIf(min_torch="1.13", min_cuda_gpus=1)
def test_fsdp_manual_activation_checkpointing():
model = torch.nn.Sequential(torch.nn.Linear(1, 1), torch.nn.Linear(1, 1))
strategy = FSDPStrategy(activation_checkpointing_policy={torch.nn.Linear})
fabric = Fabric(devices=1, accelerator="cuda", strategy=strategy)
fabric.launch()
from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (
apply_activation_checkpointing,
CheckpointWrapper,
)
# manually apply activation checkpointing
apply_activation_checkpointing(model)
wrappers = {name for name, mod in model.named_modules() if isinstance(mod, CheckpointWrapper)}
assert wrappers == {"0", "1"}
# let fabric set up the model, it shouldn't apply activation checkpointing again
with pytest.warns(match="is configured, but the model already contains checkpointed"):
model = fabric.setup(model)
wrappers = {name for name, mod in model._forward_module.named_modules() if isinstance(mod, CheckpointWrapper)}
assert wrappers == {"_fsdp_wrapped_module.0", "_fsdp_wrapped_module.1"}
@RunIf(min_torch="1.12", min_cuda_gpus=1)
def test_rewrap_warnings():
from torch.distributed.fsdp import FullyShardedDataParallel
from torch.distributed.fsdp.wrap import wrap
strategy = FSDPStrategy(auto_wrap_policy={torch.nn.Linear})
fabric = Fabric(devices=1, accelerator="cuda", strategy=strategy)
fabric.launch()
with fabric.init_module():
model = torch.nn.Sequential(torch.nn.Linear(1, 1), torch.nn.ReLU(), wrap(torch.nn.Linear(1, 1)))
with pytest.warns(match="the model is already wrapped"):
model = fabric.setup(model)
assert not isinstance(model._forward_module, FullyShardedDataParallel)
assert isinstance(model._forward_module[2], FullyShardedDataParallel)
if not _TORCH_GREATER_EQUAL_2_1:
return
with fabric.init_module(empty_init=True):
model = torch.nn.Sequential(torch.nn.Linear(1, 1), torch.nn.ReLU(), wrap(torch.nn.Linear(1, 1)))
assert model[0].weight.is_meta
with pytest.warns(match="there are still parameters on the meta device"):
fabric_model = fabric.setup(model)
assert next(fabric_model.parameters()).is_meta