496 lines
21 KiB
Python
496 lines
21 KiB
Python
# Copyright The Lightning AI team.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
import os
|
|
from copy import deepcopy
|
|
from pathlib import Path
|
|
from unittest import mock
|
|
|
|
import pytest
|
|
import torch
|
|
from torch.nn import Parameter
|
|
|
|
from lightning.fabric import Fabric
|
|
from lightning.fabric.plugins import FSDPPrecision
|
|
from lightning.fabric.strategies import FSDPStrategy
|
|
from lightning.fabric.utilities.imports import (
|
|
_TORCH_GREATER_EQUAL_1_12,
|
|
_TORCH_GREATER_EQUAL_2_0,
|
|
_TORCH_GREATER_EQUAL_2_1,
|
|
)
|
|
from lightning.fabric.wrappers import _FabricOptimizer
|
|
from tests_fabric.helpers.models import BoringFabric
|
|
from tests_fabric.helpers.runif import RunIf
|
|
from tests_fabric.test_fabric import BoringModel
|
|
|
|
if _TORCH_GREATER_EQUAL_1_12:
|
|
from torch.distributed.fsdp import FlatParameter, FullyShardedDataParallel, OptimStateKeyType
|
|
from torch.distributed.fsdp.wrap import always_wrap_policy, wrap
|
|
|
|
|
|
class _MyFabric(BoringFabric):
|
|
def get_model(self):
|
|
model = torch.nn.Sequential(torch.nn.Linear(32, 32), torch.nn.ReLU(), torch.nn.Linear(32, 2))
|
|
self.num_wrapped = 4
|
|
return model
|
|
|
|
def step(self, model, batch):
|
|
wrapped_layers = [m for m in model.modules() if isinstance(m, FullyShardedDataParallel)]
|
|
assert len(wrapped_layers) == self.num_wrapped
|
|
assert (self.num_wrapped == 4) == isinstance(model._forward_module, FullyShardedDataParallel)
|
|
|
|
precision = self._precision
|
|
assert isinstance(precision, FSDPPrecision)
|
|
if precision.precision == "16-mixed":
|
|
param_dtype = torch.float32
|
|
reduce_dtype = buffer_dtype = torch.float16
|
|
elif precision.precision == "bf16-mixed":
|
|
param_dtype = torch.float32
|
|
reduce_dtype = buffer_dtype = torch.bfloat16
|
|
elif precision.precision == "16-true":
|
|
param_dtype = reduce_dtype = buffer_dtype = torch.float16
|
|
elif precision.precision == "bf16-true":
|
|
param_dtype = reduce_dtype = buffer_dtype = torch.bfloat16
|
|
else:
|
|
raise ValueError(f"Unknown precision {precision.precision}")
|
|
|
|
for layer in wrapped_layers:
|
|
assert layer.mixed_precision.param_dtype == param_dtype
|
|
assert layer.mixed_precision.reduce_dtype == reduce_dtype
|
|
assert layer.mixed_precision.buffer_dtype == buffer_dtype
|
|
|
|
output = model(batch)
|
|
return torch.nn.functional.mse_loss(output, torch.ones_like(output))
|
|
|
|
|
|
class _MyFabricManualWrapping(_MyFabric):
|
|
def get_model(self):
|
|
model = super().get_model()
|
|
for i, layer in enumerate(model):
|
|
if i % 2 == 0:
|
|
model[i] = wrap(layer)
|
|
self.num_wrapped = 2
|
|
return model
|
|
|
|
|
|
@RunIf(min_cuda_gpus=2, standalone=True, min_torch="2.0.0")
|
|
@pytest.mark.parametrize("precision", ["16-mixed", pytest.param("bf16-mixed", marks=RunIf(bf16_cuda=True))])
|
|
@pytest.mark.parametrize("manual_wrapping", [True, False])
|
|
def test_fsdp_train_save_load(tmp_path, manual_wrapping, precision):
|
|
"""Test FSDP training, saving and loading with different wrapping and precision settings."""
|
|
fabric_cls = _MyFabricManualWrapping if manual_wrapping else _MyFabric
|
|
fabric = fabric_cls(
|
|
accelerator="cuda", strategy=FSDPStrategy(auto_wrap_policy=always_wrap_policy), devices=2, precision=precision
|
|
)
|
|
fabric.run()
|
|
|
|
checkpoint_path = fabric.broadcast(str(tmp_path / "fsdp-checkpoint"))
|
|
|
|
params_before = deepcopy(list(fabric.model.parameters()))
|
|
state = {"model": fabric.model, "optimizer": fabric.optimizer, "steps": 1}
|
|
fabric.save(checkpoint_path, state)
|
|
assert set(os.listdir(checkpoint_path)) == {"meta.pt", ".metadata", "__0_0.distcp", "__1_0.distcp"}
|
|
|
|
# re-init all objects and resume
|
|
fabric = fabric_cls(
|
|
accelerator="cuda", strategy=FSDPStrategy(auto_wrap_policy=always_wrap_policy), devices=2, precision=precision
|
|
)
|
|
fabric.run()
|
|
|
|
# check correctness with loaded state
|
|
state = {"model": fabric.model, "optimizer": fabric.optimizer, "steps": 0}
|
|
metadata = fabric.load(checkpoint_path, state)
|
|
for p0, p1 in zip(params_before, fabric.model.parameters()):
|
|
torch.testing.assert_close(p0, p1, atol=0, rtol=0, equal_nan=True)
|
|
|
|
# check user data in state reloaded
|
|
assert state["steps"] == 1
|
|
assert not metadata
|
|
|
|
# attempt to load a key not in the metadata checkpoint
|
|
state = {"model": fabric.model, "coconut": 11}
|
|
with pytest.raises(KeyError, match="The requested state contains a key 'coconut' that does not exist"):
|
|
fabric.load(checkpoint_path, state)
|
|
|
|
# `strict=False` ignores the missing key
|
|
state = {"model": fabric.model, "coconut": 11}
|
|
fabric.load(checkpoint_path, state, strict=False)
|
|
assert state["coconut"] == 11
|
|
|
|
|
|
@RunIf(min_cuda_gpus=2, standalone=True, min_torch="2.0.0")
|
|
def test_fsdp_save_full_state_dict(tmp_path):
|
|
"""Test that FSDP saves the full state into a single file with `state_dict_type="full"`."""
|
|
fabric = BoringFabric(
|
|
accelerator="cuda",
|
|
strategy=FSDPStrategy(auto_wrap_policy=always_wrap_policy, state_dict_type="full"),
|
|
devices=2,
|
|
)
|
|
fabric.run()
|
|
|
|
checkpoint_path = Path(fabric.broadcast(str(tmp_path / "fsdp-checkpoint.pt")))
|
|
|
|
state = {"model": fabric.model, "optimizer": fabric.optimizer, "steps": 1}
|
|
fabric.save(checkpoint_path, state)
|
|
|
|
checkpoint = torch.load(checkpoint_path)
|
|
assert checkpoint["steps"] == 1
|
|
loaded_state_dict = checkpoint["model"]
|
|
|
|
# assert the correct state model was saved
|
|
with FullyShardedDataParallel.summon_full_params(fabric.model):
|
|
state_dict = fabric.model.state_dict()
|
|
assert set(loaded_state_dict.keys()) == set(state_dict.keys())
|
|
for param_name in state_dict:
|
|
assert torch.equal(loaded_state_dict[param_name], state_dict[param_name].cpu())
|
|
params_before = [p.cpu() for p in fabric.model.parameters()]
|
|
|
|
# assert the correct optimizer state was saved
|
|
optimizer_state_before = FullyShardedDataParallel.full_optim_state_dict(
|
|
fabric.model, fabric.optimizer, rank0_only=False
|
|
)
|
|
assert set(checkpoint["optimizer"].keys()) == set(optimizer_state_before.keys()) == {"state", "param_groups"}
|
|
|
|
# 1. verify the FSDP state can be loaded back into a FSDP model/strategy directly
|
|
fabric = BoringFabric(
|
|
accelerator="cuda",
|
|
strategy=FSDPStrategy(auto_wrap_policy=always_wrap_policy),
|
|
devices=2,
|
|
)
|
|
fabric.run()
|
|
metadata = fabric.load(checkpoint_path, {"model": fabric.model, "optimizer": fabric.optimizer})
|
|
assert metadata == {"steps": 1}
|
|
|
|
with FullyShardedDataParallel.summon_full_params(fabric.model):
|
|
params_after = list(fabric.model.parameters())
|
|
assert all(torch.equal(p0.cpu(), p1.cpu()) for p0, p1 in zip(params_before, params_after))
|
|
|
|
# assert the correct optimizer state was loaded
|
|
optimizer_state_after = FullyShardedDataParallel.full_optim_state_dict(
|
|
fabric.model, fabric.optimizer, rank0_only=False
|
|
)
|
|
assert set(optimizer_state_after.keys()) == set(optimizer_state_before.keys()) == {"state", "param_groups"}
|
|
torch.testing.assert_close(optimizer_state_after["state"], optimizer_state_before["state"], atol=0, rtol=0)
|
|
assert optimizer_state_after["param_groups"] == optimizer_state_before["param_groups"]
|
|
|
|
# run a step to verify the optimizer state is correct
|
|
fabric.run()
|
|
|
|
# 2. verify the FSDP state can be loaded back into a single-device model/strategy
|
|
fabric = BoringFabric(accelerator="cpu", devices=1)
|
|
fabric.run()
|
|
metadata = fabric.load(checkpoint_path, {"model": fabric.model, "optimizer": fabric.optimizer})
|
|
assert metadata == {"steps": 1}
|
|
params_after = list(fabric.model.parameters())
|
|
assert all(torch.equal(p0, p1) for p0, p1 in zip(params_before, params_after))
|
|
|
|
# get optimizer state after loading
|
|
normal_checkpoint_path = Path(fabric.broadcast(str(tmp_path / "normal-checkpoint.pt")))
|
|
fabric.save(normal_checkpoint_path, {"model": fabric.model, "optimizer": fabric.optimizer, "steps": 2})
|
|
optimizer_state_after = torch.load(normal_checkpoint_path)["optimizer"]
|
|
optimizer_state_after = FullyShardedDataParallel.rekey_optim_state_dict(
|
|
optimizer_state_after, optim_state_key_type=OptimStateKeyType.PARAM_NAME, model=fabric.model
|
|
)
|
|
|
|
# assert the correct optimizer state was loaded
|
|
assert set(optimizer_state_after.keys()) == set(optimizer_state_before.keys()) == {"state", "param_groups"}
|
|
torch.testing.assert_close(optimizer_state_after["state"], optimizer_state_before["state"], atol=0, rtol=0)
|
|
|
|
# run a step to verify the optimizer state is correct
|
|
fabric.run()
|
|
|
|
# 3. verify that a single-device model/strategy states can be loaded into a FSDP model/strategy
|
|
fabric = BoringFabric(
|
|
accelerator="cuda",
|
|
strategy=FSDPStrategy(auto_wrap_policy=always_wrap_policy),
|
|
devices=2,
|
|
)
|
|
fabric.run()
|
|
metadata = fabric.load(normal_checkpoint_path, {"model": fabric.model, "optimizer": fabric.optimizer})
|
|
assert metadata == {"steps": 2}
|
|
|
|
with FullyShardedDataParallel.summon_full_params(fabric.model):
|
|
params_after = list(fabric.model.parameters())
|
|
assert all(torch.equal(p0.cpu(), p1.cpu()) for p0, p1 in zip(params_before, params_after))
|
|
|
|
# assert the correct optimizer state was loaded
|
|
optimizer_state_after = FullyShardedDataParallel.full_optim_state_dict(
|
|
fabric.model, fabric.optimizer, rank0_only=False
|
|
)
|
|
assert set(optimizer_state_after.keys()) == set(optimizer_state_before.keys()) == {"state", "param_groups"}
|
|
torch.testing.assert_close(optimizer_state_after["state"], optimizer_state_before["state"], atol=0, rtol=0)
|
|
assert optimizer_state_after["param_groups"] == optimizer_state_before["param_groups"]
|
|
|
|
# run a step to verify the optimizer state is correct
|
|
fabric.run()
|
|
|
|
|
|
@RunIf(min_cuda_gpus=2, standalone=True, min_torch="2.0.0")
|
|
def test_fsdp_load_full_state_dict_into_sharded_model(tmp_path):
|
|
"""Test that the strategy can load a full-state checkpoint into a FSDP sharded model."""
|
|
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
|
|
|
|
fabric = BoringFabric(accelerator="cuda", devices=1)
|
|
fabric.seed_everything(0)
|
|
fabric.run()
|
|
|
|
# Save a full-state-dict checkpoint
|
|
checkpoint_path = Path(fabric.broadcast(str(tmp_path / "full-checkpoint.pt")))
|
|
state = {"model": fabric.model, "optimizer": fabric.optimizer, "steps": 1}
|
|
fabric.save(checkpoint_path, state)
|
|
|
|
# Gather all weights and store a copy manually
|
|
with FSDP.summon_full_params(fabric.model, writeback=False, rank0_only=False):
|
|
params_before = torch.cat([p.cpu().view(-1) for p in fabric.model.parameters()])
|
|
|
|
# Create a FSDP sharded model
|
|
fabric = BoringFabric(
|
|
accelerator="cuda",
|
|
strategy=FSDPStrategy(auto_wrap_policy=always_wrap_policy),
|
|
devices=2,
|
|
)
|
|
fabric.run()
|
|
|
|
state = {"model": fabric.model, "optimizer": fabric.optimizer, "steps": 44}
|
|
fabric.load(checkpoint_path, state)
|
|
assert state["steps"] == 1
|
|
|
|
# Gather all weights and compare
|
|
with FSDP.summon_full_params(fabric.model, writeback=False, rank0_only=False):
|
|
params_after = torch.cat([p.cpu().view(-1) for p in fabric.model.parameters()])
|
|
assert torch.equal(params_before, params_after)
|
|
|
|
# Create a raw state-dict checkpoint to test `Fabric.load_raw` too
|
|
raw_checkpoint_path = checkpoint_path.with_name("model-state-dict")
|
|
if fabric.global_rank == 0:
|
|
checkpoint = torch.load(checkpoint_path)
|
|
torch.save(checkpoint["model"], raw_checkpoint_path)
|
|
fabric.barrier()
|
|
|
|
fabric.run()
|
|
fabric.load_raw(raw_checkpoint_path, fabric.model)
|
|
|
|
# Gather all weights and compare
|
|
with FSDP.summon_full_params(fabric.model, writeback=False, rank0_only=False):
|
|
params_after = torch.cat([p.cpu().view(-1) for p in fabric.model.parameters()])
|
|
assert torch.equal(params_before, params_after)
|
|
|
|
|
|
@RunIf(min_cuda_gpus=2, skip_windows=True, standalone=True, min_torch="1.12")
|
|
@pytest.mark.parametrize("move_to_device", [True, False])
|
|
@mock.patch("lightning.fabric.wrappers._FabricModule")
|
|
def test_setup_module_move_to_device(fabric_module_mock, move_to_device):
|
|
"""Test that `move_to_device` does nothing, FSDP decides which device parameters get moved to which device
|
|
(sharding)."""
|
|
strategy = FSDPStrategy(auto_wrap_policy=always_wrap_policy)
|
|
fabric = Fabric(accelerator="cuda", devices=2, strategy=strategy)
|
|
fabric.launch()
|
|
|
|
model = torch.nn.Linear(10, 10, bias=False) # total params: 10 * 10 = 100
|
|
fabric_model = fabric.setup_module(model, move_to_device=move_to_device)
|
|
fabric_module_mock.assert_not_called()
|
|
|
|
assert len(list(fabric_model.parameters())) == 1
|
|
# the linear layer got sharded and each part is on the expected device
|
|
assert next(fabric_model.parameters()).device == torch.device("cuda", fabric.local_rank)
|
|
assert next(fabric_model.parameters()).numel() == 50
|
|
if _TORCH_GREATER_EQUAL_2_0:
|
|
# In PyTorch >= 2.0 we set `use_orig_params=True` and don't see flattened parameters
|
|
assert isinstance(next(fabric_model.parameters()), Parameter)
|
|
else:
|
|
assert isinstance(next(fabric_model.parameters()), FlatParameter)
|
|
|
|
# The _DeviceDtypeModuleMixin currently can't represent the device in a meaningful way for sharded models
|
|
assert fabric_model.device == torch.device("cpu")
|
|
assert fabric.device == torch.device("cuda", fabric.local_rank)
|
|
|
|
|
|
@RunIf(min_cuda_gpus=2, skip_windows=True, standalone=True, min_torch="2.0.0")
|
|
def test_setup_with_orig_params_and_multiple_param_groups():
|
|
"""Test that Fabric sets `use_orig_params` for the user when jointly setting up model and optimizer."""
|
|
strategy = FSDPStrategy(auto_wrap_policy=always_wrap_policy)
|
|
fabric = Fabric(accelerator="cuda", devices=2, strategy=strategy)
|
|
fabric.launch()
|
|
|
|
model = torch.nn.Sequential(
|
|
torch.nn.Linear(10, 10, bias=False),
|
|
torch.nn.Linear(5, 2, bias=False),
|
|
)
|
|
optimizer = torch.optim.Adam(
|
|
[
|
|
{"params": model[0].parameters(), "lr": 1e-2},
|
|
{"params": model[1].parameters(), "lr": 1e-6},
|
|
]
|
|
)
|
|
|
|
# set up model and optimizer jointly
|
|
wrapped_model, wrapped_optimizer = fabric.setup(model, optimizer)
|
|
|
|
assert fabric.strategy._fsdp_kwargs["use_orig_params"]
|
|
assert isinstance(wrapped_optimizer, _FabricOptimizer)
|
|
assert len(wrapped_optimizer.param_groups) == 2
|
|
for i in range(2):
|
|
layer = wrapped_model._forward_module.module[i]
|
|
assert isinstance(layer, FullyShardedDataParallel)
|
|
assert torch.equal(wrapped_optimizer.param_groups[i]["params"][0], layer.weight)
|
|
|
|
# A regular parameter as a view into the flattened parameters
|
|
assert isinstance(layer.weight, torch.nn.Parameter)
|
|
assert not isinstance(layer.weight, FlatParameter)
|
|
|
|
|
|
@RunIf(min_cuda_gpus=2, skip_windows=True, standalone=True, dynamo=True)
|
|
@mock.patch.dict(os.environ, {})
|
|
@pytest.mark.parametrize(
|
|
"compile_after_setup",
|
|
[
|
|
False,
|
|
# https://github.com/pytorch/pytorch/issues/97811
|
|
pytest.param(True, marks=RunIf(min_python="3.9")),
|
|
],
|
|
)
|
|
def test_compile(compile_after_setup):
|
|
"""Test that the model can be compiled before and after the model is wrapped in FSDP."""
|
|
model = BoringModel()
|
|
strategy = FSDPStrategy(auto_wrap_policy=always_wrap_policy)
|
|
fabric = Fabric(accelerator="cuda", devices=2, strategy=strategy)
|
|
fabric.launch()
|
|
|
|
if not compile_after_setup:
|
|
model = torch.compile(model)
|
|
|
|
model = fabric.setup(model)
|
|
|
|
if compile_after_setup:
|
|
model = torch.compile(model)
|
|
|
|
for _ in range(3):
|
|
model(torch.rand(2, 32, device=fabric.device)).sum().backward()
|
|
|
|
|
|
@RunIf(min_cuda_gpus=2, skip_windows=True, standalone=True)
|
|
@pytest.mark.parametrize(
|
|
("precision", "expected_dtype"),
|
|
[
|
|
("32-true", torch.float32),
|
|
("16-true", torch.float16),
|
|
pytest.param("bf16-true", torch.bfloat16, marks=RunIf(bf16_cuda=True)),
|
|
],
|
|
)
|
|
def test_module_init_context(precision, expected_dtype):
|
|
"""Test that the module under the init-context gets moved to the right device and dtype."""
|
|
fabric = Fabric(
|
|
accelerator="cuda",
|
|
devices=2,
|
|
strategy=FSDPStrategy(auto_wrap_policy=always_wrap_policy),
|
|
precision=precision,
|
|
)
|
|
fabric.launch()
|
|
|
|
def _run_setup_assertions(empty_init, expected_device):
|
|
with fabric.init_module(empty_init=empty_init):
|
|
model = torch.nn.Linear(100, 100, bias=False)
|
|
|
|
# The model is on the CPU/meta-device until after `.setup()``
|
|
assert model.weight.device == expected_device
|
|
assert model.weight.dtype == expected_dtype
|
|
model = fabric.setup(model)
|
|
# Parameters get sharded in `.setup()` and moved to the target device
|
|
assert model.weight.device == torch.device("cuda", fabric.local_rank)
|
|
assert model.weight.dtype == expected_dtype
|
|
|
|
# Case 1: No empty init
|
|
_run_setup_assertions(empty_init=False, expected_device=torch.device("cpu"))
|
|
|
|
if _TORCH_GREATER_EQUAL_2_1:
|
|
# Case 2: Empty-init with PyTorch >= 2.1 supports meta device
|
|
_run_setup_assertions(empty_init=True, expected_device=torch.device("meta"))
|
|
else:
|
|
# Case 2: Empty-init with PyTorch < 2.1 only supports `torch.empty()`-init
|
|
_run_setup_assertions(empty_init=True, expected_device=torch.device("cpu"))
|
|
|
|
|
|
@RunIf(min_cuda_gpus=2, standalone=True, min_torch="2.0.0")
|
|
def test_fsdp_save_filter(tmp_path):
|
|
fabric = BoringFabric(accelerator="cuda", strategy=FSDPStrategy(state_dict_type="full"), devices=2)
|
|
fabric.launch()
|
|
model = fabric.get_model()
|
|
model = fabric.setup_module(model)
|
|
|
|
tmp_path = Path(fabric.broadcast(str(tmp_path)))
|
|
state = {"model": model}
|
|
filter = {"model": lambda k, v: "bias" in k}
|
|
|
|
checkpoint_path = tmp_path / "full.pth"
|
|
fabric.save(checkpoint_path, state, filter=filter)
|
|
checkpoint = torch.load(checkpoint_path)["model"]
|
|
assert set(checkpoint) == {"bias"}
|
|
assert isinstance(checkpoint["bias"], torch.Tensor)
|
|
|
|
fabric.strategy._state_dict_type = "sharded"
|
|
checkpoint_path = tmp_path / "sharded"
|
|
with pytest.raises(NotImplementedError, match="doesn't support loading sharded filtered"):
|
|
fabric.save(checkpoint_path, state, filter=filter)
|
|
|
|
|
|
@RunIf(min_torch="1.13", min_cuda_gpus=1)
|
|
def test_fsdp_manual_activation_checkpointing():
|
|
model = torch.nn.Sequential(torch.nn.Linear(1, 1), torch.nn.Linear(1, 1))
|
|
strategy = FSDPStrategy(activation_checkpointing_policy={torch.nn.Linear})
|
|
fabric = Fabric(devices=1, accelerator="cuda", strategy=strategy)
|
|
fabric.launch()
|
|
|
|
from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (
|
|
apply_activation_checkpointing,
|
|
CheckpointWrapper,
|
|
)
|
|
|
|
# manually apply activation checkpointing
|
|
apply_activation_checkpointing(model)
|
|
|
|
wrappers = {name for name, mod in model.named_modules() if isinstance(mod, CheckpointWrapper)}
|
|
assert wrappers == {"0", "1"}
|
|
|
|
# let fabric set up the model, it shouldn't apply activation checkpointing again
|
|
with pytest.warns(match="is configured, but the model already contains checkpointed"):
|
|
model = fabric.setup(model)
|
|
|
|
wrappers = {name for name, mod in model._forward_module.named_modules() if isinstance(mod, CheckpointWrapper)}
|
|
assert wrappers == {"_fsdp_wrapped_module.0", "_fsdp_wrapped_module.1"}
|
|
|
|
|
|
@RunIf(min_torch="1.12", min_cuda_gpus=1)
|
|
def test_rewrap_warnings():
|
|
from torch.distributed.fsdp import FullyShardedDataParallel
|
|
from torch.distributed.fsdp.wrap import wrap
|
|
|
|
strategy = FSDPStrategy(auto_wrap_policy={torch.nn.Linear})
|
|
fabric = Fabric(devices=1, accelerator="cuda", strategy=strategy)
|
|
fabric.launch()
|
|
with fabric.init_module():
|
|
model = torch.nn.Sequential(torch.nn.Linear(1, 1), torch.nn.ReLU(), wrap(torch.nn.Linear(1, 1)))
|
|
with pytest.warns(match="the model is already wrapped"):
|
|
model = fabric.setup(model)
|
|
assert not isinstance(model._forward_module, FullyShardedDataParallel)
|
|
assert isinstance(model._forward_module[2], FullyShardedDataParallel)
|
|
|
|
if not _TORCH_GREATER_EQUAL_2_1:
|
|
return
|
|
|
|
with fabric.init_module(empty_init=True):
|
|
model = torch.nn.Sequential(torch.nn.Linear(1, 1), torch.nn.ReLU(), wrap(torch.nn.Linear(1, 1)))
|
|
assert model[0].weight.is_meta
|
|
with pytest.warns(match="there are still parameters on the meta device"):
|
|
fabric_model = fabric.setup(model)
|
|
assert next(fabric_model.parameters()).is_meta
|