lightning/tests/tests_fabric/strategies/test_fsdp_integration.py

672 lines
27 KiB
Python

# Copyright The Lightning AI team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
from copy import deepcopy
from pathlib import Path
from unittest import mock
from unittest.mock import Mock
import pytest
import torch
import torch.nn as nn
from lightning.fabric import Fabric
from lightning.fabric.plugins import FSDPPrecision
from lightning.fabric.strategies import FSDPStrategy
from lightning.fabric.utilities.imports import (
_TORCH_GREATER_EQUAL_2_0,
_TORCH_GREATER_EQUAL_2_1,
)
from lightning.fabric.utilities.load import _load_distributed_checkpoint
from lightning.fabric.wrappers import _FabricOptimizer
from torch.distributed.fsdp import FlatParameter, FullyShardedDataParallel, OptimStateKeyType
from torch.distributed.fsdp.wrap import always_wrap_policy, wrap
from torch.nn import Parameter
from torch.utils.data import DataLoader
from tests_fabric.helpers.datasets import RandomDataset
from tests_fabric.helpers.runif import RunIf
from tests_fabric.test_fabric import BoringModel
class BasicTrainer:
"""Implements a basic training loop for the end-to-end tests."""
def __init__(self, fabric):
self.fabric = fabric
self.model = self.optimizer = self.dataloader = None
def get_model(self):
return nn.Linear(32, 2)
def step(self, model, batch):
output = model(batch)
return torch.nn.functional.mse_loss(output, torch.ones_like(output))
def run(self) -> None:
with self.fabric.init_module():
model = self.get_model()
optimizer = torch.optim.Adam(model.parameters(), lr=0.1)
model, optimizer = self.fabric.setup(model, optimizer)
dataloader = DataLoader(RandomDataset(32, 64))
dataloader = self.fabric.setup_dataloaders(dataloader)
self.model = model
self.optimizer = optimizer
self.dataloader = dataloader
model.train()
data_iter = iter(dataloader)
batch = next(data_iter)
loss = self.step(model, batch)
self.fabric.backward(loss)
optimizer.step()
optimizer.zero_grad()
class _Trainer(BasicTrainer):
def get_model(self):
model = torch.nn.Sequential(torch.nn.Linear(32, 32), torch.nn.ReLU(), torch.nn.Linear(32, 2))
self.num_wrapped = 4
return model
def step(self, model, batch):
wrapped_layers = [m for m in model.modules() if isinstance(m, FullyShardedDataParallel)]
assert len(wrapped_layers) == self.num_wrapped
assert (self.num_wrapped == 4) == isinstance(model._forward_module, FullyShardedDataParallel)
precision = self.fabric._precision
assert isinstance(precision, FSDPPrecision)
if precision.precision == "16-mixed":
param_dtype = torch.float32
reduce_dtype = buffer_dtype = torch.float16
elif precision.precision == "bf16-mixed":
param_dtype = torch.float32
reduce_dtype = buffer_dtype = torch.bfloat16
elif precision.precision == "16-true":
param_dtype = reduce_dtype = buffer_dtype = torch.float16
elif precision.precision == "bf16-true":
param_dtype = reduce_dtype = buffer_dtype = torch.bfloat16
else:
raise ValueError(f"Unknown precision {precision.precision}")
for layer in wrapped_layers:
assert layer.mixed_precision.param_dtype == param_dtype
assert layer.mixed_precision.reduce_dtype == reduce_dtype
assert layer.mixed_precision.buffer_dtype == buffer_dtype
output = model(batch)
return torch.nn.functional.mse_loss(output, torch.ones_like(output))
class _TrainerManualWrapping(_Trainer):
def get_model(self):
model = super().get_model()
for i, layer in enumerate(model):
if i % 2 == 0:
model[i] = wrap(layer)
self.num_wrapped = 2
return model
@RunIf(min_cuda_gpus=2, standalone=True, min_torch="2.0.0")
@pytest.mark.parametrize("precision", ["16-mixed", pytest.param("bf16-mixed", marks=RunIf(bf16_cuda=True))])
@pytest.mark.parametrize("manual_wrapping", [True, False])
def test_fsdp_train_save_load(tmp_path, manual_wrapping, precision):
"""Test FSDP training, saving and loading with different wrapping and precision settings."""
trainer_cls = _TrainerManualWrapping if manual_wrapping else _Trainer
fabric = Fabric(
accelerator="cuda",
strategy=FSDPStrategy(auto_wrap_policy=always_wrap_policy),
devices=2,
precision=precision,
)
fabric.launch()
trainer = trainer_cls(fabric)
trainer.run()
checkpoint_path = fabric.broadcast(str(tmp_path / "fsdp-checkpoint"))
params_before = deepcopy(list(trainer.model.parameters()))
state = {"model": trainer.model, "optimizer": trainer.optimizer, "steps": 1}
fabric.save(checkpoint_path, state)
assert set(os.listdir(checkpoint_path)) == {"meta.pt", ".metadata", "__0_0.distcp", "__1_0.distcp"}
# re-init all objects and resume
fabric = Fabric(
accelerator="cuda",
strategy=FSDPStrategy(auto_wrap_policy=always_wrap_policy),
devices=2,
precision=precision,
)
fabric.launch()
trainer = trainer_cls(fabric)
trainer.run()
# check correctness with loaded state
state = {"model": trainer.model, "optimizer": trainer.optimizer, "steps": 0}
metadata = fabric.load(checkpoint_path, state)
for p0, p1 in zip(params_before, trainer.model.parameters()):
torch.testing.assert_close(p0, p1, atol=0, rtol=0, equal_nan=True)
# check user data in state reloaded
assert state["steps"] == 1
assert not metadata
# attempt to load a key not in the metadata checkpoint
state = {"model": trainer.model, "coconut": 11}
with pytest.raises(KeyError, match="The requested state contains a key 'coconut' that does not exist"):
fabric.load(checkpoint_path, state)
# `strict=False` ignores the missing key
state = {"model": trainer.model, "coconut": 11}
fabric.load(checkpoint_path, state, strict=False)
assert state["coconut"] == 11
@RunIf(min_cuda_gpus=2, standalone=True, min_torch="2.0.0")
def test_fsdp_save_full_state_dict(tmp_path):
"""Test that FSDP saves the full state into a single file with `state_dict_type="full"`."""
fabric = Fabric(
accelerator="cuda",
strategy=FSDPStrategy(auto_wrap_policy=always_wrap_policy, state_dict_type="full"),
devices=2,
)
fabric.launch()
trainer = BasicTrainer(fabric)
trainer.run()
checkpoint_path = Path(fabric.broadcast(str(tmp_path / "fsdp-checkpoint.pt")))
state = {"model": trainer.model, "optimizer": trainer.optimizer, "steps": 1}
fabric.save(checkpoint_path, state)
checkpoint = torch.load(checkpoint_path)
assert checkpoint["steps"] == 1
loaded_state_dict = checkpoint["model"]
# assert the correct state model was saved
with FullyShardedDataParallel.summon_full_params(trainer.model):
state_dict = trainer.model.state_dict()
assert set(loaded_state_dict.keys()) == set(state_dict.keys())
for param_name in state_dict:
assert torch.equal(loaded_state_dict[param_name], state_dict[param_name].cpu())
params_before = [p.cpu() for p in trainer.model.parameters()]
# assert the correct optimizer state was saved
optimizer_state_before = FullyShardedDataParallel.full_optim_state_dict(
trainer.model, trainer.optimizer, rank0_only=False
)
assert set(checkpoint["optimizer"].keys()) == set(optimizer_state_before.keys()) == {"state", "param_groups"}
# 1. verify the FSDP state can be loaded back into a FSDP model/strategy directly
fabric = Fabric(
accelerator="cuda",
strategy=FSDPStrategy(auto_wrap_policy=always_wrap_policy),
devices=2,
)
fabric.launch()
trainer = BasicTrainer(fabric)
trainer.run()
metadata = fabric.load(checkpoint_path, {"model": trainer.model, "optimizer": trainer.optimizer})
assert metadata == {"steps": 1}
with FullyShardedDataParallel.summon_full_params(trainer.model):
params_after = list(trainer.model.parameters())
assert all(torch.equal(p0.cpu(), p1.cpu()) for p0, p1 in zip(params_before, params_after))
# assert the correct optimizer state was loaded
optimizer_state_after = FullyShardedDataParallel.full_optim_state_dict(
trainer.model, trainer.optimizer, rank0_only=False
)
assert set(optimizer_state_after.keys()) == set(optimizer_state_before.keys()) == {"state", "param_groups"}
torch.testing.assert_close(optimizer_state_after["state"], optimizer_state_before["state"], atol=0, rtol=0)
assert optimizer_state_after["param_groups"] == optimizer_state_before["param_groups"]
# run a step to verify the optimizer state is correct
trainer.run()
# 2. verify the FSDP state can be loaded back into a single-device model/strategy
fabric = Fabric(accelerator="cpu", devices=1)
trainer = BasicTrainer(fabric)
trainer.run()
metadata = fabric.load(checkpoint_path, {"model": trainer.model, "optimizer": trainer.optimizer})
assert metadata == {"steps": 1}
params_after = list(trainer.model.parameters())
assert all(torch.equal(p0, p1) for p0, p1 in zip(params_before, params_after))
# get optimizer state after loading
normal_checkpoint_path = Path(fabric.broadcast(str(tmp_path / "normal-checkpoint.pt")))
fabric.save(normal_checkpoint_path, {"model": trainer.model, "optimizer": trainer.optimizer, "steps": 2})
optimizer_state_after = torch.load(normal_checkpoint_path)["optimizer"]
optimizer_state_after = FullyShardedDataParallel.rekey_optim_state_dict(
optimizer_state_after, optim_state_key_type=OptimStateKeyType.PARAM_NAME, model=trainer.model
)
# assert the correct optimizer state was loaded
assert set(optimizer_state_after.keys()) == set(optimizer_state_before.keys()) == {"state", "param_groups"}
torch.testing.assert_close(optimizer_state_after["state"], optimizer_state_before["state"], atol=0, rtol=0)
# run a step to verify the optimizer state is correct
trainer.run()
# 3. verify that a single-device model/strategy states can be loaded into a FSDP model/strategy
fabric = Fabric(
accelerator="cuda",
strategy=FSDPStrategy(auto_wrap_policy=always_wrap_policy),
devices=2,
)
fabric.launch()
trainer = BasicTrainer(fabric)
trainer.run()
metadata = fabric.load(normal_checkpoint_path, {"model": trainer.model, "optimizer": trainer.optimizer})
assert metadata == {"steps": 2}
with FullyShardedDataParallel.summon_full_params(trainer.model):
params_after = list(trainer.model.parameters())
assert all(torch.equal(p0.cpu(), p1.cpu()) for p0, p1 in zip(params_before, params_after))
# assert the correct optimizer state was loaded
optimizer_state_after = FullyShardedDataParallel.full_optim_state_dict(
trainer.model, trainer.optimizer, rank0_only=False
)
assert set(optimizer_state_after.keys()) == set(optimizer_state_before.keys()) == {"state", "param_groups"}
torch.testing.assert_close(optimizer_state_after["state"], optimizer_state_before["state"], atol=0, rtol=0)
assert optimizer_state_after["param_groups"] == optimizer_state_before["param_groups"]
# run a step to verify the optimizer state is correct
trainer.run()
@RunIf(min_cuda_gpus=2, standalone=True, min_torch="2.0.0")
def test_fsdp_load_full_state_dict_into_sharded_model(tmp_path):
"""Test that the strategy can load a full-state checkpoint into a FSDP sharded model."""
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
fabric = Fabric(accelerator="cuda", devices=1)
fabric.seed_everything(0)
trainer = BasicTrainer(fabric)
trainer.run()
# Save a full-state-dict checkpoint
checkpoint_path = Path(fabric.broadcast(str(tmp_path / "full-checkpoint.pt")))
state = {"model": trainer.model, "optimizer": trainer.optimizer, "steps": 1}
fabric.save(checkpoint_path, state)
# Gather all weights and store a copy manually
with FSDP.summon_full_params(trainer.model, writeback=False, rank0_only=False):
params_before = torch.cat([p.cpu().view(-1) for p in trainer.model.parameters()])
# Create a FSDP sharded model
fabric = Fabric(
accelerator="cuda",
strategy=FSDPStrategy(auto_wrap_policy=always_wrap_policy),
devices=2,
)
fabric.launch()
trainer = BasicTrainer(fabric)
trainer.run()
state = {"model": trainer.model, "optimizer": trainer.optimizer, "steps": 44}
fabric.load(checkpoint_path, state)
assert state["steps"] == 1
# Gather all weights and compare
with FSDP.summon_full_params(trainer.model, writeback=False, rank0_only=False):
params_after = torch.cat([p.cpu().view(-1) for p in trainer.model.parameters()])
assert torch.equal(params_before, params_after)
# Create a raw state-dict checkpoint to test `Fabric.load_raw` too
raw_checkpoint_path = checkpoint_path.with_name("model-state-dict")
if fabric.global_rank == 0:
checkpoint = torch.load(checkpoint_path)
torch.save(checkpoint["model"], raw_checkpoint_path)
fabric.barrier()
trainer.run()
fabric.load_raw(raw_checkpoint_path, trainer.model)
# Gather all weights and compare
with FSDP.summon_full_params(trainer.model, writeback=False, rank0_only=False):
params_after = torch.cat([p.cpu().view(-1) for p in trainer.model.parameters()])
assert torch.equal(params_before, params_after)
@RunIf(min_cuda_gpus=2, skip_windows=True, standalone=True)
@pytest.mark.parametrize("move_to_device", [True, False])
@mock.patch("lightning.fabric.wrappers._FabricModule")
def test_setup_module_move_to_device(fabric_module_mock, move_to_device):
"""Test that `move_to_device` does nothing, FSDP decides which device parameters get moved to which device
(sharding)."""
strategy = FSDPStrategy(auto_wrap_policy=always_wrap_policy)
fabric = Fabric(accelerator="cuda", devices=2, strategy=strategy)
fabric.launch()
model = torch.nn.Linear(10, 10, bias=False) # total params: 10 * 10 = 100
fabric_model = fabric.setup_module(model, move_to_device=move_to_device)
fabric_module_mock.assert_not_called()
assert len(list(fabric_model.parameters())) == 1
# the linear layer got sharded and each part is on the expected device
assert next(fabric_model.parameters()).device == torch.device("cuda", fabric.local_rank)
assert next(fabric_model.parameters()).numel() == 50
if _TORCH_GREATER_EQUAL_2_0:
# In PyTorch >= 2.0 we set `use_orig_params=True` and don't see flattened parameters
assert isinstance(next(fabric_model.parameters()), Parameter)
else:
assert isinstance(next(fabric_model.parameters()), FlatParameter)
# The _DeviceDtypeModuleMixin currently can't represent the device in a meaningful way for models with pieces on
# different devices
assert fabric_model.device == torch.device("cuda", fabric.local_rank)
assert fabric.device == torch.device("cuda", fabric.local_rank)
@RunIf(min_cuda_gpus=2, skip_windows=True, standalone=True, min_torch="2.0.0")
def test_setup_with_orig_params_and_multiple_param_groups():
"""Test that Fabric sets `use_orig_params` for the user when jointly setting up model and optimizer."""
strategy = FSDPStrategy(auto_wrap_policy=always_wrap_policy)
fabric = Fabric(accelerator="cuda", devices=2, strategy=strategy)
fabric.launch()
model = torch.nn.Sequential(
torch.nn.Linear(10, 10, bias=False),
torch.nn.Linear(5, 2, bias=False),
)
optimizer = torch.optim.Adam(
[
{"params": model[0].parameters(), "lr": 1e-2},
{"params": model[1].parameters(), "lr": 1e-6},
]
)
# set up model and optimizer jointly
wrapped_model, wrapped_optimizer = fabric.setup(model, optimizer)
assert fabric.strategy._fsdp_kwargs["use_orig_params"]
assert isinstance(wrapped_optimizer, _FabricOptimizer)
assert len(wrapped_optimizer.param_groups) == 2
for i in range(2):
layer = wrapped_model._forward_module.module[i]
assert isinstance(layer, FullyShardedDataParallel)
assert torch.equal(wrapped_optimizer.param_groups[i]["params"][0], layer.weight)
# A regular parameter as a view into the flattened parameters
assert isinstance(layer.weight, torch.nn.Parameter)
assert not isinstance(layer.weight, FlatParameter)
@RunIf(min_cuda_gpus=2, standalone=True, min_torch="2.1.0", dynamo=True, skip_windows=True)
@mock.patch(
"lightning.fabric.wrappers.torch.compile",
Mock(wraps=(torch.compile if _TORCH_GREATER_EQUAL_2_0 else None)),
)
@mock.patch.dict(os.environ, {})
def test_reapply_compile():
"""Test that Fabric can rewrap a compiled module such that compilation happens over the FSDP-wrapper."""
from torch._dynamo import OptimizedModule
strategy = FSDPStrategy(auto_wrap_policy=always_wrap_policy)
fabric = Fabric(accelerator="cuda", devices=2, strategy=strategy)
fabric.launch()
model = BoringModel()
compile_kwargs = {"mode": "reduce-overhead"}
compiled_model = torch.compile(model, **compile_kwargs)
torch.compile.reset_mock()
fabric_model = fabric.setup(compiled_model, _reapply_compile=True)
assert isinstance(fabric_model._forward_module, OptimizedModule)
assert isinstance(fabric_model._forward_module._orig_mod, FullyShardedDataParallel)
# Assert we called compile again with the same arguments, but on the FSDP-wrapped module
torch.compile.assert_called_with(fabric_model._forward_module._orig_mod, **compile_kwargs)
assert fabric_model._original_module == model
assert fabric_model._forward_module._orig_mod.module == model
assert fabric_model.device == fabric.device
# Smoke-testing forward to ensure we don't get compilation errors
for _ in range(3):
fabric_model(torch.randn(2, 32, device=fabric.device)).sum().backward()
@RunIf(min_cuda_gpus=2, skip_windows=True, standalone=True)
@pytest.mark.parametrize(
("precision", "expected_dtype"),
[
("32-true", torch.float32),
("16-true", torch.float16),
pytest.param("bf16-true", torch.bfloat16, marks=RunIf(bf16_cuda=True)),
],
)
def test_module_init_context(precision, expected_dtype):
"""Test that the module under the init-context gets moved to the right device and dtype."""
fabric = Fabric(
accelerator="cuda",
devices=2,
strategy=FSDPStrategy(auto_wrap_policy=always_wrap_policy),
precision=precision,
)
fabric.launch()
def _run_setup_assertions(empty_init, expected_device):
with fabric.init_module(empty_init=empty_init):
model = torch.nn.Linear(100, 100, bias=False)
# The model is on the CPU/meta-device until after `.setup()``
assert model.weight.device == expected_device
assert model.weight.dtype == expected_dtype
model = fabric.setup(model)
# Parameters get sharded in `.setup()` and moved to the target device
assert model.weight.device == torch.device("cuda", fabric.local_rank)
assert model.weight.dtype == expected_dtype
# Case 1: No empty init
_run_setup_assertions(empty_init=False, expected_device=torch.device("cpu"))
if _TORCH_GREATER_EQUAL_2_1:
# Case 2: Empty-init with PyTorch >= 2.1 supports meta device
_run_setup_assertions(empty_init=True, expected_device=torch.device("meta"))
else:
# Case 2: Empty-init with PyTorch < 2.1 only supports `torch.empty()`-init
_run_setup_assertions(empty_init=True, expected_device=torch.device("cpu"))
@RunIf(min_cuda_gpus=2, standalone=True, min_torch="2.0.0")
def test_fsdp_save_filter(tmp_path):
fabric = Fabric(accelerator="cuda", strategy=FSDPStrategy(state_dict_type="full"), devices=2)
fabric.launch()
model = nn.Linear(32, 2)
model = fabric.setup_module(model)
tmp_path = Path(fabric.broadcast(str(tmp_path)))
state = {"model": model}
filter = {"model": lambda k, v: "bias" in k}
checkpoint_path = tmp_path / "full.pth"
fabric.save(checkpoint_path, state, filter=filter)
checkpoint = torch.load(checkpoint_path)["model"]
assert set(checkpoint) == {"bias"}
assert isinstance(checkpoint["bias"], torch.Tensor)
fabric.strategy._state_dict_type = "sharded"
checkpoint_path = tmp_path / "sharded"
with pytest.raises(NotImplementedError, match="doesn't support loading sharded filtered"):
fabric.save(checkpoint_path, state, filter=filter)
@RunIf(min_cuda_gpus=1)
def test_fsdp_manual_activation_checkpointing():
model = torch.nn.Sequential(torch.nn.Linear(1, 1), torch.nn.Linear(1, 1))
strategy = FSDPStrategy(activation_checkpointing_policy={torch.nn.Linear})
fabric = Fabric(devices=1, accelerator="cuda", strategy=strategy)
fabric.launch()
from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (
CheckpointWrapper,
apply_activation_checkpointing,
)
# manually apply activation checkpointing
apply_activation_checkpointing(model)
wrappers = {name for name, mod in model.named_modules() if isinstance(mod, CheckpointWrapper)}
assert wrappers == {"0", "1"}
# let fabric set up the model, it shouldn't apply activation checkpointing again
with pytest.warns(match="is configured, but the model already contains checkpointed"):
model = fabric.setup(model)
wrappers = {name for name, mod in model._forward_module.named_modules() if isinstance(mod, CheckpointWrapper)}
assert wrappers == {"_fsdp_wrapped_module.0", "_fsdp_wrapped_module.1"}
@RunIf(min_cuda_gpus=1)
def test_rewrap_warnings():
from torch.distributed.fsdp import FullyShardedDataParallel
from torch.distributed.fsdp.wrap import wrap
strategy = FSDPStrategy(auto_wrap_policy={torch.nn.Linear})
fabric = Fabric(devices=1, accelerator="cuda", strategy=strategy)
fabric.launch()
with fabric.init_module():
model = torch.nn.Sequential(torch.nn.Linear(1, 1), torch.nn.ReLU(), wrap(torch.nn.Linear(1, 1)))
with pytest.warns(match="the model is already wrapped"):
model = fabric.setup(model)
assert not isinstance(model._forward_module, FullyShardedDataParallel)
assert isinstance(model._forward_module[2], FullyShardedDataParallel)
if not _TORCH_GREATER_EQUAL_2_1:
return
with fabric.init_module(empty_init=True):
model = torch.nn.Sequential(torch.nn.Linear(1, 1), torch.nn.ReLU(), wrap(torch.nn.Linear(1, 1)))
assert model[0].weight.is_meta
with pytest.warns(match="there are still parameters on the meta device"):
fabric_model = fabric.setup(model)
assert next(fabric_model.parameters()).is_meta
@RunIf(min_cuda_gpus=2, standalone=True)
@pytest.mark.parametrize(
"precision",
[
"32-true",
pytest.param("16-mixed"),
pytest.param("bf16-mixed", marks=RunIf(bf16_cuda=True)),
],
)
@pytest.mark.parametrize(
"clip_type",
[
pytest.param("norm", marks=pytest.mark.skip("FSDP gradient clipping by norm is not correct.")),
"val",
],
)
def test_clip_gradients(clip_type, precision):
if clip_type == "norm" and precision == "16-mixed":
pytest.skip(reason="Clipping by norm with 16-mixed is numerically unstable.")
strategy = FSDPStrategy(auto_wrap_policy={torch.nn.Linear})
fabric = Fabric(accelerator="auto", devices=2, precision=precision, strategy=strategy)
fabric.launch()
in_features, out_features = 32, 2
model = torch.nn.Linear(in_features, out_features, bias=False)
model.weight.data.fill_(0.01)
optimizer = torch.optim.Adam(model.parameters(), lr=0.1)
model, optimizer = fabric.setup(model, optimizer)
batch = torch.full((1, in_features), 0.1, device=fabric.device)
loss = model(batch).sum()
# The example is constructed such that the gradients are all the same
fabric.backward(loss)
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
if clip_type == "norm":
with FSDP.summon_full_params(model._forward_module, with_grads=True):
norm = torch.linalg.vector_norm(model.weight.grad.detach().cpu(), 2, dtype=torch.float32).item()
new_norm = norm / 10
fabric.clip_gradients(model, optimizer, max_norm=new_norm * 10)
with FSDP.summon_full_params(model._forward_module, with_grads=True):
assert torch.allclose(
torch.linalg.vector_norm(model.weight.grad.detach().cpu(), 2, dtype=torch.float32),
torch.tensor(new_norm),
)
elif clip_type == "val":
val = model.weight.grad[0].item()
new_val = val / 2.0
fabric.clip_gradients(model, optimizer, clip_val=new_val)
assert torch.allclose(model.weight.grad, torch.full_like(model.weight.grad, new_val))
else:
raise AssertionError(f"Unknown clip type: {clip_type}")
optimizer.step()
optimizer.zero_grad()
# TODO: Support checkpoint consolidation with PyTorch >= 2.2
@RunIf(min_cuda_gpus=2, standalone=True, min_torch="2.1.0", max_torch="2.2.0")
def test_save_sharded_and_consolidate_and_load(tmp_path):
"""Test the consolidation of a FSDP-sharded checkpoint into a single file."""
fabric = Fabric(
accelerator="cuda",
strategy=FSDPStrategy(auto_wrap_policy=always_wrap_policy, state_dict_type="sharded"),
devices=2,
)
fabric.launch()
model = BoringModel()
optimizer = torch.optim.Adam(model.parameters())
model, optimizer = fabric.setup(model, optimizer)
state = {"model": model, "optimizer": optimizer, "steps": 1}
# run one iteration to init the state of the optimizer
model(torch.rand(1, 32, device=fabric.device)).sum().backward()
optimizer.step()
checkpoint_path_sharded = fabric.broadcast(str(tmp_path / "checkpoint_sharded"))
fabric.save(checkpoint_path_sharded, state)
assert set(os.listdir(checkpoint_path_sharded)) == {"meta.pt", ".metadata", "__0_0.distcp", "__1_0.distcp"}
# consolidate the checkpoint to a single file
checkpoint_path_full = fabric.broadcast(str(tmp_path / "checkpoint_full.pt"))
if fabric.global_rank == 0:
checkpoint = _load_distributed_checkpoint(Path(checkpoint_path_sharded))
torch.save(checkpoint, checkpoint_path_full)
fabric.barrier()
# re-init and load from full checkpoint
fabric = Fabric(
accelerator="cuda",
strategy=FSDPStrategy(auto_wrap_policy=always_wrap_policy),
devices=2,
)
# Hack: we already called launch() on another Fabric instance above
fabric._launched = True
model = BoringModel()
optimizer = torch.optim.Adam(model.parameters())
model, optimizer = fabric.setup(model, optimizer)
state = {"model": model, "optimizer": optimizer, "steps": 1}
fabric.load(checkpoint_path_full, state)