lightning/tests/tests_fabric/strategies/test_xla_fsdp_integration.py

209 lines
8.0 KiB
Python

# Copyright The Lightning AI team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import re
from pathlib import Path
from unittest import mock
import pytest
import torch
from lightning.fabric import Fabric
from lightning.fabric.strategies import XLAFSDPStrategy
from torch.utils.data import DataLoader
from tests_fabric.helpers.datasets import RandomDataset
from tests_fabric.helpers.runif import RunIf
def _xla_fsdp_rewrap_warning(fabric: Fabric):
"""Fabric launch function for test_xla_fsdp_rewrap_warning."""
from torch_xla.distributed.fsdp.xla_fully_sharded_data_parallel import XlaFullyShardedDataParallel
with fabric.init_module():
model = torch.nn.Sequential(
torch.nn.Linear(1, 1), torch.nn.ReLU(), XlaFullyShardedDataParallel(torch.nn.Linear(1, 1))
)
if fabric.node_rank:
with pytest.warns(match="submodule is already wrapped"):
model = fabric.setup_module(model)
else:
model = fabric.setup_module(model)
fabric.barrier("warning_check")
assert not isinstance(model._forward_module[0], XlaFullyShardedDataParallel)
assert not isinstance(model._forward_module[1], XlaFullyShardedDataParallel)
assert isinstance(model._forward_module[2], XlaFullyShardedDataParallel)
@RunIf(min_torch="2.0", tpu=True, standalone=True)
def test_xla_fsdp_rewrap_warning():
"""Test that XLAFSDP warns about rewrapping the modules."""
from torch_xla.distributed.fsdp.wrap import always_wrap_policy
strategy = XLAFSDPStrategy(auto_wrap_policy=always_wrap_policy)
fabric = Fabric(accelerator="tpu", strategy=strategy)
fabric.launch(_xla_fsdp_rewrap_warning)
def xla_fsdp_train_save_load(fabric: Fabric, tmp_path, state_dict_type):
"""Fabric launch function for test_xla_fsdp_train_save_load."""
tmp_path = Path(fabric.broadcast(tmp_path))
with fabric.init_module():
model_1 = torch.nn.Sequential(torch.nn.Linear(32, 32), torch.nn.ReLU(), torch.nn.Linear(32, 2))
model_1 = fabric.setup_module(model_1)
optimizer_1 = torch.optim.Adam(model_1.parameters(), lr=0.1)
optimizer_1 = fabric.setup_optimizers(optimizer_1)
dataloader = DataLoader(RandomDataset(32, 64))
dataloader = fabric.setup_dataloaders(dataloader)
def step(model, batch):
output = model(batch)
return torch.nn.functional.mse_loss(output, torch.ones_like(output))
model_1.train()
data_iter = iter(dataloader)
batch = next(data_iter)
loss = step(model_1, batch)
fabric.backward(loss)
optimizer_1.step()
optimizer_1.zero_grad()
state = {
"model": model_1,
"optimizer": optimizer_1, # not needed in ckpt consolidation
"step_count": 1,
}
checkpoint_path = tmp_path / "foo.pth"
world_size = fabric.world_size
local_process_count = len(fabric.strategy.parallel_devices)
is_multihost = local_process_count < world_size
if state_dict_type == "full" and is_multihost:
with pytest.raises(OSError, match="Multihost setups do not have a shared filesystem"):
fabric.save(checkpoint_path, state)
return
fabric.save(checkpoint_path, state)
if state_dict_type == "sharded":
pattern = rf"checkpoint_rank-0000000\d-of-{world_size:08d}\.pth"
shards = os.listdir(checkpoint_path)
assert len(shards) == local_process_count
for name in shards:
assert re.match(pattern, name)
# define a second set of model and optimizer
with fabric.init_module():
model_2 = torch.nn.Sequential(torch.nn.Linear(32, 32), torch.nn.ReLU(), torch.nn.Linear(32, 2))
model_2 = fabric.setup_module(model_2)
optimizer_2 = torch.optim.Adam(model_2.parameters(), lr=0.1)
optimizer_2 = fabric.setup_optimizers(optimizer_2)
# load sharded checkpoints into the second set of model and optimizer
state = {
"model": model_2,
"optimizer": optimizer_2,
"step_count": 0,
}
metadata = fabric.load(checkpoint_path, state)
# check user data in loaded state
assert not metadata
assert state["step_count"] == 1
# check correctness with loaded state
for p0, p1 in zip(model_1._forward_module.parameters(), model_2.parameters()):
torch.testing.assert_close(p0, p1, atol=0, rtol=0, equal_nan=True)
# attempt to load a key not in the metadata checkpoint
state = {"model": model_2, "coconut": 11}
with pytest.raises(KeyError, match="The requested state contains a key 'coconut' that does not exist"):
fabric.load(checkpoint_path, state)
# `strict=False` ignores the missing key
state = {"model": model_2, "coconut": 11}
fabric.load(checkpoint_path, state, strict=False)
assert state["coconut"] == 11
if state_dict_type == "full":
assert set(os.listdir(tmp_path)) == {"foo.pth"}
# define a second set of model and optimizer
with fabric.init_module():
model_2 = torch.nn.Sequential(torch.nn.Linear(32, 32), torch.nn.ReLU(), torch.nn.Linear(32, 2))
import torch_xla.core.xla_model as xm
device = xm.xla_device()
model_2.to(device)
# load sharded checkpoints into the second model
state = {"model": model_2}
fabric.load(checkpoint_path, state)
# check that loaded state is different
with pytest.raises(AssertionError, match="do not match"):
for p0, p1 in zip(model_1.parameters(), model_2.parameters()):
torch.testing.assert_close(p0, p1, atol=0, rtol=0, equal_nan=True)
@RunIf(min_torch="2.0", tpu=True, standalone=True)
@pytest.mark.parametrize(
("use_auto_wrap_policy", "state_dict_type", "sequential_save"),
[
(False, "sharded", False),
(False, "full", False),
(False, "full", True),
(True, "sharded", False),
(True, "full", False),
],
)
def test_xla_fsdp_train_save_load(tmp_path, use_auto_wrap_policy, state_dict_type, sequential_save):
"""Test XLAFSDP training, saving and loading checkpoint (both full and sharded)."""
from torch_xla.distributed.fsdp.wrap import always_wrap_policy
policy = always_wrap_policy if use_auto_wrap_policy else None
strategy = XLAFSDPStrategy(
auto_wrap_policy=policy,
state_dict_type=state_dict_type,
sequential_save=sequential_save,
)
fabric = Fabric(accelerator="tpu", strategy=strategy)
fabric.launch(xla_fsdp_train_save_load, tmp_path, state_dict_type)
def _test_setup_module_move_to_device(fabric, move_to_device):
model = torch.nn.Linear(10, 10, bias=False)
with mock.patch("lightning.fabric.wrappers._FabricModule") as fabric_module_mock:
fabric_model = fabric.setup_module(model, move_to_device=move_to_device)
fabric_module_mock.assert_not_called()
# The _DeviceDtypeModuleMixin currently can't represent the device in a meaningful way for models with pieces on
# different devices
assert fabric_model.device.type == "xla"
assert fabric.device.type == "xla"
@RunIf(min_torch="2.0", tpu=True, standalone=True)
@pytest.mark.parametrize("move_to_device", [True, False])
def test_setup_module_move_to_device(move_to_device):
"""Test that `move_to_device` does nothing, FSDP decides which device parameters get moved to which device
(sharding)."""
from torch_xla.distributed.fsdp.wrap import always_wrap_policy
strategy = XLAFSDPStrategy(auto_wrap_policy=always_wrap_policy)
fabric = Fabric(accelerator="tpu", strategy=strategy)
fabric.launch(_test_setup_module_move_to_device, move_to_device=move_to_device)