lightning/tests/tests_fabric/strategies/test_xla_fsdp_integration.py

189 lines
7.5 KiB
Python

# Copyright The Lightning AI team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
from unittest import mock
import pytest
import torch
from torch.utils.data import DataLoader
from lightning.fabric import Fabric
from lightning.fabric.strategies import XLAFSDPStrategy
from tests_fabric.helpers.models import RandomDataset
from tests_fabric.helpers.runif import RunIf
def _xla_fsdp_rewrap_warning(fabric: Fabric):
"""Fabric launch function for test_xla_fsdp_rewrap_warning."""
from torch_xla.distributed.fsdp.xla_fully_sharded_data_parallel import XlaFullyShardedDataParallel
with fabric.init_module():
model = torch.nn.Sequential(
torch.nn.Linear(1, 1), torch.nn.ReLU(), XlaFullyShardedDataParallel(torch.nn.Linear(1, 1))
)
if fabric.node_rank:
with pytest.warns(match="submodule is already wrapped"):
model = fabric.setup_module(model)
else:
model = fabric.setup_module(model)
fabric.barrier("warning_check")
assert not isinstance(model._forward_module[0], XlaFullyShardedDataParallel)
assert not isinstance(model._forward_module[1], XlaFullyShardedDataParallel)
assert isinstance(model._forward_module[2], XlaFullyShardedDataParallel)
@RunIf(min_torch="2.0", tpu=True, standalone=True)
def test_xla_fsdp_rewrap_warning():
"""Test that XLAFSDP warns about rewrapping the modules."""
from torch_xla.distributed.fsdp.wrap import always_wrap_policy
strategy = XLAFSDPStrategy(auto_wrap_policy=always_wrap_policy)
fabric = Fabric(accelerator="tpu", strategy=strategy)
fabric.launch(_xla_fsdp_rewrap_warning)
def xla_fsdp_train_save_load(fabric: Fabric, tmp_path, state_dict_type):
"""Fabric launch function for test_xla_fsdp_train_save_load."""
# check if multihost
if fabric.strategy.all_reduce(fabric.node_rank, reduce_op="sum").item() > 0:
return # pytest.skip() is not pickleable
checkpoint_path = fabric.broadcast(str(tmp_path))
with fabric.init_module():
model_1 = torch.nn.Sequential(torch.nn.Linear(32, 32), torch.nn.ReLU(), torch.nn.Linear(32, 2))
model_1 = fabric.setup_module(model_1)
optimizer_1 = torch.optim.Adam(model_1.parameters(), lr=0.1)
optimizer_1 = fabric.setup_optimizers(optimizer_1)
dataloader = DataLoader(RandomDataset(32, 64))
dataloader = fabric.setup_dataloaders(dataloader)
def step(model, batch):
output = model(batch)
return torch.nn.functional.mse_loss(output, torch.ones_like(output))
model_1.train()
data_iter = iter(dataloader)
batch = next(data_iter)
loss = step(model_1, batch)
fabric.backward(loss)
optimizer_1.step()
optimizer_1.zero_grad()
state = {
"model": model_1,
"optimizer": optimizer_1, # not needed in ckpt consolidation
"step_count": 1,
}
fabric.save(checkpoint_path, state)
world_size = fabric.world_size
if state_dict_type == "sharded":
expected_files = {f"checkpoint_rank-{i:08d}-of-{world_size:08d}.pth" for i in range(world_size)}
assert set(os.listdir(checkpoint_path)) == expected_files
# define a second set of model and optimizer
with fabric.init_module():
model_2 = torch.nn.Sequential(torch.nn.Linear(32, 32), torch.nn.ReLU(), torch.nn.Linear(32, 2))
model_2 = fabric.setup_module(model_2)
optimizer_2 = torch.optim.Adam(model_2.parameters(), lr=0.1)
optimizer_2 = fabric.setup_optimizers(optimizer_2)
# load sharded checkpoints into the second set of model and optimizer
state = {
"model": model_2,
"optimizer": optimizer_2,
"step_count": 0,
}
metadata = fabric.load(checkpoint_path, state)
# check user data in loaded state
assert not metadata
assert state["step_count"] == 1
# check correctness with loaded state
for p0, p1 in zip(model_1._forward_module.parameters(), model_2.parameters()):
torch.testing.assert_close(p0, p1, atol=0, rtol=0, equal_nan=True)
# attempt to load a key not in the metadata checkpoint
state = {"model": model_2, "coconut": 11}
with pytest.raises(KeyError, match="The requested state contains a key 'coconut' that does not exist"):
fabric.load(checkpoint_path, state)
# `strict=False` ignores the missing key
state = {"model": model_2, "coconut": 11}
fabric.load(checkpoint_path, state, strict=False)
assert state["coconut"] == 11
if state_dict_type == "full":
assert set(os.listdir(checkpoint_path)) == {"checkpoint_consolidated.pth"}
# define a second set of model and optimizer
with fabric.init_module():
model_2 = torch.nn.Sequential(torch.nn.Linear(32, 32), torch.nn.ReLU(), torch.nn.Linear(32, 2))
import torch_xla.core.xla_model as xm
device = xm.xla_device()
model_2.to(device)
# load sharded checkpoints into the second model
state = {"model": model_2}
fabric.load(checkpoint_path, state)
# check that loaded state is different
with pytest.raises(AssertionError, match="do not match"):
for p0, p1 in zip(model_1.parameters(), model_2.parameters()):
torch.testing.assert_close(p0, p1, atol=0, rtol=0, equal_nan=True)
@RunIf(min_torch="2.0", tpu=True, standalone=True)
@pytest.mark.parametrize("use_auto_wrap_policy", [False, True])
@pytest.mark.parametrize("state_dict_type", ["sharded", "full"])
def test_xla_fsdp_train_save_load(tmp_path, use_auto_wrap_policy, state_dict_type):
"""Test XLAFSDP training, saving and loading checkpoint (both full and sharded)."""
from torch_xla.distributed.fsdp.wrap import always_wrap_policy
strategy = XLAFSDPStrategy(
auto_wrap_policy=always_wrap_policy if use_auto_wrap_policy else None, state_dict_type=state_dict_type
)
fabric = Fabric(accelerator="tpu", strategy=strategy)
fabric.launch(xla_fsdp_train_save_load, tmp_path, state_dict_type)
def _test_setup_module_move_to_device(fabric, move_to_device):
model = torch.nn.Linear(10, 10, bias=False)
with mock.patch("lightning.fabric.wrappers._FabricModule") as fabric_module_mock:
fabric_model = fabric.setup_module(model, move_to_device=move_to_device)
fabric_module_mock.assert_not_called()
# The _DeviceDtypeModuleMixin currently can't represent the device in a meaningful way for sharded models
assert fabric_model.device == torch.device("cpu")
assert fabric.device.type == "xla"
@RunIf(min_torch="2.0", tpu=True, standalone=True)
@pytest.mark.parametrize("move_to_device", [True, False])
def test_setup_module_move_to_device(move_to_device):
"""Test that `move_to_device` does nothing, FSDP decides which device parameters get moved to which device
(sharding)."""
from torch_xla.distributed.fsdp.wrap import always_wrap_policy
strategy = XLAFSDPStrategy(auto_wrap_policy=always_wrap_policy)
fabric = Fabric(accelerator="tpu", strategy=strategy)
fabric.launch(_test_setup_module_move_to_device, move_to_device=move_to_device)