967 lines
39 KiB
Python
967 lines
39 KiB
Python
import os
|
|
from contextlib import nullcontext
|
|
from copy import deepcopy
|
|
from datetime import timedelta
|
|
from functools import partial
|
|
from pathlib import Path
|
|
from re import escape
|
|
from typing import Optional
|
|
from unittest import mock
|
|
from unittest.mock import ANY, MagicMock, Mock
|
|
|
|
import pytest
|
|
import torch
|
|
import torch.nn as nn
|
|
from lightning.fabric.plugins.environments import LightningEnvironment
|
|
from lightning.fabric.strategies.fsdp import _is_sharded_checkpoint
|
|
from lightning.fabric.utilities.imports import (
|
|
_TORCH_GREATER_EQUAL_2_0,
|
|
_TORCH_GREATER_EQUAL_2_1,
|
|
)
|
|
from lightning.pytorch import Trainer
|
|
from lightning.pytorch.callbacks import ModelCheckpoint
|
|
from lightning.pytorch.demos.boring_classes import BoringModel
|
|
from lightning.pytorch.plugins import HalfPrecision
|
|
from lightning.pytorch.plugins.precision.fsdp import FSDPPrecision
|
|
from lightning.pytorch.strategies import FSDPStrategy
|
|
from lightning.pytorch.trainer.states import TrainerFn
|
|
from lightning.pytorch.utilities.exceptions import MisconfigurationException
|
|
from torch.distributed.fsdp.fully_sharded_data_parallel import CPUOffload, FullyShardedDataParallel, MixedPrecision
|
|
from torch.distributed.fsdp.wrap import always_wrap_policy, size_based_auto_wrap_policy, wrap
|
|
|
|
from tests_pytorch.helpers.runif import RunIf
|
|
|
|
if _TORCH_GREATER_EQUAL_2_0:
|
|
from torch.distributed.fsdp.wrap import ModuleWrapPolicy
|
|
else:
|
|
ModuleWrapPolicy = object
|
|
|
|
|
|
class TestFSDPModel(BoringModel):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.layer: Optional[nn.Module] = None
|
|
|
|
def _init_model(self) -> None:
|
|
self.layer = torch.nn.Sequential(torch.nn.Linear(32, 32), torch.nn.ReLU(), torch.nn.Linear(32, 2))
|
|
|
|
def configure_model(self) -> None:
|
|
if self.layer is None:
|
|
self._init_model()
|
|
# the model is already wrapped with FSDP: no need to wrap again!
|
|
if isinstance(self.layer, FullyShardedDataParallel):
|
|
return
|
|
for i, layer in enumerate(self.layer):
|
|
if i % 2 == 0:
|
|
self.layer[i] = wrap(layer)
|
|
self.layer = wrap(self.layer)
|
|
|
|
def configure_optimizers(self):
|
|
# There is some issue with SGD optimizer state in FSDP
|
|
return torch.optim.AdamW(self.layer.parameters(), lr=0.1)
|
|
|
|
def on_train_batch_start(self, batch, batch_idx):
|
|
assert batch.dtype == torch.float32
|
|
|
|
def on_train_batch_end(self, _, batch, batch_idx):
|
|
assert batch.dtype == torch.float32
|
|
self._assert_layer_fsdp_instance()
|
|
|
|
def on_test_batch_end(self, _, batch, batch_idx):
|
|
assert batch.dtype == torch.float32
|
|
self._assert_layer_fsdp_instance()
|
|
|
|
def on_validation_batch_end(self, _, batch, batch_idx):
|
|
assert batch.dtype == torch.float32
|
|
self._assert_layer_fsdp_instance()
|
|
|
|
def on_predict_batch_end(self, _, batch, batch_idx):
|
|
assert batch.dtype == torch.float32
|
|
self._assert_layer_fsdp_instance()
|
|
|
|
def _assert_layer_fsdp_instance(self) -> None:
|
|
assert isinstance(self.layer, FullyShardedDataParallel)
|
|
assert isinstance(self.trainer.strategy.precision_plugin, FSDPPrecision)
|
|
|
|
if self.trainer.precision == "16-mixed":
|
|
param_dtype = None if not _TORCH_GREATER_EQUAL_2_0 else torch.float32
|
|
reduce_dtype = buffer_dtype = torch.float16
|
|
elif self.trainer.precision == "bf16-mixed":
|
|
param_dtype = None if not _TORCH_GREATER_EQUAL_2_0 else torch.float32
|
|
reduce_dtype = buffer_dtype = torch.bfloat16
|
|
elif self.trainer.precision == "16-true":
|
|
param_dtype = reduce_dtype = buffer_dtype = torch.float16
|
|
elif self.trainer.precision == "bf16-true":
|
|
param_dtype = reduce_dtype = buffer_dtype = torch.bfloat16
|
|
else:
|
|
raise ValueError(f"Unknown precision {self.trainer.precision}")
|
|
|
|
assert self.layer.mixed_precision.param_dtype == param_dtype
|
|
assert self.layer.mixed_precision.reduce_dtype == reduce_dtype
|
|
assert self.layer.mixed_precision.buffer_dtype == buffer_dtype
|
|
|
|
for layer_num in [0, 2]:
|
|
assert isinstance(self.layer.module[layer_num], FullyShardedDataParallel)
|
|
assert self.layer[layer_num].mixed_precision.param_dtype == param_dtype
|
|
assert self.layer[layer_num].mixed_precision.reduce_dtype == reduce_dtype
|
|
assert self.layer[layer_num].mixed_precision.buffer_dtype == buffer_dtype
|
|
|
|
|
|
class TestBoringModel(BoringModel):
|
|
def __init__(self, wrap_min_params: int = 2):
|
|
super().__init__()
|
|
|
|
self.save_hyperparameters()
|
|
self.layer = torch.nn.Sequential(torch.nn.Linear(32, 32), torch.nn.ReLU(), torch.nn.Linear(32, 2))
|
|
self.should_be_wrapped = [(32 * 32 + 32) > wrap_min_params, None, (32 * 2 + 2) > wrap_min_params]
|
|
|
|
def configure_optimizers(self):
|
|
parameters = self.parameters() if _TORCH_GREATER_EQUAL_2_0 else self.trainer.model.parameters()
|
|
|
|
# SGD's FSDP optimier state is fixed in https://github.com/pytorch/pytorch/pull/99214
|
|
return torch.optim.AdamW(parameters, lr=0.1)
|
|
|
|
|
|
class TestFSDPModelAutoWrapped(TestBoringModel):
|
|
def on_train_batch_start(self, batch, batch_idx):
|
|
assert batch.dtype == torch.float32
|
|
|
|
def on_train_batch_end(self, _, batch, batch_idx):
|
|
assert batch.dtype == torch.float32
|
|
self._assert_layer_fsdp_instance()
|
|
|
|
def on_test_batch_end(self, _, batch, batch_idx):
|
|
assert batch.dtype == torch.float32
|
|
self._assert_layer_fsdp_instance()
|
|
|
|
def on_validation_batch_end(self, _, batch, batch_idx):
|
|
assert batch.dtype == torch.float32
|
|
self._assert_layer_fsdp_instance()
|
|
|
|
def on_predict_batch_end(self, _, batch, batch_idx):
|
|
assert batch.dtype == torch.float32
|
|
self._assert_layer_fsdp_instance()
|
|
|
|
def _assert_layer_fsdp_instance(self) -> None:
|
|
assert isinstance(self.layer, torch.nn.Sequential)
|
|
assert isinstance(self.trainer.strategy.precision_plugin, FSDPPrecision)
|
|
|
|
if self.trainer.precision == "16-mixed":
|
|
param_dtype = None if not _TORCH_GREATER_EQUAL_2_0 else torch.float32
|
|
reduce_dtype = buffer_dtype = torch.float16
|
|
elif self.trainer.precision == "bf16-mixed":
|
|
param_dtype = None if not _TORCH_GREATER_EQUAL_2_0 else torch.float32
|
|
reduce_dtype = buffer_dtype = torch.bfloat16
|
|
elif self.trainer.precision == "16-true":
|
|
param_dtype = reduce_dtype = buffer_dtype = torch.float16
|
|
elif self.trainer.precision == "bf16-true":
|
|
param_dtype = reduce_dtype = buffer_dtype = torch.bfloat16
|
|
else:
|
|
raise ValueError(f"Unknown precision {self.trainer.precision}")
|
|
|
|
for layer_num in [0, 2]:
|
|
if not self.should_be_wrapped[layer_num]:
|
|
# this layer is not wrapped
|
|
assert not isinstance(self.layer[layer_num], FullyShardedDataParallel)
|
|
continue
|
|
assert isinstance(self.layer[layer_num], FullyShardedDataParallel)
|
|
assert self.layer[layer_num].mixed_precision.param_dtype == param_dtype
|
|
assert self.layer[layer_num].mixed_precision.reduce_dtype == reduce_dtype
|
|
assert self.layer[layer_num].mixed_precision.buffer_dtype == buffer_dtype
|
|
|
|
|
|
def _run_multiple_stages(trainer, model, model_path: Optional[str] = None):
|
|
trainer.fit(model)
|
|
model_path = trainer.strategy.broadcast(model_path)
|
|
model_path = model_path if model_path else trainer.checkpoint_callback.last_model_path
|
|
|
|
trainer.save_checkpoint(model_path, weights_only=True)
|
|
|
|
_assert_save_equality(trainer, model_path, cls=model.__class__)
|
|
|
|
with torch.inference_mode():
|
|
# Test entry point
|
|
trainer.test(model) # model is wrapped, will not call `configure_model`
|
|
|
|
# provide model path, will create a new unwrapped model and load and then call `configure_shared_model` to wrap
|
|
trainer.test(ckpt_path=model_path)
|
|
|
|
# Predict entry point
|
|
trainer.predict(model) # model is wrapped, will not call `configure_model`
|
|
|
|
# provide model path, will create a new unwrapped model and load and then call `configure_shared_model` to wrap
|
|
trainer.predict(ckpt_path=model_path)
|
|
|
|
|
|
def _assert_save_equality(trainer, ckpt_path, cls=TestFSDPModel):
|
|
# Use FullySharded to get the state dict for the sake of comparison
|
|
model_state_dict = trainer.strategy.lightning_module_state_dict()
|
|
|
|
if trainer.is_global_zero:
|
|
saved_model = cls.load_from_checkpoint(ckpt_path)
|
|
|
|
# Assert model parameters are identical after loading
|
|
for ddp_param, shard_param in zip(model_state_dict.values(), saved_model.state_dict().values()):
|
|
assert torch.equal(ddp_param, shard_param)
|
|
|
|
|
|
def test_invalid_on_cpu(tmpdir, cuda_count_0):
|
|
"""Test to ensure that we raise Misconfiguration for FSDP on CPU."""
|
|
with pytest.raises(
|
|
MisconfigurationException,
|
|
match=f"You selected strategy to be `{FSDPStrategy.strategy_name}`, but GPU accelerator is not used.",
|
|
):
|
|
trainer = Trainer(accelerator="cpu", default_root_dir=tmpdir, fast_dev_run=True, strategy="fsdp")
|
|
assert isinstance(trainer.strategy, FSDPStrategy)
|
|
trainer.strategy.setup_environment()
|
|
|
|
|
|
def test_fsdp_custom_mixed_precision():
|
|
"""Test to ensure that passing a custom mixed precision config works."""
|
|
config = MixedPrecision()
|
|
strategy = FSDPStrategy(mixed_precision=config)
|
|
assert strategy.mixed_precision_config == config
|
|
|
|
|
|
@RunIf(min_cuda_gpus=2, skip_windows=True, standalone=True)
|
|
def test_fsdp_strategy_sync_batchnorm(tmpdir):
|
|
"""Test to ensure that sync_batchnorm works when using FSDP and GPU, and all stages can be run."""
|
|
model = TestFSDPModel()
|
|
trainer = Trainer(
|
|
default_root_dir=tmpdir,
|
|
accelerator="gpu",
|
|
devices=2,
|
|
strategy="fsdp",
|
|
precision="16-mixed",
|
|
max_epochs=1,
|
|
sync_batchnorm=True,
|
|
)
|
|
_run_multiple_stages(trainer, model, os.path.join(tmpdir, "last.ckpt"))
|
|
|
|
|
|
@RunIf(min_cuda_gpus=1, skip_windows=True, standalone=True)
|
|
@pytest.mark.parametrize("precision", ["16-mixed", pytest.param("bf16-mixed", marks=RunIf(bf16_cuda=True))])
|
|
def test_fsdp_strategy_checkpoint(tmpdir, precision):
|
|
"""Test to ensure that checkpoint is saved correctly when using a single GPU, and all stages can be run."""
|
|
model = TestFSDPModel()
|
|
trainer = Trainer(
|
|
default_root_dir=tmpdir, accelerator="gpu", devices=1, strategy="fsdp", precision=precision, max_epochs=1
|
|
)
|
|
_run_multiple_stages(trainer, model, os.path.join(tmpdir, "last.ckpt"))
|
|
|
|
|
|
if _TORCH_GREATER_EQUAL_2_0:
|
|
|
|
def custom_auto_wrap_policy(
|
|
module,
|
|
recurse,
|
|
nonwrapped_numel: int,
|
|
) -> bool:
|
|
return nonwrapped_numel >= 2
|
|
|
|
else:
|
|
|
|
def custom_auto_wrap_policy(
|
|
module,
|
|
recurse,
|
|
unwrapped_params: int,
|
|
) -> bool:
|
|
return unwrapped_params >= 2
|
|
|
|
|
|
@RunIf(min_cuda_gpus=2, skip_windows=True, standalone=True)
|
|
@pytest.mark.parametrize("wrap_min_params", [2, 1024, 100000000])
|
|
def test_fsdp_strategy_full_state_dict(tmpdir, wrap_min_params):
|
|
"""Test to ensure that the full state dict is extracted when using FSDP strategy.
|
|
|
|
Based on `wrap_min_params`, the model will be fully wrapped, half wrapped, and not wrapped at all.
|
|
|
|
"""
|
|
model = TestFSDPModelAutoWrapped(wrap_min_params=wrap_min_params)
|
|
correct_state_dict = model.state_dict() # State dict before wrapping
|
|
|
|
strategy = FSDPStrategy(auto_wrap_policy=partial(size_based_auto_wrap_policy, min_num_params=wrap_min_params))
|
|
trainer = Trainer(
|
|
default_root_dir=tmpdir,
|
|
accelerator="gpu",
|
|
devices=2,
|
|
strategy=strategy,
|
|
precision="16-mixed",
|
|
max_epochs=1,
|
|
barebones=True,
|
|
)
|
|
trainer.fit(model)
|
|
|
|
full_state_dict = trainer.strategy.lightning_module_state_dict()
|
|
|
|
if trainer.global_rank != 0:
|
|
assert len(full_state_dict) == 0
|
|
return
|
|
|
|
# State dict should contain same number of keys
|
|
assert len(correct_state_dict) == len(full_state_dict)
|
|
# OrderedDict should return the same keys in the same order
|
|
assert all(_ex == _co for _ex, _co in zip(full_state_dict.keys(), correct_state_dict.keys()))
|
|
|
|
|
|
@RunIf(min_cuda_gpus=2, skip_windows=True, standalone=True)
|
|
@pytest.mark.parametrize(
|
|
("model", "strategy", "strategy_cfg"),
|
|
[
|
|
pytest.param(TestFSDPModel(), "fsdp", None, id="manually_wrapped"),
|
|
pytest.param(
|
|
TestFSDPModelAutoWrapped(),
|
|
FSDPStrategy,
|
|
{"auto_wrap_policy": custom_auto_wrap_policy},
|
|
marks=RunIf(max_torch="2.0.0"),
|
|
id="autowrap_1x",
|
|
),
|
|
pytest.param(
|
|
TestFSDPModelAutoWrapped(),
|
|
FSDPStrategy,
|
|
{"auto_wrap_policy": custom_auto_wrap_policy},
|
|
marks=RunIf(min_torch="2.0.0"),
|
|
id="autowrap_2x",
|
|
),
|
|
pytest.param(
|
|
TestFSDPModelAutoWrapped(),
|
|
FSDPStrategy,
|
|
{
|
|
"auto_wrap_policy": ModuleWrapPolicy({nn.Linear}) if _TORCH_GREATER_EQUAL_2_1 else None,
|
|
"use_orig_params": True,
|
|
},
|
|
marks=RunIf(min_torch="2.1.0"),
|
|
id="autowrap_use_orig_params",
|
|
),
|
|
],
|
|
)
|
|
def test_fsdp_checkpoint_multi_gpus(tmpdir, model, strategy, strategy_cfg):
|
|
"""Test to ensure that checkpoint is saved correctly when using multiple GPUs, and all stages can be run."""
|
|
ck = ModelCheckpoint(save_last=True)
|
|
|
|
strategy_cfg = strategy_cfg or {}
|
|
if not isinstance(strategy, str):
|
|
strategy = strategy(**strategy_cfg)
|
|
|
|
trainer = Trainer(
|
|
default_root_dir=tmpdir,
|
|
accelerator="gpu",
|
|
devices=2,
|
|
strategy=strategy,
|
|
precision="16-mixed",
|
|
max_epochs=1,
|
|
limit_train_batches=2,
|
|
limit_val_batches=2,
|
|
limit_test_batches=2,
|
|
limit_predict_batches=2,
|
|
callbacks=[ck],
|
|
)
|
|
_run_multiple_stages(trainer, model)
|
|
|
|
|
|
@RunIf(min_cuda_gpus=1, skip_windows=True, standalone=True)
|
|
@pytest.mark.parametrize("use_orig_params", [None, False, True])
|
|
def test_invalid_parameters_in_optimizer(use_orig_params):
|
|
fsdp_kwargs = {}
|
|
if _TORCH_GREATER_EQUAL_2_0 and use_orig_params is not None:
|
|
fsdp_kwargs = {"use_orig_params": use_orig_params}
|
|
|
|
trainer = Trainer(
|
|
strategy=FSDPStrategy(**fsdp_kwargs),
|
|
accelerator="cuda",
|
|
devices=1,
|
|
fast_dev_run=1,
|
|
)
|
|
|
|
error_context = (
|
|
nullcontext()
|
|
if _TORCH_GREATER_EQUAL_2_0 and (_TORCH_GREATER_EQUAL_2_1 or use_orig_params is not False)
|
|
else pytest.raises(ValueError, match="The optimizer does not seem to reference any FSDP parameters")
|
|
)
|
|
|
|
class EmptyParametersModel(BoringModel):
|
|
def configure_optimizers(self):
|
|
return torch.optim.Adam(self.parameters(), lr=1e-2)
|
|
|
|
model = EmptyParametersModel()
|
|
with error_context:
|
|
trainer.fit(model)
|
|
|
|
class NoFlatParametersModel(BoringModel):
|
|
def configure_optimizers(self):
|
|
layer = torch.nn.Linear(4, 5)
|
|
return torch.optim.Adam(layer.parameters(), lr=1e-2)
|
|
|
|
error_context = (
|
|
nullcontext()
|
|
if _TORCH_GREATER_EQUAL_2_0 and use_orig_params is not False
|
|
else pytest.raises(ValueError, match="The optimizer does not seem to reference any FSDP parameters")
|
|
)
|
|
|
|
model = NoFlatParametersModel()
|
|
with error_context:
|
|
trainer.fit(model)
|
|
|
|
|
|
@mock.patch("lightning.fabric.strategies.fsdp._TORCH_GREATER_EQUAL_1_13", False)
|
|
def test_fsdp_activation_checkpointing_support():
|
|
"""Test that we error out if activation checkpointing requires a newer PyTorch version."""
|
|
with pytest.raises(ValueError, match="activation_checkpointing` requires torch >= 1.13.0"):
|
|
FSDPStrategy(activation_checkpointing=Mock())
|
|
|
|
|
|
def test_fsdp_forbidden_precision_raises():
|
|
with pytest.raises(TypeError, match="can only work with the `FSDPPrecision"):
|
|
FSDPStrategy(precision_plugin=HalfPrecision())
|
|
|
|
strategy = FSDPStrategy()
|
|
with pytest.raises(TypeError, match="can only work with the `FSDPPrecision"):
|
|
strategy.precision_plugin = HalfPrecision()
|
|
|
|
|
|
@RunIf(min_torch="1.13")
|
|
def test_fsdp_activation_checkpointing():
|
|
"""Test that the FSDP strategy can apply activation checkpointing to the given layers."""
|
|
|
|
class Block1(nn.Linear):
|
|
pass
|
|
|
|
class Block2(nn.Linear):
|
|
pass
|
|
|
|
class Model(BoringModel):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.layer0 = nn.Sequential(Block1(4, 4), Block1(5, 5))
|
|
self.layer1 = Block2(2, 2)
|
|
self.layer2 = nn.Linear(3, 3)
|
|
|
|
if _TORCH_GREATER_EQUAL_2_1:
|
|
from torch.distributed.fsdp.wrap import ModuleWrapPolicy
|
|
|
|
strategy = FSDPStrategy(activation_checkpointing_policy={Block1})
|
|
assert set(strategy._activation_checkpointing_kwargs) == {"auto_wrap_policy"}
|
|
assert isinstance(strategy._activation_checkpointing_kwargs["auto_wrap_policy"], ModuleWrapPolicy)
|
|
|
|
strategy = FSDPStrategy(activation_checkpointing_policy=ModuleWrapPolicy({Block1, Block2}))
|
|
assert set(strategy._activation_checkpointing_kwargs) == {"auto_wrap_policy"}
|
|
assert isinstance(strategy._activation_checkpointing_kwargs["auto_wrap_policy"], ModuleWrapPolicy)
|
|
else:
|
|
strategy = FSDPStrategy(activation_checkpointing=Block1)
|
|
assert set(strategy._activation_checkpointing_kwargs) == {"check_fn"}
|
|
|
|
strategy = FSDPStrategy(activation_checkpointing=[Block1, Block2])
|
|
assert set(strategy._activation_checkpointing_kwargs) == {"check_fn"}
|
|
|
|
strategy = FSDPStrategy(activation_checkpointing_policy={Block1})
|
|
assert set(strategy._activation_checkpointing_kwargs) == {"check_fn"}
|
|
|
|
strategy = FSDPStrategy(activation_checkpointing_policy={Block1, Block2})
|
|
assert set(strategy._activation_checkpointing_kwargs) == {"check_fn"}
|
|
|
|
model = Model()
|
|
strategy._parallel_devices = [torch.device("cuda", 0)]
|
|
strategy._lightning_module = model
|
|
strategy._process_group = Mock()
|
|
with mock.patch("torch.distributed.fsdp.FullyShardedDataParallel", new=MagicMock), mock.patch(
|
|
"torch.distributed.algorithms._checkpoint.checkpoint_wrapper.apply_activation_checkpointing"
|
|
) as apply_mock:
|
|
wrapped = strategy._setup_model(model)
|
|
apply_mock.assert_called_with(wrapped, checkpoint_wrapper_fn=ANY, **strategy._activation_checkpointing_kwargs)
|
|
|
|
|
|
def test_fsdp_strategy_cpu_offload():
|
|
"""Test the different ways cpu offloading can be enabled."""
|
|
# bool
|
|
strategy = FSDPStrategy(cpu_offload=True)
|
|
assert strategy.cpu_offload == CPUOffload(offload_params=True)
|
|
|
|
# dataclass
|
|
config = CPUOffload()
|
|
strategy = FSDPStrategy(cpu_offload=config)
|
|
assert strategy.cpu_offload == config
|
|
|
|
|
|
def test_fsdp_sharding_strategy():
|
|
"""Test the different ways the sharding strategy can be set."""
|
|
from torch.distributed.fsdp import ShardingStrategy
|
|
|
|
# default
|
|
strategy = FSDPStrategy()
|
|
assert strategy.sharding_strategy == ShardingStrategy.FULL_SHARD
|
|
|
|
# enum
|
|
strategy = FSDPStrategy(sharding_strategy=ShardingStrategy.SHARD_GRAD_OP)
|
|
assert strategy.sharding_strategy == ShardingStrategy.SHARD_GRAD_OP
|
|
|
|
# string
|
|
strategy = FSDPStrategy(sharding_strategy="NO_SHARD")
|
|
assert strategy.sharding_strategy == ShardingStrategy.NO_SHARD
|
|
strategy = FSDPStrategy(sharding_strategy="no_shard")
|
|
assert strategy.sharding_strategy == ShardingStrategy.NO_SHARD
|
|
|
|
|
|
@RunIf(min_torch="2.0")
|
|
@pytest.mark.parametrize("sharding_strategy", ["HYBRID_SHARD", "_HYBRID_SHARD_ZERO2"])
|
|
def test_fsdp_hybrid_sharding_strategy(sharding_strategy):
|
|
"""Test that the hybrid sharding strategies can only be used with automatic wrapping or a manually specified pg."""
|
|
with pytest.raises(RuntimeError, match="The hybrid sharding strategy requires you to either set"):
|
|
FSDPStrategy(sharding_strategy=sharding_strategy)
|
|
|
|
strategy = FSDPStrategy(auto_wrap_policy={nn.Linear}, sharding_strategy=sharding_strategy)
|
|
assert strategy.sharding_strategy.name == sharding_strategy
|
|
|
|
strategy = FSDPStrategy(sharding_strategy=sharding_strategy, process_group=(Mock(), Mock()))
|
|
assert strategy.sharding_strategy.name == sharding_strategy
|
|
|
|
|
|
def test_fsdp_use_orig_params():
|
|
"""Test that Lightning enables `use_orig_params` in PyTorch >= 2.0."""
|
|
with mock.patch("lightning.pytorch.strategies.fsdp._TORCH_GREATER_EQUAL_2_0", False):
|
|
strategy = FSDPStrategy()
|
|
assert "use_orig_params" not in strategy.kwargs
|
|
|
|
with mock.patch("lightning.pytorch.strategies.fsdp._TORCH_GREATER_EQUAL_2_0", True):
|
|
strategy = FSDPStrategy()
|
|
assert strategy.kwargs["use_orig_params"]
|
|
strategy = FSDPStrategy(use_orig_params=False)
|
|
assert not strategy.kwargs["use_orig_params"]
|
|
|
|
|
|
@mock.patch("torch.distributed.init_process_group")
|
|
def test_set_timeout(init_process_group_mock):
|
|
"""Test that the timeout gets passed to the ``torch.distributed.init_process_group`` function."""
|
|
test_timedelta = timedelta(seconds=30)
|
|
strategy = FSDPStrategy(timeout=test_timedelta, parallel_devices=[torch.device("cpu")])
|
|
strategy.cluster_environment = LightningEnvironment()
|
|
strategy.accelerator = Mock()
|
|
strategy.setup_environment()
|
|
process_group_backend = strategy._get_process_group_backend()
|
|
global_rank = strategy.cluster_environment.global_rank()
|
|
world_size = strategy.cluster_environment.world_size()
|
|
init_process_group_mock.assert_called_with(
|
|
process_group_backend, rank=global_rank, world_size=world_size, timeout=test_timedelta
|
|
)
|
|
|
|
|
|
@RunIf(min_torch="2.0")
|
|
@mock.patch("lightning.pytorch.strategies.fsdp._load_raw_module_state")
|
|
def test_fsdp_strategy_load_optimizer_states_multiple(_, tmp_path):
|
|
strategy = FSDPStrategy(parallel_devices=[torch.device("cpu")], state_dict_type="full")
|
|
trainer = Trainer()
|
|
trainer.state.fn = TrainerFn.FITTING
|
|
strategy._lightning_module = Mock(trainer=trainer)
|
|
spec = torch.optim.Optimizer
|
|
|
|
# More states than optimizers configured
|
|
strategy.optimizers = [Mock(spec=spec)]
|
|
checkpoint = {"state_dict": {}, "optimizer_states": [{"state": {}}, {"state": {}}]}
|
|
torch.save(checkpoint, tmp_path / "two-states.ckpt")
|
|
with pytest.raises(RuntimeError, match="1 optimizers but the checkpoint contains 2 optimizers to load"):
|
|
strategy.load_checkpoint(tmp_path / "two-states.ckpt")
|
|
|
|
# Fewer states than optimizers configured
|
|
strategy.optimizers = [Mock(spec=spec), Mock(spec=spec)]
|
|
checkpoint = {"state_dict": {}, "optimizer_states": [{"state": {}}]}
|
|
torch.save(checkpoint, tmp_path / "one-state.ckpt")
|
|
with pytest.raises(RuntimeError, match="2 optimizers but the checkpoint contains 1 optimizers to load"):
|
|
strategy.load_checkpoint(tmp_path / "one-state.ckpt")
|
|
|
|
|
|
@RunIf(min_cuda_gpus=2, skip_windows=True, standalone=True)
|
|
@pytest.mark.parametrize("wrap_min_params", [2, 1024, 100000000])
|
|
def test_fsdp_strategy_save_optimizer_states(tmpdir, wrap_min_params):
|
|
"""Test to ensure that the full state dict and optimizer states is saved when using FSDP strategy.
|
|
|
|
Based on `wrap_min_params`, the model will be fully wrapped, half wrapped, and not wrapped at all. If the model can
|
|
be restored to DDP, it means that the optimizer states were saved correctly.
|
|
|
|
"""
|
|
model = TestFSDPModelAutoWrapped(wrap_min_params=wrap_min_params)
|
|
|
|
strategy = FSDPStrategy(auto_wrap_policy=partial(size_based_auto_wrap_policy, min_num_params=wrap_min_params))
|
|
trainer = Trainer(
|
|
default_root_dir=tmpdir,
|
|
accelerator="gpu",
|
|
devices=2,
|
|
strategy=strategy,
|
|
precision="16-mixed",
|
|
max_epochs=1,
|
|
barebones=True,
|
|
)
|
|
|
|
trainer.fit(model)
|
|
model_path = os.path.join(tmpdir, "last.ckpt")
|
|
model_path = trainer.strategy.broadcast(model_path)
|
|
trainer.save_checkpoint(model_path)
|
|
|
|
model_state_dict = trainer.strategy.lightning_module_state_dict()
|
|
optimizer_state_dict = trainer.strategy.optimizer_state(model.optimizers())
|
|
|
|
if trainer.global_rank != 0:
|
|
assert len(model_state_dict) == 0
|
|
|
|
if trainer.global_rank != 0 and _TORCH_GREATER_EQUAL_2_1 or not _TORCH_GREATER_EQUAL_2_0:
|
|
assert len(optimizer_state_dict) == 0
|
|
|
|
if not _TORCH_GREATER_EQUAL_2_0:
|
|
return
|
|
|
|
# restore model to ddp
|
|
model = TestBoringModel()
|
|
trainer = Trainer(default_root_dir=tmpdir, accelerator="gpu", devices=2, strategy="ddp", max_epochs=1)
|
|
|
|
# This step will restore the model and optimizer states
|
|
trainer.fit(model, ckpt_path=model_path)
|
|
|
|
# Get the model and optimizer states from the restored ddp model
|
|
restored_model_state_dict = trainer.strategy.lightning_module_state_dict()
|
|
restored_optimizer_state_dict = trainer.strategy.optimizer_state(model.optimizers())
|
|
|
|
if trainer.global_rank == 0:
|
|
# assert everything is the same
|
|
assert len(model_state_dict) == len(restored_model_state_dict)
|
|
assert len(optimizer_state_dict) == len(restored_optimizer_state_dict)
|
|
|
|
torch.testing.assert_close(model_state_dict, restored_model_state_dict, atol=0, rtol=0)
|
|
torch.testing.assert_close(optimizer_state_dict, restored_optimizer_state_dict, atol=0, rtol=0)
|
|
|
|
trainer.strategy.barrier()
|
|
|
|
|
|
@RunIf(min_cuda_gpus=2, skip_windows=True, standalone=True)
|
|
@pytest.mark.parametrize("wrap_min_params", [2, 1024, 100000000])
|
|
def test_fsdp_strategy_load_optimizer_states(tmpdir, wrap_min_params):
|
|
"""Test to ensure that the full state dict and optimizer states can be load when using FSDP strategy.
|
|
|
|
Based on `wrap_min_params`, the model will be fully wrapped, half wrapped, and not wrapped at all. If the DDP model
|
|
can be restored to FSDP, it means that the optimizer states were restored correctly.
|
|
|
|
"""
|
|
|
|
# restore model to ddp
|
|
model = TestBoringModel()
|
|
trainer = Trainer(default_root_dir=tmpdir, accelerator="gpu", devices=2, strategy="ddp", max_epochs=1)
|
|
|
|
# This step will restore the model and optimizer states
|
|
trainer.fit(model)
|
|
model_path = os.path.join(tmpdir, "last.ckpt")
|
|
model_path = trainer.strategy.broadcast(model_path)
|
|
trainer.save_checkpoint(model_path)
|
|
|
|
# Get the model and optimizer states from the restored ddp model
|
|
model_state_dict = trainer.strategy.lightning_module_state_dict()
|
|
optimizer_state_dict = trainer.strategy.optimizer_state(model.optimizers())
|
|
|
|
# Build a new FSDP model
|
|
model = TestFSDPModelAutoWrapped(wrap_min_params=wrap_min_params)
|
|
|
|
strategy = FSDPStrategy(auto_wrap_policy=partial(size_based_auto_wrap_policy, min_num_params=wrap_min_params))
|
|
trainer = Trainer(
|
|
default_root_dir=tmpdir,
|
|
accelerator="gpu",
|
|
devices=2,
|
|
strategy=strategy,
|
|
precision="16-mixed",
|
|
max_epochs=1,
|
|
barebones=True,
|
|
)
|
|
|
|
trainer.fit(model, ckpt_path=model_path)
|
|
|
|
restored_model_state_dict = trainer.strategy.lightning_module_state_dict()
|
|
restored_optimizer_state_dict = trainer.strategy.optimizer_state(model.optimizers())
|
|
|
|
if trainer.global_rank != 0:
|
|
assert len(restored_model_state_dict) == 0
|
|
|
|
if trainer.global_rank != 0 and _TORCH_GREATER_EQUAL_2_1 or not _TORCH_GREATER_EQUAL_2_0:
|
|
assert len(restored_optimizer_state_dict) == 0
|
|
|
|
if trainer.global_rank == 0 and _TORCH_GREATER_EQUAL_2_0:
|
|
# assert everything is the same
|
|
assert len(model_state_dict) == len(restored_model_state_dict)
|
|
assert len(optimizer_state_dict) == len(restored_optimizer_state_dict)
|
|
torch.testing.assert_close(model_state_dict, restored_model_state_dict, atol=0, rtol=0)
|
|
torch.testing.assert_close(optimizer_state_dict, restored_optimizer_state_dict, atol=0, rtol=0)
|
|
|
|
trainer.strategy.barrier()
|
|
|
|
|
|
@RunIf(min_cuda_gpus=2, skip_windows=True, standalone=True)
|
|
@pytest.mark.parametrize(
|
|
("precision", "expected_dtype"),
|
|
[
|
|
("32-true", torch.float32),
|
|
],
|
|
)
|
|
def test_configure_model(precision, expected_dtype):
|
|
"""Test that the module under configure_model gets moved to the right device and dtype."""
|
|
trainer = Trainer(
|
|
accelerator="cuda",
|
|
devices=2,
|
|
strategy=FSDPStrategy(auto_wrap_policy=always_wrap_policy),
|
|
precision=precision,
|
|
max_epochs=1,
|
|
)
|
|
|
|
class MyModel(BoringModel):
|
|
def configure_model(self):
|
|
self.layer = torch.nn.Linear(32, 2)
|
|
# The model is on the CPU until after `.setup()``
|
|
# TODO: Support initialization on meta device
|
|
expected_device = torch.device("cpu")
|
|
assert self.layer.weight.device == expected_device
|
|
assert self.layer.weight.dtype == expected_dtype
|
|
|
|
def configure_optimizers(self):
|
|
# There is some issue with SGD optimizer state in FSDP
|
|
return torch.optim.AdamW(self.layer.parameters(), lr=0.1)
|
|
|
|
def on_fit_start(self):
|
|
# Parameters get sharded in `.setup()` and moved to the target device
|
|
assert self.layer.weight.device == torch.device("cuda", self.local_rank)
|
|
assert self.layer.weight.dtype == expected_dtype
|
|
|
|
model = MyModel()
|
|
trainer.fit(model)
|
|
|
|
|
|
@mock.patch("lightning.pytorch.strategies.fsdp._TORCH_GREATER_EQUAL_2_0", False)
|
|
@mock.patch("lightning.pytorch.strategies.fsdp.torch.load")
|
|
@mock.patch("lightning.pytorch.strategies.fsdp._load_raw_module_state")
|
|
def test_load_save_optimizer_torch_lt_2_0(_, __, tmp_path):
|
|
strategy = FSDPStrategy(state_dict_type="full")
|
|
with pytest.warns(UserWarning, match="does not support saving the optimizer state"):
|
|
strategy.optimizer_state(Mock())
|
|
|
|
file = tmp_path / "test.ckpt"
|
|
file.touch()
|
|
trainer = Trainer()
|
|
trainer.state.fn = TrainerFn.FITTING
|
|
strategy._lightning_module = Mock(trainer=trainer)
|
|
with pytest.warns(UserWarning, match="does not support loading the optimizer state"):
|
|
strategy.load_checkpoint(file)
|
|
|
|
|
|
@mock.patch("lightning.pytorch.strategies.fsdp._TORCH_GREATER_EQUAL_2_0", False)
|
|
def test_sharded_state_dict_type_support():
|
|
"""Test that the sharded state dict type is supported."""
|
|
with pytest.raises(
|
|
NotImplementedError,
|
|
match=escape("`FSDPStrategy(state_dict_type='sharded')` is not supported in PyTorch < 2.0"),
|
|
):
|
|
FSDPStrategy(state_dict_type="sharded")
|
|
|
|
|
|
def test_save_checkpoint_storage_options(tmp_path):
|
|
"""Test that the FSDP strategy does not accept storage options for saving checkpoints."""
|
|
strategy = FSDPStrategy()
|
|
with pytest.raises(TypeError, match=escape("FSDPStrategy.save_checkpoint(..., storage_options=...)` is not")):
|
|
strategy.save_checkpoint(filepath=tmp_path, checkpoint=Mock(), storage_options=Mock())
|
|
|
|
|
|
@RunIf(min_torch="2.0.0")
|
|
@mock.patch("torch.distributed.checkpoint.save_state_dict", return_value=MagicMock())
|
|
@mock.patch("lightning.pytorch.strategies.fsdp.FSDPStrategy.broadcast", lambda _, x: x)
|
|
@mock.patch("lightning.pytorch.strategies.fsdp._get_full_state_dict_context", return_value=MagicMock())
|
|
@mock.patch("lightning.pytorch.strategies.fsdp._get_sharded_state_dict_context", return_value=MagicMock())
|
|
@mock.patch("lightning.fabric.plugins.io.torch_io._atomic_save", return_value=Mock())
|
|
@mock.patch("lightning.pytorch.strategies.fsdp.shutil", return_value=MagicMock())
|
|
def test_fsdp_save_checkpoint_path_exists(shutil_mock, torch_save_mock, __, ___, ____, tmp_path):
|
|
strategy = FSDPStrategy(state_dict_type="full")
|
|
|
|
# state_dict_type='full', path exists, path is not a sharded checkpoint: error
|
|
path = tmp_path / "not-empty"
|
|
path.mkdir()
|
|
(path / "file").touch()
|
|
assert not _is_sharded_checkpoint(path)
|
|
with pytest.raises(IsADirectoryError, match="exists and is a directory"):
|
|
strategy.save_checkpoint(Mock(), filepath=path)
|
|
|
|
# state_dict_type='full', path exists, path is a sharded checkpoint: no error (overwrite)
|
|
path = tmp_path / "sharded-checkpoint"
|
|
path.mkdir()
|
|
(path / "meta.pt").touch()
|
|
assert _is_sharded_checkpoint(path)
|
|
model = Mock(spec=FullyShardedDataParallel)
|
|
model.modules.return_value = [model]
|
|
strategy.save_checkpoint(Mock(), filepath=path)
|
|
shutil_mock.rmtree.assert_called_once_with(path)
|
|
|
|
# state_dict_type='full', path exists, path is a file: no error (overwrite)
|
|
path = tmp_path / "file.pt"
|
|
path.touch()
|
|
model = Mock(spec=FullyShardedDataParallel)
|
|
model.modules.return_value = [model]
|
|
torch_save_mock.reset_mock()
|
|
strategy.save_checkpoint(Mock(), filepath=path)
|
|
torch_save_mock.assert_called_once()
|
|
|
|
strategy = FSDPStrategy(state_dict_type="sharded")
|
|
|
|
# state_dict_type='sharded', path exists, path is a folder: no error (overwrite)
|
|
path = tmp_path / "not-empty-2"
|
|
path.mkdir()
|
|
(path / "file").touch()
|
|
model = Mock(spec=FullyShardedDataParallel)
|
|
model.modules.return_value = [model]
|
|
strategy.save_checkpoint({"state_dict": {}, "optimizer_states": {"": {}}}, filepath=path)
|
|
assert (path / "file").exists()
|
|
|
|
# state_dict_type='sharded', path exists, path is a file: no error (overwrite)
|
|
path = tmp_path / "file-2.pt"
|
|
path.touch()
|
|
model = Mock(spec=FullyShardedDataParallel)
|
|
model.modules.return_value = [model]
|
|
strategy.save_checkpoint({"state_dict": {}, "optimizer_states": {"": {}}}, filepath=path)
|
|
assert path.is_dir()
|
|
|
|
|
|
@mock.patch("lightning.pytorch.strategies.fsdp.FSDPStrategy.broadcast", lambda _, x: x)
|
|
def test_fsdp_save_checkpoint_unknown_state_dict_type(tmp_path):
|
|
strategy = FSDPStrategy(state_dict_type="invalid")
|
|
with pytest.raises(ValueError, match="Unknown state_dict_type"):
|
|
strategy.save_checkpoint(checkpoint=Mock(), filepath=tmp_path)
|
|
|
|
|
|
def test_fsdp_load_unknown_checkpoint_type(tmp_path):
|
|
"""Test that the strategy validates the contents at the checkpoint path."""
|
|
strategy = FSDPStrategy()
|
|
strategy.model = Mock()
|
|
strategy._lightning_module = Mock()
|
|
path = tmp_path / "empty_dir" # neither a single file nor a directory with meta file
|
|
path.mkdir()
|
|
with pytest.raises(ValueError, match="does not point to a valid checkpoint"):
|
|
strategy.load_checkpoint(checkpoint_path=path)
|
|
|
|
|
|
class TestFSDPCheckpointModel(BoringModel):
|
|
def __init__(self, params_to_compare=None):
|
|
super().__init__()
|
|
self.layer = torch.nn.Sequential(torch.nn.Linear(32, 32), torch.nn.ReLU(), torch.nn.Linear(32, 2))
|
|
self.params_to_compare = params_to_compare
|
|
|
|
def configure_optimizers(self):
|
|
# SGD's FSDP optimier state is fixed in https://github.com/pytorch/pytorch/pull/99214
|
|
return torch.optim.AdamW(self.parameters(), lr=0.1)
|
|
|
|
def on_train_start(self):
|
|
if self.params_to_compare is None:
|
|
return
|
|
for p0, p1 in zip(self.params_to_compare, self.trainer.model.parameters()):
|
|
torch.testing.assert_close(p0, p1, atol=0, rtol=0, equal_nan=True)
|
|
|
|
|
|
@RunIf(min_cuda_gpus=2, standalone=True, min_torch="2.0.0")
|
|
def test_save_load_sharded_state_dict(tmp_path):
|
|
"""Test FSDP saving and loading with the sharded state dict format."""
|
|
strategy = FSDPStrategy(auto_wrap_policy={nn.Linear}, state_dict_type="sharded")
|
|
trainer_kwargs = {
|
|
"default_root_dir": tmp_path,
|
|
"accelerator": "cuda",
|
|
"devices": 2,
|
|
"max_epochs": 1,
|
|
"enable_progress_bar": False,
|
|
"enable_model_summary": False,
|
|
"logger": False,
|
|
}
|
|
|
|
# Initial training
|
|
model = TestFSDPCheckpointModel()
|
|
trainer = Trainer(**trainer_kwargs, strategy=strategy)
|
|
trainer.fit(model)
|
|
params_before = deepcopy(list(trainer.model.parameters()))
|
|
|
|
checkpoint_path = Path(trainer.strategy.broadcast(trainer.checkpoint_callback.best_model_path))
|
|
assert set(os.listdir(checkpoint_path)) == {"meta.pt", ".metadata", "__0_0.distcp", "__1_0.distcp"}
|
|
|
|
metadata = torch.load(checkpoint_path / "meta.pt")
|
|
assert "pytorch-lightning_version" in metadata
|
|
assert len(metadata["callbacks"]) == 1 # model checkpoint callback
|
|
assert "state_dict" not in metadata
|
|
assert "optimizer_states" not in metadata
|
|
|
|
# Load checkpoint and continue training
|
|
trainer_kwargs.update(max_epochs=2)
|
|
model = TestFSDPCheckpointModel(params_to_compare=params_before)
|
|
strategy = FSDPStrategy(auto_wrap_policy={nn.Linear}, state_dict_type="sharded")
|
|
trainer = Trainer(**trainer_kwargs, strategy=strategy)
|
|
trainer.fit(model, ckpt_path=checkpoint_path)
|
|
|
|
|
|
@mock.patch("lightning.pytorch.strategies.fsdp.torch.load")
|
|
@mock.patch("lightning.pytorch.strategies.fsdp._lazy_load")
|
|
@mock.patch("lightning.pytorch.strategies.fsdp._load_raw_module_state")
|
|
def test_fsdp_lazy_load_full_state_dict(_, lazy_load_mock, torch_load_mock, tmp_path):
|
|
"""Test that loading a single file (full state) is lazy to reduce peak CPU memory usage."""
|
|
model = BoringModel()
|
|
checkpoint = {"state_dict": model.state_dict()}
|
|
lazy_load_mock.return_value = checkpoint
|
|
|
|
strategy = FSDPStrategy()
|
|
trainer = Trainer()
|
|
model.trainer = trainer
|
|
strategy._lightning_module = model
|
|
strategy.model = model
|
|
|
|
file = tmp_path / "test.ckpt"
|
|
file.touch()
|
|
|
|
strategy.load_checkpoint(checkpoint_path=file)
|
|
if _TORCH_GREATER_EQUAL_2_0:
|
|
lazy_load_mock.assert_called_once()
|
|
else:
|
|
torch_load_mock.assert_called_once()
|
|
|
|
|
|
@RunIf(min_cuda_gpus=2, skip_windows=True, standalone=True)
|
|
@pytest.mark.parametrize(
|
|
("precision", "expected_dtype"),
|
|
[
|
|
("32-true", torch.float32),
|
|
("16-true", torch.float16),
|
|
pytest.param("bf16-true", torch.bfloat16, marks=RunIf(bf16_cuda=True)),
|
|
],
|
|
)
|
|
def test_module_init_context(precision, expected_dtype):
|
|
"""Test that the module under the init-context gets moved to the right device and dtype."""
|
|
|
|
class Model(BoringModel):
|
|
def configure_optimizers(self):
|
|
return torch.optim.Adam(self.parameters(), lr=1e-2)
|
|
|
|
def on_train_start(self):
|
|
# Parameters get sharded in `FSDPStrategy.setup()` and moved to the target device
|
|
assert self.layer.weight.device == torch.device("cuda", self.local_rank)
|
|
assert self.layer.weight.dtype == expected_dtype
|
|
optimizer = self.optimizers(use_pl_optimizer=False)
|
|
assert optimizer.param_groups[0]["params"][0].device.type == "cuda"
|
|
|
|
def _run_setup_assertions(empty_init, expected_device):
|
|
trainer = Trainer(
|
|
accelerator="cuda",
|
|
devices=2,
|
|
strategy=FSDPStrategy(auto_wrap_policy={torch.nn.Linear}),
|
|
precision=precision,
|
|
max_steps=1,
|
|
barebones=True,
|
|
)
|
|
with trainer.init_module(empty_init=empty_init):
|
|
model = Model()
|
|
|
|
# The model is on the CPU/meta-device until after `FSDPStrategy.setup()`
|
|
assert model.layer.weight.device == expected_device
|
|
assert model.layer.weight.dtype == expected_dtype
|
|
trainer.fit(model)
|
|
|
|
# Case 1: No empty init
|
|
_run_setup_assertions(empty_init=False, expected_device=torch.device("cpu"))
|
|
|
|
if _TORCH_GREATER_EQUAL_2_1:
|
|
# Case 2: Empty-init with PyTorch >= 2.1 supports meta device
|
|
_run_setup_assertions(empty_init=True, expected_device=torch.device("meta"))
|
|
else:
|
|
# Case 2: Empty-init with PyTorch < 2.1 only supports `torch.empty()`-init
|
|
_run_setup_assertions(empty_init=True, expected_device=torch.device("cpu"))
|