672 lines
27 KiB
Python
672 lines
27 KiB
Python
# Copyright The Lightning AI team.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
import os
|
|
from copy import deepcopy
|
|
from pathlib import Path
|
|
from unittest import mock
|
|
from unittest.mock import Mock
|
|
|
|
import pytest
|
|
import torch
|
|
import torch.nn as nn
|
|
from lightning.fabric import Fabric
|
|
from lightning.fabric.plugins import FSDPPrecision
|
|
from lightning.fabric.strategies import FSDPStrategy
|
|
from lightning.fabric.utilities.imports import (
|
|
_TORCH_GREATER_EQUAL_2_0,
|
|
_TORCH_GREATER_EQUAL_2_1,
|
|
)
|
|
from lightning.fabric.utilities.load import _load_distributed_checkpoint
|
|
from lightning.fabric.wrappers import _FabricOptimizer
|
|
from torch.distributed.fsdp import FlatParameter, FullyShardedDataParallel, OptimStateKeyType
|
|
from torch.distributed.fsdp.wrap import always_wrap_policy, wrap
|
|
from torch.nn import Parameter
|
|
from torch.utils.data import DataLoader
|
|
|
|
from tests_fabric.helpers.datasets import RandomDataset
|
|
from tests_fabric.helpers.runif import RunIf
|
|
from tests_fabric.test_fabric import BoringModel
|
|
|
|
|
|
class BasicTrainer:
|
|
"""Implements a basic training loop for the end-to-end tests."""
|
|
|
|
def __init__(self, fabric):
|
|
self.fabric = fabric
|
|
self.model = self.optimizer = self.dataloader = None
|
|
|
|
def get_model(self):
|
|
return nn.Linear(32, 2)
|
|
|
|
def step(self, model, batch):
|
|
output = model(batch)
|
|
return torch.nn.functional.mse_loss(output, torch.ones_like(output))
|
|
|
|
def run(self) -> None:
|
|
with self.fabric.init_module():
|
|
model = self.get_model()
|
|
optimizer = torch.optim.Adam(model.parameters(), lr=0.1)
|
|
model, optimizer = self.fabric.setup(model, optimizer)
|
|
|
|
dataloader = DataLoader(RandomDataset(32, 64))
|
|
dataloader = self.fabric.setup_dataloaders(dataloader)
|
|
|
|
self.model = model
|
|
self.optimizer = optimizer
|
|
self.dataloader = dataloader
|
|
|
|
model.train()
|
|
|
|
data_iter = iter(dataloader)
|
|
batch = next(data_iter)
|
|
loss = self.step(model, batch)
|
|
self.fabric.backward(loss)
|
|
optimizer.step()
|
|
optimizer.zero_grad()
|
|
|
|
|
|
class _Trainer(BasicTrainer):
|
|
def get_model(self):
|
|
model = torch.nn.Sequential(torch.nn.Linear(32, 32), torch.nn.ReLU(), torch.nn.Linear(32, 2))
|
|
self.num_wrapped = 4
|
|
return model
|
|
|
|
def step(self, model, batch):
|
|
wrapped_layers = [m for m in model.modules() if isinstance(m, FullyShardedDataParallel)]
|
|
assert len(wrapped_layers) == self.num_wrapped
|
|
assert (self.num_wrapped == 4) == isinstance(model._forward_module, FullyShardedDataParallel)
|
|
|
|
precision = self.fabric._precision
|
|
assert isinstance(precision, FSDPPrecision)
|
|
if precision.precision == "16-mixed":
|
|
param_dtype = torch.float32
|
|
reduce_dtype = buffer_dtype = torch.float16
|
|
elif precision.precision == "bf16-mixed":
|
|
param_dtype = torch.float32
|
|
reduce_dtype = buffer_dtype = torch.bfloat16
|
|
elif precision.precision == "16-true":
|
|
param_dtype = reduce_dtype = buffer_dtype = torch.float16
|
|
elif precision.precision == "bf16-true":
|
|
param_dtype = reduce_dtype = buffer_dtype = torch.bfloat16
|
|
else:
|
|
raise ValueError(f"Unknown precision {precision.precision}")
|
|
|
|
for layer in wrapped_layers:
|
|
assert layer.mixed_precision.param_dtype == param_dtype
|
|
assert layer.mixed_precision.reduce_dtype == reduce_dtype
|
|
assert layer.mixed_precision.buffer_dtype == buffer_dtype
|
|
|
|
output = model(batch)
|
|
return torch.nn.functional.mse_loss(output, torch.ones_like(output))
|
|
|
|
|
|
class _TrainerManualWrapping(_Trainer):
|
|
def get_model(self):
|
|
model = super().get_model()
|
|
for i, layer in enumerate(model):
|
|
if i % 2 == 0:
|
|
model[i] = wrap(layer)
|
|
self.num_wrapped = 2
|
|
return model
|
|
|
|
|
|
@RunIf(min_cuda_gpus=2, standalone=True, min_torch="2.0.0")
|
|
@pytest.mark.parametrize("precision", ["16-mixed", pytest.param("bf16-mixed", marks=RunIf(bf16_cuda=True))])
|
|
@pytest.mark.parametrize("manual_wrapping", [True, False])
|
|
def test_fsdp_train_save_load(tmp_path, manual_wrapping, precision):
|
|
"""Test FSDP training, saving and loading with different wrapping and precision settings."""
|
|
trainer_cls = _TrainerManualWrapping if manual_wrapping else _Trainer
|
|
fabric = Fabric(
|
|
accelerator="cuda",
|
|
strategy=FSDPStrategy(auto_wrap_policy=always_wrap_policy),
|
|
devices=2,
|
|
precision=precision,
|
|
)
|
|
fabric.launch()
|
|
trainer = trainer_cls(fabric)
|
|
trainer.run()
|
|
|
|
checkpoint_path = fabric.broadcast(str(tmp_path / "fsdp-checkpoint"))
|
|
|
|
params_before = deepcopy(list(trainer.model.parameters()))
|
|
state = {"model": trainer.model, "optimizer": trainer.optimizer, "steps": 1}
|
|
fabric.save(checkpoint_path, state)
|
|
assert set(os.listdir(checkpoint_path)) == {"meta.pt", ".metadata", "__0_0.distcp", "__1_0.distcp"}
|
|
|
|
# re-init all objects and resume
|
|
fabric = Fabric(
|
|
accelerator="cuda",
|
|
strategy=FSDPStrategy(auto_wrap_policy=always_wrap_policy),
|
|
devices=2,
|
|
precision=precision,
|
|
)
|
|
fabric.launch()
|
|
trainer = trainer_cls(fabric)
|
|
trainer.run()
|
|
|
|
# check correctness with loaded state
|
|
state = {"model": trainer.model, "optimizer": trainer.optimizer, "steps": 0}
|
|
metadata = fabric.load(checkpoint_path, state)
|
|
for p0, p1 in zip(params_before, trainer.model.parameters()):
|
|
torch.testing.assert_close(p0, p1, atol=0, rtol=0, equal_nan=True)
|
|
|
|
# check user data in state reloaded
|
|
assert state["steps"] == 1
|
|
assert not metadata
|
|
|
|
# attempt to load a key not in the metadata checkpoint
|
|
state = {"model": trainer.model, "coconut": 11}
|
|
with pytest.raises(KeyError, match="The requested state contains a key 'coconut' that does not exist"):
|
|
fabric.load(checkpoint_path, state)
|
|
|
|
# `strict=False` ignores the missing key
|
|
state = {"model": trainer.model, "coconut": 11}
|
|
fabric.load(checkpoint_path, state, strict=False)
|
|
assert state["coconut"] == 11
|
|
|
|
|
|
@RunIf(min_cuda_gpus=2, standalone=True, min_torch="2.0.0")
|
|
def test_fsdp_save_full_state_dict(tmp_path):
|
|
"""Test that FSDP saves the full state into a single file with `state_dict_type="full"`."""
|
|
fabric = Fabric(
|
|
accelerator="cuda",
|
|
strategy=FSDPStrategy(auto_wrap_policy=always_wrap_policy, state_dict_type="full"),
|
|
devices=2,
|
|
)
|
|
fabric.launch()
|
|
trainer = BasicTrainer(fabric)
|
|
trainer.run()
|
|
|
|
checkpoint_path = Path(fabric.broadcast(str(tmp_path / "fsdp-checkpoint.pt")))
|
|
|
|
state = {"model": trainer.model, "optimizer": trainer.optimizer, "steps": 1}
|
|
fabric.save(checkpoint_path, state)
|
|
|
|
checkpoint = torch.load(checkpoint_path)
|
|
assert checkpoint["steps"] == 1
|
|
loaded_state_dict = checkpoint["model"]
|
|
|
|
# assert the correct state model was saved
|
|
with FullyShardedDataParallel.summon_full_params(trainer.model):
|
|
state_dict = trainer.model.state_dict()
|
|
assert set(loaded_state_dict.keys()) == set(state_dict.keys())
|
|
for param_name in state_dict:
|
|
assert torch.equal(loaded_state_dict[param_name], state_dict[param_name].cpu())
|
|
params_before = [p.cpu() for p in trainer.model.parameters()]
|
|
|
|
# assert the correct optimizer state was saved
|
|
optimizer_state_before = FullyShardedDataParallel.full_optim_state_dict(
|
|
trainer.model, trainer.optimizer, rank0_only=False
|
|
)
|
|
assert set(checkpoint["optimizer"].keys()) == set(optimizer_state_before.keys()) == {"state", "param_groups"}
|
|
|
|
# 1. verify the FSDP state can be loaded back into a FSDP model/strategy directly
|
|
fabric = Fabric(
|
|
accelerator="cuda",
|
|
strategy=FSDPStrategy(auto_wrap_policy=always_wrap_policy),
|
|
devices=2,
|
|
)
|
|
fabric.launch()
|
|
trainer = BasicTrainer(fabric)
|
|
trainer.run()
|
|
metadata = fabric.load(checkpoint_path, {"model": trainer.model, "optimizer": trainer.optimizer})
|
|
assert metadata == {"steps": 1}
|
|
|
|
with FullyShardedDataParallel.summon_full_params(trainer.model):
|
|
params_after = list(trainer.model.parameters())
|
|
assert all(torch.equal(p0.cpu(), p1.cpu()) for p0, p1 in zip(params_before, params_after))
|
|
|
|
# assert the correct optimizer state was loaded
|
|
optimizer_state_after = FullyShardedDataParallel.full_optim_state_dict(
|
|
trainer.model, trainer.optimizer, rank0_only=False
|
|
)
|
|
assert set(optimizer_state_after.keys()) == set(optimizer_state_before.keys()) == {"state", "param_groups"}
|
|
torch.testing.assert_close(optimizer_state_after["state"], optimizer_state_before["state"], atol=0, rtol=0)
|
|
assert optimizer_state_after["param_groups"] == optimizer_state_before["param_groups"]
|
|
|
|
# run a step to verify the optimizer state is correct
|
|
trainer.run()
|
|
|
|
# 2. verify the FSDP state can be loaded back into a single-device model/strategy
|
|
fabric = Fabric(accelerator="cpu", devices=1)
|
|
trainer = BasicTrainer(fabric)
|
|
trainer.run()
|
|
metadata = fabric.load(checkpoint_path, {"model": trainer.model, "optimizer": trainer.optimizer})
|
|
assert metadata == {"steps": 1}
|
|
params_after = list(trainer.model.parameters())
|
|
assert all(torch.equal(p0, p1) for p0, p1 in zip(params_before, params_after))
|
|
|
|
# get optimizer state after loading
|
|
normal_checkpoint_path = Path(fabric.broadcast(str(tmp_path / "normal-checkpoint.pt")))
|
|
fabric.save(normal_checkpoint_path, {"model": trainer.model, "optimizer": trainer.optimizer, "steps": 2})
|
|
optimizer_state_after = torch.load(normal_checkpoint_path)["optimizer"]
|
|
optimizer_state_after = FullyShardedDataParallel.rekey_optim_state_dict(
|
|
optimizer_state_after, optim_state_key_type=OptimStateKeyType.PARAM_NAME, model=trainer.model
|
|
)
|
|
|
|
# assert the correct optimizer state was loaded
|
|
assert set(optimizer_state_after.keys()) == set(optimizer_state_before.keys()) == {"state", "param_groups"}
|
|
torch.testing.assert_close(optimizer_state_after["state"], optimizer_state_before["state"], atol=0, rtol=0)
|
|
|
|
# run a step to verify the optimizer state is correct
|
|
trainer.run()
|
|
|
|
# 3. verify that a single-device model/strategy states can be loaded into a FSDP model/strategy
|
|
fabric = Fabric(
|
|
accelerator="cuda",
|
|
strategy=FSDPStrategy(auto_wrap_policy=always_wrap_policy),
|
|
devices=2,
|
|
)
|
|
fabric.launch()
|
|
trainer = BasicTrainer(fabric)
|
|
trainer.run()
|
|
metadata = fabric.load(normal_checkpoint_path, {"model": trainer.model, "optimizer": trainer.optimizer})
|
|
assert metadata == {"steps": 2}
|
|
|
|
with FullyShardedDataParallel.summon_full_params(trainer.model):
|
|
params_after = list(trainer.model.parameters())
|
|
assert all(torch.equal(p0.cpu(), p1.cpu()) for p0, p1 in zip(params_before, params_after))
|
|
|
|
# assert the correct optimizer state was loaded
|
|
optimizer_state_after = FullyShardedDataParallel.full_optim_state_dict(
|
|
trainer.model, trainer.optimizer, rank0_only=False
|
|
)
|
|
assert set(optimizer_state_after.keys()) == set(optimizer_state_before.keys()) == {"state", "param_groups"}
|
|
torch.testing.assert_close(optimizer_state_after["state"], optimizer_state_before["state"], atol=0, rtol=0)
|
|
assert optimizer_state_after["param_groups"] == optimizer_state_before["param_groups"]
|
|
|
|
# run a step to verify the optimizer state is correct
|
|
trainer.run()
|
|
|
|
|
|
@RunIf(min_cuda_gpus=2, standalone=True, min_torch="2.0.0")
|
|
def test_fsdp_load_full_state_dict_into_sharded_model(tmp_path):
|
|
"""Test that the strategy can load a full-state checkpoint into a FSDP sharded model."""
|
|
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
|
|
|
|
fabric = Fabric(accelerator="cuda", devices=1)
|
|
fabric.seed_everything(0)
|
|
trainer = BasicTrainer(fabric)
|
|
trainer.run()
|
|
|
|
# Save a full-state-dict checkpoint
|
|
checkpoint_path = Path(fabric.broadcast(str(tmp_path / "full-checkpoint.pt")))
|
|
state = {"model": trainer.model, "optimizer": trainer.optimizer, "steps": 1}
|
|
fabric.save(checkpoint_path, state)
|
|
|
|
# Gather all weights and store a copy manually
|
|
with FSDP.summon_full_params(trainer.model, writeback=False, rank0_only=False):
|
|
params_before = torch.cat([p.cpu().view(-1) for p in trainer.model.parameters()])
|
|
|
|
# Create a FSDP sharded model
|
|
fabric = Fabric(
|
|
accelerator="cuda",
|
|
strategy=FSDPStrategy(auto_wrap_policy=always_wrap_policy),
|
|
devices=2,
|
|
)
|
|
fabric.launch()
|
|
trainer = BasicTrainer(fabric)
|
|
trainer.run()
|
|
|
|
state = {"model": trainer.model, "optimizer": trainer.optimizer, "steps": 44}
|
|
fabric.load(checkpoint_path, state)
|
|
assert state["steps"] == 1
|
|
|
|
# Gather all weights and compare
|
|
with FSDP.summon_full_params(trainer.model, writeback=False, rank0_only=False):
|
|
params_after = torch.cat([p.cpu().view(-1) for p in trainer.model.parameters()])
|
|
assert torch.equal(params_before, params_after)
|
|
|
|
# Create a raw state-dict checkpoint to test `Fabric.load_raw` too
|
|
raw_checkpoint_path = checkpoint_path.with_name("model-state-dict")
|
|
if fabric.global_rank == 0:
|
|
checkpoint = torch.load(checkpoint_path)
|
|
torch.save(checkpoint["model"], raw_checkpoint_path)
|
|
fabric.barrier()
|
|
|
|
trainer.run()
|
|
fabric.load_raw(raw_checkpoint_path, trainer.model)
|
|
|
|
# Gather all weights and compare
|
|
with FSDP.summon_full_params(trainer.model, writeback=False, rank0_only=False):
|
|
params_after = torch.cat([p.cpu().view(-1) for p in trainer.model.parameters()])
|
|
assert torch.equal(params_before, params_after)
|
|
|
|
|
|
@RunIf(min_cuda_gpus=2, skip_windows=True, standalone=True)
|
|
@pytest.mark.parametrize("move_to_device", [True, False])
|
|
@mock.patch("lightning.fabric.wrappers._FabricModule")
|
|
def test_setup_module_move_to_device(fabric_module_mock, move_to_device):
|
|
"""Test that `move_to_device` does nothing, FSDP decides which device parameters get moved to which device
|
|
(sharding)."""
|
|
strategy = FSDPStrategy(auto_wrap_policy=always_wrap_policy)
|
|
fabric = Fabric(accelerator="cuda", devices=2, strategy=strategy)
|
|
fabric.launch()
|
|
|
|
model = torch.nn.Linear(10, 10, bias=False) # total params: 10 * 10 = 100
|
|
fabric_model = fabric.setup_module(model, move_to_device=move_to_device)
|
|
fabric_module_mock.assert_not_called()
|
|
|
|
assert len(list(fabric_model.parameters())) == 1
|
|
# the linear layer got sharded and each part is on the expected device
|
|
assert next(fabric_model.parameters()).device == torch.device("cuda", fabric.local_rank)
|
|
assert next(fabric_model.parameters()).numel() == 50
|
|
if _TORCH_GREATER_EQUAL_2_0:
|
|
# In PyTorch >= 2.0 we set `use_orig_params=True` and don't see flattened parameters
|
|
assert isinstance(next(fabric_model.parameters()), Parameter)
|
|
else:
|
|
assert isinstance(next(fabric_model.parameters()), FlatParameter)
|
|
|
|
# The _DeviceDtypeModuleMixin currently can't represent the device in a meaningful way for models with pieces on
|
|
# different devices
|
|
assert fabric_model.device == torch.device("cuda", fabric.local_rank)
|
|
assert fabric.device == torch.device("cuda", fabric.local_rank)
|
|
|
|
|
|
@RunIf(min_cuda_gpus=2, skip_windows=True, standalone=True, min_torch="2.0.0")
|
|
def test_setup_with_orig_params_and_multiple_param_groups():
|
|
"""Test that Fabric sets `use_orig_params` for the user when jointly setting up model and optimizer."""
|
|
strategy = FSDPStrategy(auto_wrap_policy=always_wrap_policy)
|
|
fabric = Fabric(accelerator="cuda", devices=2, strategy=strategy)
|
|
fabric.launch()
|
|
|
|
model = torch.nn.Sequential(
|
|
torch.nn.Linear(10, 10, bias=False),
|
|
torch.nn.Linear(5, 2, bias=False),
|
|
)
|
|
optimizer = torch.optim.Adam(
|
|
[
|
|
{"params": model[0].parameters(), "lr": 1e-2},
|
|
{"params": model[1].parameters(), "lr": 1e-6},
|
|
]
|
|
)
|
|
|
|
# set up model and optimizer jointly
|
|
wrapped_model, wrapped_optimizer = fabric.setup(model, optimizer)
|
|
|
|
assert fabric.strategy._fsdp_kwargs["use_orig_params"]
|
|
assert isinstance(wrapped_optimizer, _FabricOptimizer)
|
|
assert len(wrapped_optimizer.param_groups) == 2
|
|
for i in range(2):
|
|
layer = wrapped_model._forward_module.module[i]
|
|
assert isinstance(layer, FullyShardedDataParallel)
|
|
assert torch.equal(wrapped_optimizer.param_groups[i]["params"][0], layer.weight)
|
|
|
|
# A regular parameter as a view into the flattened parameters
|
|
assert isinstance(layer.weight, torch.nn.Parameter)
|
|
assert not isinstance(layer.weight, FlatParameter)
|
|
|
|
|
|
@RunIf(min_cuda_gpus=2, standalone=True, min_torch="2.1.0", dynamo=True, skip_windows=True)
|
|
@mock.patch(
|
|
"lightning.fabric.wrappers.torch.compile",
|
|
Mock(wraps=(torch.compile if _TORCH_GREATER_EQUAL_2_0 else None)),
|
|
)
|
|
@mock.patch.dict(os.environ, {})
|
|
def test_reapply_compile():
|
|
"""Test that Fabric can rewrap a compiled module such that compilation happens over the FSDP-wrapper."""
|
|
from torch._dynamo import OptimizedModule
|
|
|
|
strategy = FSDPStrategy(auto_wrap_policy=always_wrap_policy)
|
|
fabric = Fabric(accelerator="cuda", devices=2, strategy=strategy)
|
|
fabric.launch()
|
|
|
|
model = BoringModel()
|
|
compile_kwargs = {"mode": "reduce-overhead"}
|
|
compiled_model = torch.compile(model, **compile_kwargs)
|
|
torch.compile.reset_mock()
|
|
|
|
fabric_model = fabric.setup(compiled_model, _reapply_compile=True)
|
|
|
|
assert isinstance(fabric_model._forward_module, OptimizedModule)
|
|
assert isinstance(fabric_model._forward_module._orig_mod, FullyShardedDataParallel)
|
|
|
|
# Assert we called compile again with the same arguments, but on the FSDP-wrapped module
|
|
torch.compile.assert_called_with(fabric_model._forward_module._orig_mod, **compile_kwargs)
|
|
|
|
assert fabric_model._original_module == model
|
|
assert fabric_model._forward_module._orig_mod.module == model
|
|
assert fabric_model.device == fabric.device
|
|
|
|
# Smoke-testing forward to ensure we don't get compilation errors
|
|
for _ in range(3):
|
|
fabric_model(torch.randn(2, 32, device=fabric.device)).sum().backward()
|
|
|
|
|
|
@RunIf(min_cuda_gpus=2, skip_windows=True, standalone=True)
|
|
@pytest.mark.parametrize(
|
|
("precision", "expected_dtype"),
|
|
[
|
|
("32-true", torch.float32),
|
|
("16-true", torch.float16),
|
|
pytest.param("bf16-true", torch.bfloat16, marks=RunIf(bf16_cuda=True)),
|
|
],
|
|
)
|
|
def test_module_init_context(precision, expected_dtype):
|
|
"""Test that the module under the init-context gets moved to the right device and dtype."""
|
|
fabric = Fabric(
|
|
accelerator="cuda",
|
|
devices=2,
|
|
strategy=FSDPStrategy(auto_wrap_policy=always_wrap_policy),
|
|
precision=precision,
|
|
)
|
|
fabric.launch()
|
|
|
|
def _run_setup_assertions(empty_init, expected_device):
|
|
with fabric.init_module(empty_init=empty_init):
|
|
model = torch.nn.Linear(100, 100, bias=False)
|
|
|
|
# The model is on the CPU/meta-device until after `.setup()``
|
|
assert model.weight.device == expected_device
|
|
assert model.weight.dtype == expected_dtype
|
|
model = fabric.setup(model)
|
|
# Parameters get sharded in `.setup()` and moved to the target device
|
|
assert model.weight.device == torch.device("cuda", fabric.local_rank)
|
|
assert model.weight.dtype == expected_dtype
|
|
|
|
# Case 1: No empty init
|
|
_run_setup_assertions(empty_init=False, expected_device=torch.device("cpu"))
|
|
|
|
if _TORCH_GREATER_EQUAL_2_1:
|
|
# Case 2: Empty-init with PyTorch >= 2.1 supports meta device
|
|
_run_setup_assertions(empty_init=True, expected_device=torch.device("meta"))
|
|
else:
|
|
# Case 2: Empty-init with PyTorch < 2.1 only supports `torch.empty()`-init
|
|
_run_setup_assertions(empty_init=True, expected_device=torch.device("cpu"))
|
|
|
|
|
|
@RunIf(min_cuda_gpus=2, standalone=True, min_torch="2.0.0")
|
|
def test_fsdp_save_filter(tmp_path):
|
|
fabric = Fabric(accelerator="cuda", strategy=FSDPStrategy(state_dict_type="full"), devices=2)
|
|
fabric.launch()
|
|
model = nn.Linear(32, 2)
|
|
model = fabric.setup_module(model)
|
|
|
|
tmp_path = Path(fabric.broadcast(str(tmp_path)))
|
|
state = {"model": model}
|
|
filter = {"model": lambda k, v: "bias" in k}
|
|
|
|
checkpoint_path = tmp_path / "full.pth"
|
|
fabric.save(checkpoint_path, state, filter=filter)
|
|
checkpoint = torch.load(checkpoint_path)["model"]
|
|
assert set(checkpoint) == {"bias"}
|
|
assert isinstance(checkpoint["bias"], torch.Tensor)
|
|
|
|
fabric.strategy._state_dict_type = "sharded"
|
|
checkpoint_path = tmp_path / "sharded"
|
|
with pytest.raises(NotImplementedError, match="doesn't support loading sharded filtered"):
|
|
fabric.save(checkpoint_path, state, filter=filter)
|
|
|
|
|
|
@RunIf(min_cuda_gpus=1)
|
|
def test_fsdp_manual_activation_checkpointing():
|
|
model = torch.nn.Sequential(torch.nn.Linear(1, 1), torch.nn.Linear(1, 1))
|
|
strategy = FSDPStrategy(activation_checkpointing_policy={torch.nn.Linear})
|
|
fabric = Fabric(devices=1, accelerator="cuda", strategy=strategy)
|
|
fabric.launch()
|
|
|
|
from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (
|
|
CheckpointWrapper,
|
|
apply_activation_checkpointing,
|
|
)
|
|
|
|
# manually apply activation checkpointing
|
|
apply_activation_checkpointing(model)
|
|
|
|
wrappers = {name for name, mod in model.named_modules() if isinstance(mod, CheckpointWrapper)}
|
|
assert wrappers == {"0", "1"}
|
|
|
|
# let fabric set up the model, it shouldn't apply activation checkpointing again
|
|
with pytest.warns(match="is configured, but the model already contains checkpointed"):
|
|
model = fabric.setup(model)
|
|
|
|
wrappers = {name for name, mod in model._forward_module.named_modules() if isinstance(mod, CheckpointWrapper)}
|
|
assert wrappers == {"_fsdp_wrapped_module.0", "_fsdp_wrapped_module.1"}
|
|
|
|
|
|
@RunIf(min_cuda_gpus=1)
|
|
def test_rewrap_warnings():
|
|
from torch.distributed.fsdp import FullyShardedDataParallel
|
|
from torch.distributed.fsdp.wrap import wrap
|
|
|
|
strategy = FSDPStrategy(auto_wrap_policy={torch.nn.Linear})
|
|
fabric = Fabric(devices=1, accelerator="cuda", strategy=strategy)
|
|
fabric.launch()
|
|
with fabric.init_module():
|
|
model = torch.nn.Sequential(torch.nn.Linear(1, 1), torch.nn.ReLU(), wrap(torch.nn.Linear(1, 1)))
|
|
with pytest.warns(match="the model is already wrapped"):
|
|
model = fabric.setup(model)
|
|
assert not isinstance(model._forward_module, FullyShardedDataParallel)
|
|
assert isinstance(model._forward_module[2], FullyShardedDataParallel)
|
|
|
|
if not _TORCH_GREATER_EQUAL_2_1:
|
|
return
|
|
|
|
with fabric.init_module(empty_init=True):
|
|
model = torch.nn.Sequential(torch.nn.Linear(1, 1), torch.nn.ReLU(), wrap(torch.nn.Linear(1, 1)))
|
|
assert model[0].weight.is_meta
|
|
with pytest.warns(match="there are still parameters on the meta device"):
|
|
fabric_model = fabric.setup(model)
|
|
assert next(fabric_model.parameters()).is_meta
|
|
|
|
|
|
@RunIf(min_cuda_gpus=2, standalone=True)
|
|
@pytest.mark.parametrize(
|
|
"precision",
|
|
[
|
|
"32-true",
|
|
pytest.param("16-mixed"),
|
|
pytest.param("bf16-mixed", marks=RunIf(bf16_cuda=True)),
|
|
],
|
|
)
|
|
@pytest.mark.parametrize(
|
|
"clip_type",
|
|
[
|
|
pytest.param("norm", marks=pytest.mark.skip("FSDP gradient clipping by norm is not correct.")),
|
|
"val",
|
|
],
|
|
)
|
|
def test_clip_gradients(clip_type, precision):
|
|
if clip_type == "norm" and precision == "16-mixed":
|
|
pytest.skip(reason="Clipping by norm with 16-mixed is numerically unstable.")
|
|
|
|
strategy = FSDPStrategy(auto_wrap_policy={torch.nn.Linear})
|
|
fabric = Fabric(accelerator="auto", devices=2, precision=precision, strategy=strategy)
|
|
fabric.launch()
|
|
|
|
in_features, out_features = 32, 2
|
|
model = torch.nn.Linear(in_features, out_features, bias=False)
|
|
model.weight.data.fill_(0.01)
|
|
|
|
optimizer = torch.optim.Adam(model.parameters(), lr=0.1)
|
|
model, optimizer = fabric.setup(model, optimizer)
|
|
|
|
batch = torch.full((1, in_features), 0.1, device=fabric.device)
|
|
loss = model(batch).sum()
|
|
|
|
# The example is constructed such that the gradients are all the same
|
|
fabric.backward(loss)
|
|
|
|
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
|
|
|
|
if clip_type == "norm":
|
|
with FSDP.summon_full_params(model._forward_module, with_grads=True):
|
|
norm = torch.linalg.vector_norm(model.weight.grad.detach().cpu(), 2, dtype=torch.float32).item()
|
|
new_norm = norm / 10
|
|
fabric.clip_gradients(model, optimizer, max_norm=new_norm * 10)
|
|
with FSDP.summon_full_params(model._forward_module, with_grads=True):
|
|
assert torch.allclose(
|
|
torch.linalg.vector_norm(model.weight.grad.detach().cpu(), 2, dtype=torch.float32),
|
|
torch.tensor(new_norm),
|
|
)
|
|
elif clip_type == "val":
|
|
val = model.weight.grad[0].item()
|
|
new_val = val / 2.0
|
|
fabric.clip_gradients(model, optimizer, clip_val=new_val)
|
|
assert torch.allclose(model.weight.grad, torch.full_like(model.weight.grad, new_val))
|
|
else:
|
|
raise AssertionError(f"Unknown clip type: {clip_type}")
|
|
|
|
optimizer.step()
|
|
optimizer.zero_grad()
|
|
|
|
|
|
# TODO: Support checkpoint consolidation with PyTorch >= 2.2
|
|
@RunIf(min_cuda_gpus=2, standalone=True, min_torch="2.1.0", max_torch="2.2.0")
|
|
def test_save_sharded_and_consolidate_and_load(tmp_path):
|
|
"""Test the consolidation of a FSDP-sharded checkpoint into a single file."""
|
|
|
|
fabric = Fabric(
|
|
accelerator="cuda",
|
|
strategy=FSDPStrategy(auto_wrap_policy=always_wrap_policy, state_dict_type="sharded"),
|
|
devices=2,
|
|
)
|
|
fabric.launch()
|
|
|
|
model = BoringModel()
|
|
optimizer = torch.optim.Adam(model.parameters())
|
|
model, optimizer = fabric.setup(model, optimizer)
|
|
state = {"model": model, "optimizer": optimizer, "steps": 1}
|
|
|
|
# run one iteration to init the state of the optimizer
|
|
model(torch.rand(1, 32, device=fabric.device)).sum().backward()
|
|
optimizer.step()
|
|
|
|
checkpoint_path_sharded = fabric.broadcast(str(tmp_path / "checkpoint_sharded"))
|
|
fabric.save(checkpoint_path_sharded, state)
|
|
assert set(os.listdir(checkpoint_path_sharded)) == {"meta.pt", ".metadata", "__0_0.distcp", "__1_0.distcp"}
|
|
|
|
# consolidate the checkpoint to a single file
|
|
checkpoint_path_full = fabric.broadcast(str(tmp_path / "checkpoint_full.pt"))
|
|
if fabric.global_rank == 0:
|
|
checkpoint = _load_distributed_checkpoint(Path(checkpoint_path_sharded))
|
|
torch.save(checkpoint, checkpoint_path_full)
|
|
fabric.barrier()
|
|
|
|
# re-init and load from full checkpoint
|
|
fabric = Fabric(
|
|
accelerator="cuda",
|
|
strategy=FSDPStrategy(auto_wrap_policy=always_wrap_policy),
|
|
devices=2,
|
|
)
|
|
|
|
# Hack: we already called launch() on another Fabric instance above
|
|
fabric._launched = True
|
|
|
|
model = BoringModel()
|
|
optimizer = torch.optim.Adam(model.parameters())
|
|
model, optimizer = fabric.setup(model, optimizer)
|
|
state = {"model": model, "optimizer": optimizer, "steps": 1}
|
|
fabric.load(checkpoint_path_full, state)
|