2022-11-21 13:58:37 +00:00
|
|
|
# Copyright The Lightning AI team.
|
|
|
|
#
|
|
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
|
|
# you may not use this file except in compliance with the License.
|
|
|
|
# You may obtain a copy of the License at
|
|
|
|
#
|
|
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
#
|
|
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
|
|
# See the License for the specific language governing permissions and
|
|
|
|
# limitations under the License.
|
2023-06-06 19:41:26 +00:00
|
|
|
import contextlib
|
|
|
|
import datetime
|
|
|
|
import os
|
2023-04-28 00:27:06 +00:00
|
|
|
from datetime import timedelta
|
2023-04-16 18:11:49 +00:00
|
|
|
from re import escape
|
2022-11-21 13:58:37 +00:00
|
|
|
from unittest import mock
|
2022-12-06 15:45:33 +00:00
|
|
|
from unittest.mock import ANY, MagicMock, Mock
|
2022-11-21 13:58:37 +00:00
|
|
|
|
|
|
|
import pytest
|
|
|
|
import torch
|
|
|
|
import torch.nn as nn
|
2023-06-06 19:41:26 +00:00
|
|
|
from lightning_utilities.core.imports import RequirementCache
|
2022-11-21 13:58:37 +00:00
|
|
|
from torch.optim import Adam
|
|
|
|
|
2023-07-14 18:00:52 +00:00
|
|
|
import lightning.fabric
|
2023-06-06 19:41:26 +00:00
|
|
|
from lightning.fabric import Fabric
|
2023-04-28 00:27:06 +00:00
|
|
|
from lightning.fabric.plugins.environments import LightningEnvironment
|
2023-02-01 20:34:38 +00:00
|
|
|
from lightning.fabric.strategies import FSDPStrategy
|
2023-07-14 18:00:52 +00:00
|
|
|
from lightning.fabric.strategies.fsdp import _FSDPBackwardSyncControl, fsdp_overlap_step_with_backward
|
|
|
|
from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_1_12, _TORCH_GREATER_EQUAL_2_1
|
2023-03-03 16:55:48 +00:00
|
|
|
from tests_fabric.helpers.runif import RunIf
|
2023-04-16 18:11:49 +00:00
|
|
|
from tests_fabric.strategies.test_single_device import _MyFabricGradNorm
|
2022-11-21 13:58:37 +00:00
|
|
|
|
|
|
|
if _TORCH_GREATER_EQUAL_1_12:
|
2022-12-07 02:55:47 +00:00
|
|
|
from torch.distributed.fsdp.fully_sharded_data_parallel import CPUOffload, FullyShardedDataParallel, MixedPrecision
|
2022-11-21 13:58:37 +00:00
|
|
|
|
|
|
|
|
2023-02-01 20:34:38 +00:00
|
|
|
@mock.patch("lightning.fabric.strategies.fsdp._TORCH_GREATER_EQUAL_1_12", False)
|
2022-11-21 13:58:37 +00:00
|
|
|
def test_fsdp_support(*_):
|
|
|
|
with pytest.raises(NotImplementedError, match="`FSDPStrategy` is supported from PyTorch v1.12.0"):
|
|
|
|
FSDPStrategy()
|
|
|
|
|
|
|
|
|
|
|
|
@RunIf(min_torch="1.12")
|
2022-12-07 02:55:47 +00:00
|
|
|
def test_fsdp_custom_mixed_precision():
|
2022-11-21 13:58:37 +00:00
|
|
|
"""Test that passing a custom mixed precision config works."""
|
|
|
|
config = MixedPrecision()
|
|
|
|
strategy = FSDPStrategy(mixed_precision=config)
|
|
|
|
assert strategy.mixed_precision_config == config
|
|
|
|
|
|
|
|
|
2022-12-07 02:55:47 +00:00
|
|
|
@RunIf(min_torch="1.12")
|
|
|
|
def test_fsdp_cpu_offload():
|
|
|
|
"""Test the different ways cpu offloading can be enabled."""
|
|
|
|
# bool
|
|
|
|
strategy = FSDPStrategy(cpu_offload=True)
|
|
|
|
assert strategy.cpu_offload == CPUOffload(offload_params=True)
|
|
|
|
|
|
|
|
# dataclass
|
|
|
|
config = CPUOffload()
|
|
|
|
strategy = FSDPStrategy(cpu_offload=config)
|
|
|
|
assert strategy.cpu_offload == config
|
|
|
|
|
|
|
|
|
2023-07-15 16:07:09 +00:00
|
|
|
@RunIf(min_torch="1.12")
|
|
|
|
def test_fsdp_sharding_strategy():
|
|
|
|
"""Test the different ways the sharding strategy can be set."""
|
|
|
|
from torch.distributed.fsdp import ShardingStrategy
|
|
|
|
|
|
|
|
# default
|
|
|
|
strategy = FSDPStrategy()
|
|
|
|
assert strategy.sharding_strategy == ShardingStrategy.FULL_SHARD
|
|
|
|
|
|
|
|
# enum
|
|
|
|
strategy = FSDPStrategy(sharding_strategy=ShardingStrategy.SHARD_GRAD_OP)
|
|
|
|
assert strategy.sharding_strategy == ShardingStrategy.SHARD_GRAD_OP
|
|
|
|
|
|
|
|
# string
|
|
|
|
strategy = FSDPStrategy(sharding_strategy="NO_SHARD")
|
|
|
|
assert strategy.sharding_strategy == ShardingStrategy.NO_SHARD
|
|
|
|
strategy = FSDPStrategy(sharding_strategy="no_shard")
|
|
|
|
assert strategy.sharding_strategy == ShardingStrategy.NO_SHARD
|
|
|
|
|
|
|
|
|
2022-11-21 13:58:37 +00:00
|
|
|
@RunIf(min_torch="1.12")
|
2023-04-11 19:58:53 +00:00
|
|
|
@pytest.mark.parametrize("torch_ge_2_0", [False, True])
|
|
|
|
def test_fsdp_setup_optimizer_validation(torch_ge_2_0):
|
2022-11-21 13:58:37 +00:00
|
|
|
"""Test that `setup_optimizer()` validates the param groups and reference to FSDP parameters."""
|
|
|
|
module = nn.Linear(2, 2)
|
|
|
|
strategy = FSDPStrategy(parallel_devices=[torch.device("cpu")])
|
|
|
|
|
2023-04-11 19:58:53 +00:00
|
|
|
with mock.patch("lightning.fabric.strategies.fsdp._TORCH_GREATER_EQUAL_2_0", torch_ge_2_0):
|
|
|
|
bad_optimizer_1 = Adam([{"params": [module.weight]}, {"params": [module.bias], "lr": 1e-3}])
|
|
|
|
bad_optimizer_2 = Adam(module.parameters())
|
2022-11-21 13:58:37 +00:00
|
|
|
|
2023-04-11 19:58:53 +00:00
|
|
|
if torch_ge_2_0:
|
|
|
|
strategy.setup_optimizer(bad_optimizer_1)
|
|
|
|
strategy.setup_optimizer(bad_optimizer_2)
|
|
|
|
else:
|
|
|
|
with pytest.raises(ValueError, match="does not support multiple param groups"):
|
|
|
|
strategy.setup_optimizer(bad_optimizer_1)
|
|
|
|
with pytest.raises(ValueError, match="The optimizer does not seem to reference any FSDP parameter"):
|
|
|
|
strategy.setup_optimizer(bad_optimizer_2)
|
|
|
|
|
|
|
|
|
|
|
|
@RunIf(min_torch="2.0.0")
|
|
|
|
@mock.patch("lightning.fabric.strategies.fsdp.FSDPStrategy.setup_module")
|
|
|
|
def test_fsdp_setup_use_orig_params(_):
|
|
|
|
module = nn.Linear(2, 2)
|
|
|
|
optimizer = Adam(module.parameters())
|
|
|
|
|
|
|
|
strategy = FSDPStrategy(parallel_devices=[torch.device("cpu")], use_orig_params=False)
|
|
|
|
assert not strategy._fsdp_kwargs["use_orig_params"]
|
|
|
|
|
|
|
|
with pytest.raises(ValueError, match=r"`FSDPStrategy\(use_orig_params=False\)` but this is not supported"):
|
|
|
|
strategy.setup_module_and_optimizers(module, optimizer)
|
|
|
|
|
|
|
|
strategy = FSDPStrategy(parallel_devices=[torch.device("cpu")])
|
|
|
|
assert strategy._fsdp_kwargs["use_orig_params"]
|
|
|
|
strategy.setup_module_and_optimizers(module, optimizer)
|
|
|
|
assert strategy._fsdp_kwargs["use_orig_params"]
|
2022-11-21 13:58:37 +00:00
|
|
|
|
|
|
|
|
|
|
|
@RunIf(min_torch="1.12")
|
|
|
|
def test_fsdp_no_backward_sync():
|
|
|
|
"""Test that the backward sync control calls `.no_sync()`, and only on a module wrapped in
|
|
|
|
FullyShardedDataParallel."""
|
|
|
|
|
|
|
|
strategy = FSDPStrategy()
|
|
|
|
assert isinstance(strategy._backward_sync_control, _FSDPBackwardSyncControl)
|
|
|
|
|
|
|
|
with pytest.raises(
|
|
|
|
TypeError, match="is only possible if the module passed to .* is wrapped in `FullyShardedDataParallel`"
|
2023-04-24 21:57:08 +00:00
|
|
|
), strategy._backward_sync_control.no_backward_sync(Mock()):
|
|
|
|
pass
|
2022-11-21 13:58:37 +00:00
|
|
|
|
|
|
|
module = MagicMock(spec=FullyShardedDataParallel)
|
|
|
|
with strategy._backward_sync_control.no_backward_sync(module):
|
|
|
|
pass
|
|
|
|
|
|
|
|
module.no_sync.assert_called_once()
|
2022-12-06 15:45:33 +00:00
|
|
|
|
|
|
|
|
|
|
|
@RunIf(min_torch="1.12")
|
2023-07-14 18:00:52 +00:00
|
|
|
def test_fsdp_activation_checkpointing_support(monkeypatch):
|
2022-12-06 15:45:33 +00:00
|
|
|
"""Test that we error out if activation checkpointing requires a newer PyTorch version."""
|
2023-07-14 18:00:52 +00:00
|
|
|
monkeypatch.setattr(lightning.fabric.strategies.fsdp, "_TORCH_GREATER_EQUAL_1_13", False)
|
|
|
|
with pytest.raises(ValueError, match="activation_checkpointing` requires torch >= 1.13.0"):
|
2022-12-06 15:45:33 +00:00
|
|
|
FSDPStrategy(activation_checkpointing=Mock())
|
|
|
|
|
2023-07-14 18:00:52 +00:00
|
|
|
monkeypatch.setattr(lightning.fabric.strategies.fsdp, "_TORCH_GREATER_EQUAL_2_1", False)
|
|
|
|
with pytest.raises(ValueError, match="activation_checkpointing_policy` requires torch >= 2.1.0"):
|
|
|
|
FSDPStrategy(activation_checkpointing_policy=Mock())
|
|
|
|
|
2022-12-06 15:45:33 +00:00
|
|
|
|
|
|
|
@RunIf(min_torch="1.13")
|
|
|
|
def test_fsdp_activation_checkpointing():
|
|
|
|
"""Test that the FSDP strategy can apply activation checkpointing to the given layers."""
|
|
|
|
|
|
|
|
class Block1(nn.Linear):
|
|
|
|
pass
|
|
|
|
|
|
|
|
class Block2(nn.Linear):
|
|
|
|
pass
|
|
|
|
|
|
|
|
class Model(nn.Module):
|
|
|
|
def __init__(self):
|
|
|
|
super().__init__()
|
|
|
|
self.layer0 = nn.Sequential(Block1(4, 4), Block1(5, 5))
|
|
|
|
self.layer1 = Block2(2, 2)
|
|
|
|
self.layer2 = nn.Linear(3, 3)
|
|
|
|
|
2023-07-14 18:00:52 +00:00
|
|
|
if _TORCH_GREATER_EQUAL_2_1:
|
|
|
|
from torch.distributed.fsdp.wrap import ModuleWrapPolicy
|
|
|
|
|
2023-07-15 15:39:28 +00:00
|
|
|
strategy = FSDPStrategy(activation_checkpointing_policy={Block1})
|
2023-07-14 18:00:52 +00:00
|
|
|
assert set(strategy._activation_checkpointing_kwargs) == {"auto_wrap_policy"}
|
2023-07-15 15:39:28 +00:00
|
|
|
assert isinstance(strategy._activation_checkpointing_kwargs["auto_wrap_policy"], ModuleWrapPolicy)
|
2022-12-06 15:45:33 +00:00
|
|
|
|
2023-07-14 18:00:52 +00:00
|
|
|
strategy = FSDPStrategy(activation_checkpointing_policy=ModuleWrapPolicy({Block1, Block2}))
|
|
|
|
assert set(strategy._activation_checkpointing_kwargs) == {"auto_wrap_policy"}
|
2023-07-15 15:39:28 +00:00
|
|
|
assert isinstance(strategy._activation_checkpointing_kwargs["auto_wrap_policy"], ModuleWrapPolicy)
|
2023-07-14 18:00:52 +00:00
|
|
|
else:
|
|
|
|
strategy = FSDPStrategy(activation_checkpointing=Block1)
|
|
|
|
assert set(strategy._activation_checkpointing_kwargs) == {"check_fn"}
|
|
|
|
|
|
|
|
strategy = FSDPStrategy(activation_checkpointing=[Block1, Block2])
|
|
|
|
assert set(strategy._activation_checkpointing_kwargs) == {"check_fn"}
|
2022-12-06 15:45:33 +00:00
|
|
|
|
2023-07-15 15:39:28 +00:00
|
|
|
strategy = FSDPStrategy(activation_checkpointing_policy={Block1})
|
|
|
|
assert set(strategy._activation_checkpointing_kwargs) == {"check_fn"}
|
|
|
|
|
|
|
|
strategy = FSDPStrategy(activation_checkpointing_policy={Block1, Block2})
|
|
|
|
assert set(strategy._activation_checkpointing_kwargs) == {"check_fn"}
|
|
|
|
|
2022-12-06 15:45:33 +00:00
|
|
|
strategy._parallel_devices = [torch.device("cuda", 0)]
|
|
|
|
with mock.patch(
|
|
|
|
"torch.distributed.fsdp.fully_sharded_data_parallel.FullyShardedDataParallel"
|
|
|
|
) as fsdp_mock, mock.patch(
|
|
|
|
"torch.distributed.algorithms._checkpoint.checkpoint_wrapper.apply_activation_checkpointing"
|
|
|
|
) as ckpt_mock:
|
|
|
|
strategy.setup_module(Model())
|
2023-07-14 18:00:52 +00:00
|
|
|
ckpt_mock.assert_called_with(
|
|
|
|
fsdp_mock(), checkpoint_wrapper_fn=ANY, **strategy._activation_checkpointing_kwargs
|
|
|
|
)
|
2023-02-27 23:44:13 +00:00
|
|
|
|
|
|
|
|
|
|
|
@RunIf(min_torch="1.13")
|
|
|
|
def test_fsdp_grad_clipping_value_error():
|
|
|
|
strategy = FSDPStrategy()
|
|
|
|
with pytest.raises(
|
|
|
|
NotImplementedError,
|
|
|
|
match=(
|
|
|
|
"FSDP currently does not support to clip gradients by value. "
|
|
|
|
"Consider clipping by norm instead or choose another strategy!"
|
|
|
|
),
|
|
|
|
):
|
|
|
|
strategy.clip_gradients_value(Mock(), Mock(), Mock())
|
|
|
|
|
|
|
|
|
|
|
|
class _MyFSDPFabricGradientNorm(_MyFabricGradNorm):
|
|
|
|
def after_backward(self, model, optimizer):
|
|
|
|
self.clip_gradients(model, optimizer, max_norm=0.05, error_if_nonfinite=True)
|
|
|
|
|
|
|
|
with model._forward_module.summon_full_params(model._forward_module):
|
|
|
|
parameters = model.parameters()
|
|
|
|
grad_norm = torch.linalg.vector_norm(
|
|
|
|
torch.stack([torch.linalg.vector_norm(p.grad.detach(), 2, dtype=torch.float32) for p in parameters]),
|
|
|
|
2,
|
|
|
|
)
|
|
|
|
torch.testing.assert_close(grad_norm, torch.tensor(0.05, device=self.device))
|
|
|
|
|
|
|
|
|
|
|
|
@pytest.mark.parametrize(
|
|
|
|
"precision",
|
|
|
|
["32-true", "16-mixed", pytest.param("bf16-mixed", marks=RunIf(bf16_cuda=True))],
|
|
|
|
)
|
|
|
|
@RunIf(min_cuda_gpus=2, standalone=True)
|
|
|
|
@pytest.mark.xfail(reason="Testing with FSDP is not yet correct") # TODO: Investigate testing with fsdp
|
|
|
|
def test_fsdp_grad_clipping_norm(precision):
|
|
|
|
fabric = _MyFSDPFabricGradientNorm(accelerator="cuda", devices=2, precision=precision, strategy="fsdp")
|
|
|
|
fabric.run()
|
2023-04-16 18:11:49 +00:00
|
|
|
|
|
|
|
|
|
|
|
@RunIf(min_torch="2.0.0")
|
|
|
|
def test_fsdp_save_checkpoint_storage_options(tmp_path):
|
|
|
|
"""Test that the FSDP strategy does not accept storage options for saving checkpoints."""
|
|
|
|
strategy = FSDPStrategy()
|
|
|
|
with pytest.raises(TypeError, match=escape("FSDPStrategy.save_checkpoint(..., storage_options=...)` is not")):
|
|
|
|
strategy.save_checkpoint(path=tmp_path, state=Mock(), storage_options=Mock())
|
|
|
|
|
|
|
|
|
|
|
|
@RunIf(min_torch="2.0.0")
|
|
|
|
@mock.patch("lightning.fabric.strategies.fsdp.FSDPStrategy.broadcast", lambda _, x: x)
|
|
|
|
def test_fsdp_save_checkpoint_folder_exists(tmp_path):
|
|
|
|
path = tmp_path / "exists"
|
|
|
|
path.mkdir()
|
|
|
|
(path / "file").touch()
|
|
|
|
strategy = FSDPStrategy()
|
|
|
|
with pytest.raises(FileExistsError, match="exists and is not empty"):
|
|
|
|
strategy.save_checkpoint(path=path, state=Mock())
|
|
|
|
|
|
|
|
|
|
|
|
@RunIf(min_torch="2.0.0")
|
|
|
|
@mock.patch("lightning.fabric.strategies.fsdp.FSDPStrategy.broadcast", lambda _, x: x)
|
|
|
|
def test_fsdp_save_checkpoint_one_fsdp_module_required(tmp_path):
|
|
|
|
"""Test that the FSDP strategy can only save one FSDP model per checkpoint."""
|
|
|
|
strategy = FSDPStrategy()
|
|
|
|
|
|
|
|
# missing FSDP model
|
|
|
|
with pytest.raises(ValueError, match="Could not find a FSDP model in the provided checkpoint state."):
|
|
|
|
strategy.save_checkpoint(path=tmp_path, state={})
|
|
|
|
with pytest.raises(ValueError, match="Could not find a FSDP model in the provided checkpoint state."):
|
|
|
|
strategy.save_checkpoint(path=tmp_path, state={"model": torch.nn.Linear(3, 3)})
|
|
|
|
|
|
|
|
# multiple FSDP models
|
|
|
|
model1 = Mock(spec=FullyShardedDataParallel)
|
|
|
|
model2 = Mock(spec=FullyShardedDataParallel)
|
|
|
|
with pytest.raises(ValueError, match="Found multiple FSDP modules in the given state."):
|
|
|
|
strategy.save_checkpoint(path=tmp_path, state={"model1": model1, "model2": model2})
|
|
|
|
|
|
|
|
|
|
|
|
@RunIf(min_torch="2.0.0")
|
|
|
|
def test_fsdp_load_checkpoint_no_state(tmp_path):
|
|
|
|
"""Test that the FSDP strategy can't load the full state without access to a model instance from the user."""
|
|
|
|
strategy = FSDPStrategy()
|
|
|
|
with pytest.raises(ValueError, match=escape("Got FSDPStrategy.load_checkpoint(..., state=None")):
|
|
|
|
strategy.load_checkpoint(path=tmp_path, state=None)
|
|
|
|
with pytest.raises(ValueError, match=escape("Got FSDPStrategy.load_checkpoint(..., state={})")):
|
|
|
|
strategy.load_checkpoint(path=tmp_path, state={})
|
|
|
|
|
|
|
|
|
|
|
|
@RunIf(min_torch="2.0.0")
|
|
|
|
@mock.patch("lightning.fabric.strategies.fsdp.FSDPStrategy.broadcast", lambda _, x: x)
|
|
|
|
def test_fsdp_load_checkpoint_one_fsdp_module_required(tmp_path):
|
|
|
|
"""Test that the FSDP strategy can only load one FSDP model per checkpoint."""
|
|
|
|
strategy = FSDPStrategy()
|
|
|
|
|
|
|
|
# missing FSDP model
|
|
|
|
with pytest.raises(ValueError, match="Could not find a FSDP model in the provided checkpoint state."):
|
|
|
|
strategy.load_checkpoint(path=tmp_path, state={"other": "data"})
|
|
|
|
with pytest.raises(ValueError, match="Could not find a FSDP model in the provided checkpoint state."):
|
|
|
|
strategy.load_checkpoint(path=tmp_path, state={"model": torch.nn.Linear(3, 3)})
|
|
|
|
|
|
|
|
# multiple FSDP models
|
|
|
|
model1 = Mock(spec=FullyShardedDataParallel)
|
|
|
|
model2 = Mock(spec=FullyShardedDataParallel)
|
|
|
|
with pytest.raises(ValueError, match="Found multiple FSDP modules in the given state."):
|
|
|
|
strategy.load_checkpoint(path=tmp_path, state={"model1": model1, "model2": model2})
|
2023-04-28 00:27:06 +00:00
|
|
|
|
|
|
|
|
2023-05-11 17:02:30 +00:00
|
|
|
@RunIf(min_torch="2.0.0")
|
|
|
|
@mock.patch("lightning.fabric.strategies.fsdp.FSDPStrategy.broadcast", lambda _, x: x)
|
|
|
|
def test_fsdp_save_checkpoint_unknown_state_dict_type(tmp_path):
|
|
|
|
strategy = FSDPStrategy(state_dict_type="invalid")
|
|
|
|
model = Mock(spec=FullyShardedDataParallel)
|
|
|
|
with pytest.raises(ValueError, match="Unknown state_dict_type"):
|
|
|
|
strategy.save_checkpoint(path=tmp_path, state={"model": model})
|
|
|
|
|
|
|
|
|
2023-05-31 15:30:07 +00:00
|
|
|
@RunIf(min_torch="2.0.0")
|
2023-07-10 13:24:09 +00:00
|
|
|
def test_fsdp_load_unknown_checkpoint_type(tmp_path):
|
2023-05-31 15:30:07 +00:00
|
|
|
"""Test that the strategy validates the contents at the checkpoint path."""
|
|
|
|
strategy = FSDPStrategy()
|
|
|
|
model = Mock(spec=FullyShardedDataParallel)
|
|
|
|
path = tmp_path / "empty_dir" # neither a single file nor a directory with meta file
|
|
|
|
path.mkdir()
|
|
|
|
with pytest.raises(ValueError, match="does not point to a valid checkpoint"):
|
|
|
|
strategy.load_checkpoint(path=path, state={"model": model})
|
|
|
|
|
|
|
|
|
2023-04-28 00:27:06 +00:00
|
|
|
@RunIf(min_torch="1.12")
|
|
|
|
@mock.patch("torch.distributed.init_process_group")
|
|
|
|
def test_set_timeout(init_process_group_mock):
|
|
|
|
"""Test that the timeout gets passed to the ``torch.distributed.init_process_group`` function."""
|
|
|
|
test_timedelta = timedelta(seconds=30)
|
|
|
|
strategy = FSDPStrategy(timeout=test_timedelta, parallel_devices=[torch.device("cpu")])
|
|
|
|
strategy.cluster_environment = LightningEnvironment()
|
|
|
|
strategy.accelerator = Mock()
|
|
|
|
strategy.setup_environment()
|
|
|
|
process_group_backend = strategy._get_process_group_backend()
|
|
|
|
global_rank = strategy.cluster_environment.global_rank()
|
|
|
|
world_size = strategy.cluster_environment.world_size()
|
|
|
|
init_process_group_mock.assert_called_with(
|
|
|
|
process_group_backend, rank=global_rank, world_size=world_size, timeout=test_timedelta
|
|
|
|
)
|
2023-06-06 19:41:26 +00:00
|
|
|
|
|
|
|
|
|
|
|
class SubBlock(nn.Sequential):
|
|
|
|
def __init__(self, feature_dim: int) -> None:
|
|
|
|
super().__init__(
|
|
|
|
nn.Linear(feature_dim, feature_dim, bias=False),
|
|
|
|
torch.nn.LayerNorm([feature_dim]),
|
|
|
|
nn.ReLU(),
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
class Block(nn.Module):
|
|
|
|
def __init__(self, feature_dim: int) -> None:
|
|
|
|
super().__init__()
|
|
|
|
self.left = SubBlock(feature_dim)
|
|
|
|
self.right = SubBlock(feature_dim)
|
|
|
|
|
|
|
|
def forward(self, x):
|
|
|
|
return self.left(x) + self.right(x)
|
|
|
|
|
|
|
|
|
|
|
|
class StatusChecker:
|
|
|
|
def __init__(self, fabric: Fabric) -> None:
|
|
|
|
self._fabric = fabric
|
|
|
|
self.is_rank_zero = fabric.is_global_zero
|
|
|
|
self.pids = tuple(int(pid) for pid in fabric.all_gather(os.getpid()).cpu().numpy())
|
|
|
|
|
|
|
|
@contextlib.contextmanager
|
|
|
|
def guard_region(self, name: str):
|
|
|
|
"""Handle errors and graceful shutdown.
|
|
|
|
|
|
|
|
`pytest` interprets SystemExit as a faiure, so it will interpret shutdown of non-zero ranks as a test failure.
|
|
|
|
This is confusing (since it logs "FAILED"), but more importantly the orphan rank will continue trying to execute
|
|
|
|
the rest of the test suite. So instead we add calls to `os._exit` which actually forces the process to shut
|
|
|
|
down.
|
|
|
|
"""
|
|
|
|
success = False
|
|
|
|
try:
|
|
|
|
yield
|
|
|
|
success = True
|
|
|
|
|
|
|
|
except BaseException:
|
|
|
|
if self.is_rank_zero:
|
|
|
|
raise
|
|
|
|
|
|
|
|
finally:
|
|
|
|
# All reduce will wait for all workers to enter. This means that if a
|
|
|
|
# worker dies the status check will deadlock.
|
|
|
|
import psutil
|
|
|
|
|
|
|
|
worker_status = tuple(psutil.Process(pid).status() for pid in self.pids)
|
|
|
|
if any(
|
|
|
|
status in (psutil.STATUS_DEAD, psutil.STATUS_STOPPED, psutil.STATUS_ZOMBIE) for status in worker_status
|
|
|
|
):
|
|
|
|
if self.is_rank_zero:
|
|
|
|
raise RuntimeError(f"({name}) Dead workers: [{', '.join(worker_status)}]")
|
|
|
|
else:
|
|
|
|
os._exit(1)
|
|
|
|
|
|
|
|
rank_success = self._fabric.all_gather(success).cpu()
|
|
|
|
if not rank_success.all():
|
|
|
|
if self.is_rank_zero > 0:
|
|
|
|
os._exit(1)
|
|
|
|
elif success:
|
|
|
|
raise RuntimeError(f"({name}) Failure on different rank: {rank_success}")
|
|
|
|
|
|
|
|
def finalize(self) -> None:
|
|
|
|
if not self.is_rank_zero:
|
|
|
|
os._exit(0)
|
|
|
|
|
|
|
|
def __del__(self) -> None:
|
|
|
|
self.finalize()
|
|
|
|
|
|
|
|
|
2023-06-07 18:33:53 +00:00
|
|
|
@pytest.mark.skip(reason="Flaky test") # See also: https://github.com/Lightning-AI/lightning/pull/17774
|
2023-06-06 19:41:26 +00:00
|
|
|
@RunIf(min_torch="2.0.0", min_cuda_gpus=2, skip_windows=True, standalone=True)
|
|
|
|
@pytest.mark.skipif(not RequirementCache("psutil"), reason="psutil is needed to help prevent deadlocks.")
|
|
|
|
@pytest.mark.parametrize(
|
|
|
|
"checkpoint",
|
|
|
|
[(Block,), (SubBlock,), (Block, SubBlock, nn.Linear), None],
|
|
|
|
)
|
|
|
|
def test_apply_optimizer_in_backward(checkpoint):
|
|
|
|
from torch.distributed.fsdp._traversal_utils import _get_fsdp_handles
|
|
|
|
|
|
|
|
num_gpus = 2
|
|
|
|
num_blocks = 8
|
|
|
|
feature_dim = 256
|
|
|
|
|
|
|
|
# This bound is dependent on the topology of the model. The grads for each
|
|
|
|
# Block are two `feature_dim ** 2` Tensors (`left` and `right` Linear layers)
|
|
|
|
# times four. (FP32 = 4 bytes / element)
|
|
|
|
#
|
|
|
|
# In the baseline case grads for all Blocks are materialized at once, whereas
|
|
|
|
# in the test case only one Block should have grads in memory which adds a
|
|
|
|
# `(num_blocks - 1)` factor.
|
|
|
|
#
|
|
|
|
# However, there is one final correction to be made. In the baseline case peak
|
|
|
|
# memory occurs at the end of the backward pass; at that time activations will
|
|
|
|
# have been freed and will offset the memory relative to the base case. (Which
|
|
|
|
# reaches peak memory after the first block when most activations are still
|
|
|
|
# in memory.) It's difficult to estimate the exact correction factor
|
|
|
|
# (particularly since it varies with activation checkpointing strategy), but
|
|
|
|
# three is close enough for our purposes.
|
|
|
|
upper_savings_bound = 4 * feature_dim**2 * 2 * (num_blocks - 1)
|
|
|
|
lower_savings_bound = upper_savings_bound / 3
|
|
|
|
|
|
|
|
strategy = FSDPStrategy(
|
2023-07-15 15:39:28 +00:00
|
|
|
auto_wrap_policy={Block},
|
2023-06-06 19:41:26 +00:00
|
|
|
activation_checkpointing=checkpoint,
|
|
|
|
timeout=datetime.timedelta(seconds=10),
|
|
|
|
)
|
|
|
|
fabric = Fabric(accelerator="cuda", devices=num_gpus, strategy=strategy)
|
|
|
|
fabric.launch()
|
|
|
|
status_checker = StatusChecker(fabric)
|
|
|
|
|
|
|
|
def make_model_and_optimizers():
|
|
|
|
torch.manual_seed(0)
|
|
|
|
|
|
|
|
with fabric.init_module():
|
|
|
|
backbone = [Block(feature_dim) for _ in range(num_blocks)]
|
|
|
|
model = nn.Sequential(*backbone, nn.Linear(feature_dim, 1, bias=False))
|
|
|
|
optimizers = [torch.optim.SGD(layer.parameters(), lr=0.1, momentum=0.9) for layer in model]
|
|
|
|
|
|
|
|
return fabric.setup_module(model), fabric.setup_optimizers(*optimizers)
|
|
|
|
|
|
|
|
with status_checker.guard_region("Instantiate model."):
|
|
|
|
baseline_model, baseline_optimizers = make_model_and_optimizers()
|
|
|
|
test_model, test_optimizers = make_model_and_optimizers()
|
|
|
|
fabric.seed_everything(1337 + fabric.global_rank)
|
|
|
|
|
|
|
|
# Check that initialization is identical.
|
|
|
|
for p0, p1 in zip(baseline_model.parameters(), test_model.parameters()):
|
|
|
|
assert (p0 == p1).all()
|
|
|
|
|
|
|
|
num_steps = 50
|
|
|
|
for step in range(num_steps):
|
|
|
|
# Normal pattern: `.backward()` followed by `.step()`
|
|
|
|
with status_checker.guard_region(f"({step + 1} / {num_steps}) Baseline"):
|
|
|
|
x = torch.randn((4, feature_dim), device=fabric.device)
|
|
|
|
y_baseline = baseline_model(x)
|
|
|
|
|
|
|
|
torch.cuda.empty_cache()
|
|
|
|
torch.cuda.reset_peak_memory_stats(fabric.device)
|
|
|
|
baseline_start_memory = torch.cuda.max_memory_allocated(fabric.device)
|
|
|
|
fabric.backward(y_baseline.mean().abs())
|
|
|
|
del y_baseline
|
|
|
|
for optimizer in baseline_optimizers:
|
|
|
|
optimizer.step()
|
|
|
|
optimizer.zero_grad(set_to_none=True)
|
|
|
|
|
|
|
|
# FSDP sometimes holds onto grad memory until the next forward
|
|
|
|
# pass. In order to provide a fair comparison (and thus an
|
|
|
|
# accurate check that moving the step call into backward actually
|
|
|
|
# delivers the expected memory savings) we need to "help" the
|
|
|
|
# baseline case a bit.
|
|
|
|
param_handles = _get_fsdp_handles(baseline_model._forward_module)
|
|
|
|
for h in param_handles:
|
|
|
|
h._clear_grads_if_needed()
|
|
|
|
|
|
|
|
baseline_peak_memory = torch.cuda.max_memory_allocated(fabric.device)
|
|
|
|
|
|
|
|
# `.step()` interleaved with `.backward()`
|
|
|
|
with status_checker.guard_region(f"({step + 1} / {num_steps}) Optimizer in backward"):
|
|
|
|
y_test = test_model(x)
|
|
|
|
torch.cuda.empty_cache()
|
|
|
|
torch.cuda.reset_peak_memory_stats(fabric.device)
|
|
|
|
test_start_memory = torch.cuda.memory_allocated(fabric.device)
|
|
|
|
with fsdp_overlap_step_with_backward(test_optimizers, test_model):
|
|
|
|
fabric.backward(y_test.mean().abs())
|
|
|
|
del y_test
|
|
|
|
|
|
|
|
test_peak_memory = torch.cuda.max_memory_allocated(fabric.device)
|
|
|
|
|
|
|
|
# Make sure the parameter updates match.
|
|
|
|
with status_checker.guard_region(f"({step + 1} / {num_steps}) Check equality"):
|
|
|
|
for idx, (p0, p1) in enumerate(zip(baseline_model.parameters(), test_model.parameters())):
|
|
|
|
assert (p0 == p1).all(), (step, idx, p0, p1)
|
|
|
|
|
|
|
|
# The first step is going to be odd due to lazy initialization of optimizer state.
|
|
|
|
if not step:
|
|
|
|
continue
|
|
|
|
|
|
|
|
with status_checker.guard_region(f"({step + 1} / {num_steps}) Confirm memory reduction"):
|
|
|
|
baseline_delta = baseline_peak_memory - baseline_start_memory
|
|
|
|
test_delta = test_peak_memory - test_start_memory
|
|
|
|
assert (baseline_delta - test_delta) >= lower_savings_bound, (baseline_delta, test_delta)
|
|
|
|
assert (baseline_delta - test_delta) <= upper_savings_bound, (baseline_delta, test_delta)
|
|
|
|
|
|
|
|
status_checker.finalize()
|
|
|
|
assert (pid := os.getpid()) == status_checker.pids[0], f"Orphan worker: {pid}"
|