2022-10-08 15:42:21 +00:00
|
|
|
# Copyright The Lightning AI team.
|
|
|
|
#
|
|
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
|
|
# you may not use this file except in compliance with the License.
|
|
|
|
# You may obtain a copy of the License at
|
|
|
|
#
|
|
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
#
|
|
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
|
|
# See the License for the specific language governing permissions and
|
|
|
|
# limitations under the License.
|
2022-10-04 22:54:14 +00:00
|
|
|
import os
|
2022-09-30 07:57:18 +00:00
|
|
|
from functools import partial
|
2022-10-04 22:54:14 +00:00
|
|
|
from unittest import mock
|
2023-04-11 22:04:17 +00:00
|
|
|
from unittest.mock import MagicMock
|
2022-09-30 07:57:18 +00:00
|
|
|
|
|
|
|
import pytest
|
2022-12-08 07:08:04 +00:00
|
|
|
import torch
|
2022-10-08 15:42:21 +00:00
|
|
|
from torch.utils.data import DataLoader
|
2022-09-30 07:57:18 +00:00
|
|
|
|
2023-04-28 16:36:22 +00:00
|
|
|
from lightning.fabric.accelerators import XLAAccelerator
|
2023-02-01 20:34:38 +00:00
|
|
|
from lightning.fabric.strategies import XLAStrategy
|
|
|
|
from lightning.fabric.strategies.launchers.xla import _XLALauncher
|
|
|
|
from lightning.fabric.utilities.distributed import ReduceOp
|
2023-05-05 17:57:24 +00:00
|
|
|
from lightning.fabric.utilities.seed import seed_everything
|
2023-04-11 22:04:17 +00:00
|
|
|
from tests_fabric.helpers.models import RandomDataset
|
2023-03-03 16:55:48 +00:00
|
|
|
from tests_fabric.helpers.runif import RunIf
|
2022-09-30 07:57:18 +00:00
|
|
|
|
|
|
|
|
|
|
|
def wrap_launch_function(fn, strategy, *args, **kwargs):
|
|
|
|
# the launcher does not manage this automatically. explanation available in:
|
|
|
|
# https://github.com/Lightning-AI/lightning/pull/14926#discussion_r982976718
|
|
|
|
strategy.setup_environment()
|
|
|
|
return fn(*args, **kwargs)
|
|
|
|
|
|
|
|
|
2023-05-05 17:57:24 +00:00
|
|
|
def xla_launch(fn, strategy=None):
|
2022-09-30 07:57:18 +00:00
|
|
|
# TODO: the accelerator should be optional to just launch processes, but this requires lazy initialization
|
2023-05-05 17:57:24 +00:00
|
|
|
if not strategy:
|
|
|
|
accelerator = XLAAccelerator()
|
|
|
|
strategy = XLAStrategy(
|
|
|
|
accelerator=accelerator,
|
|
|
|
parallel_devices=XLAAccelerator.get_parallel_devices(XLAAccelerator.auto_device_count()),
|
|
|
|
)
|
2022-09-30 07:57:18 +00:00
|
|
|
launcher = _XLALauncher(strategy=strategy)
|
|
|
|
wrapped = partial(wrap_launch_function, fn, strategy)
|
|
|
|
return launcher.launch(wrapped, strategy)
|
|
|
|
|
|
|
|
|
|
|
|
def broadcast_on_tpu_fn(strategy):
|
2023-04-19 14:39:00 +00:00
|
|
|
# test broadcasting a tensor
|
|
|
|
obj = torch.tensor(strategy.local_rank)
|
|
|
|
# In PjRT, the local rank and global rank have no solid relation.
|
|
|
|
# global rank may not even be contiguous on a host, because it depends on the 3D mesh structure that is formed by
|
|
|
|
# the TPUs on all hosts in a pod. So checking a different src is not reliable
|
|
|
|
# https://github.com/pytorch/xla/blob/v2.0.0/torch_xla/experimental/pjrt.py#L161-L163
|
|
|
|
src = 0
|
|
|
|
result = strategy.broadcast(obj, src)
|
|
|
|
assert result.item() == src
|
|
|
|
assert result.device.type == "xla"
|
|
|
|
|
|
|
|
# test broadcasting an arbitrary object
|
2022-09-30 07:57:18 +00:00
|
|
|
obj = ("ver_0.5", "logger_name", strategy.local_rank)
|
2023-04-19 14:39:00 +00:00
|
|
|
result = strategy.broadcast(obj, src=src)
|
|
|
|
assert result == ("ver_0.5", "logger_name", src)
|
2022-09-30 07:57:18 +00:00
|
|
|
|
|
|
|
|
|
|
|
@RunIf(tpu=True)
|
2022-10-04 22:54:14 +00:00
|
|
|
@mock.patch.dict(os.environ, os.environ.copy(), clear=True)
|
2022-09-30 07:57:18 +00:00
|
|
|
def test_broadcast_on_tpu():
|
2023-04-19 14:39:00 +00:00
|
|
|
"""Checks if an object from the main process is broadcast to other processes correctly."""
|
2022-09-30 07:57:18 +00:00
|
|
|
xla_launch(broadcast_on_tpu_fn)
|
|
|
|
|
|
|
|
|
|
|
|
def tpu_reduce_fn(strategy):
|
|
|
|
with pytest.raises(ValueError, match="XLAStrategy only supports"):
|
2023-01-16 13:17:45 +00:00
|
|
|
strategy.all_reduce(1, reduce_op="undefined")
|
2022-09-30 07:57:18 +00:00
|
|
|
|
|
|
|
with pytest.raises(ValueError, match="XLAStrategy only supports"):
|
2023-01-16 13:17:45 +00:00
|
|
|
strategy.all_reduce(1, reduce_op=ReduceOp.MAX)
|
2022-09-30 07:57:18 +00:00
|
|
|
|
|
|
|
# it is faster to loop over here than to parameterize the test
|
|
|
|
for reduce_op in ("mean", "AVG", "sum", ReduceOp.SUM):
|
2023-01-16 13:17:45 +00:00
|
|
|
result = strategy.all_reduce(1, reduce_op=reduce_op)
|
2022-09-30 07:57:18 +00:00
|
|
|
if isinstance(reduce_op, str) and reduce_op.lower() in ("mean", "avg"):
|
|
|
|
assert result.item() == 1
|
|
|
|
else:
|
|
|
|
assert result.item() == 8
|
|
|
|
|
|
|
|
|
|
|
|
@RunIf(tpu=True)
|
2022-10-04 22:54:14 +00:00
|
|
|
@mock.patch.dict(os.environ, os.environ.copy(), clear=True)
|
2022-09-30 07:57:18 +00:00
|
|
|
def test_tpu_reduce():
|
2023-01-16 13:17:45 +00:00
|
|
|
"""Test tpu spawn all_reduce operation."""
|
2022-09-30 07:57:18 +00:00
|
|
|
xla_launch(tpu_reduce_fn)
|
2022-10-08 15:42:21 +00:00
|
|
|
|
|
|
|
|
|
|
|
@RunIf(tpu=True)
|
2023-02-01 20:34:38 +00:00
|
|
|
@mock.patch("lightning.fabric.strategies.xla.XLAStrategy.root_device")
|
2022-10-08 15:42:21 +00:00
|
|
|
def test_xla_mp_device_dataloader_attribute(_, monkeypatch):
|
2023-02-06 17:21:26 +00:00
|
|
|
dataset = RandomDataset(32, 64)
|
|
|
|
dataloader = DataLoader(dataset)
|
|
|
|
strategy = XLAStrategy()
|
|
|
|
isinstance_return = True
|
|
|
|
|
2022-10-08 15:42:21 +00:00
|
|
|
import torch_xla.distributed.parallel_loader as parallel_loader
|
|
|
|
|
2023-02-06 17:21:26 +00:00
|
|
|
class MpDeviceLoaderMock(MagicMock):
|
|
|
|
def __instancecheck__(self, instance):
|
|
|
|
# to make `isinstance(dataloader, MpDeviceLoader)` pass with a mock as class
|
|
|
|
return isinstance_return
|
|
|
|
|
|
|
|
mp_loader_mock = MpDeviceLoaderMock()
|
2022-10-08 15:42:21 +00:00
|
|
|
monkeypatch.setattr(parallel_loader, "MpDeviceLoader", mp_loader_mock)
|
|
|
|
|
2023-02-06 17:21:26 +00:00
|
|
|
processed_dataloader = strategy.process_dataloader(dataloader)
|
|
|
|
assert processed_dataloader is dataloader
|
|
|
|
mp_loader_mock.assert_not_called() # no-op
|
|
|
|
|
|
|
|
isinstance_return = False
|
2022-10-08 15:42:21 +00:00
|
|
|
processed_dataloader = strategy.process_dataloader(dataloader)
|
|
|
|
mp_loader_mock.assert_called_with(dataloader, strategy.root_device)
|
|
|
|
assert processed_dataloader.dataset == processed_dataloader._loader.dataset
|
2023-02-03 09:36:56 +00:00
|
|
|
assert processed_dataloader.batch_sampler == processed_dataloader._loader.batch_sampler
|
2022-10-08 15:42:21 +00:00
|
|
|
|
|
|
|
|
2022-12-08 07:08:04 +00:00
|
|
|
def tpu_all_gather_fn(strategy):
|
2023-04-19 14:39:00 +00:00
|
|
|
with pytest.raises(NotImplementedError, match="only implemented for tensors"):
|
|
|
|
strategy.all_gather([1])
|
|
|
|
|
|
|
|
device_count = strategy.accelerator.auto_device_count()
|
|
|
|
for sync_grads in (True, False):
|
|
|
|
tensor = torch.tensor(1.0, requires_grad=True)
|
2022-12-08 07:08:04 +00:00
|
|
|
result = strategy.all_gather(tensor, sync_grads=sync_grads)
|
|
|
|
summed = result.sum()
|
2023-04-19 14:39:00 +00:00
|
|
|
assert summed.device.type == "xla"
|
2023-04-11 22:24:11 +00:00
|
|
|
assert torch.equal(summed, torch.tensor(device_count, dtype=torch.float32))
|
2022-12-08 07:08:04 +00:00
|
|
|
summed.backward()
|
|
|
|
if sync_grads:
|
|
|
|
assert torch.equal(tensor.grad, torch.tensor(1.0))
|
|
|
|
else:
|
|
|
|
# As gradients are not synced, the original tensor will not have gradients.
|
|
|
|
assert tensor.grad is None
|
|
|
|
|
|
|
|
|
|
|
|
@RunIf(tpu=True)
|
|
|
|
@mock.patch.dict(os.environ, os.environ.copy(), clear=True)
|
|
|
|
def test_tpu_all_gather():
|
|
|
|
"""Test the all_gather operation on TPU."""
|
|
|
|
xla_launch(tpu_all_gather_fn)
|
2023-05-05 17:57:24 +00:00
|
|
|
|
|
|
|
|
|
|
|
def tpu_sync_module_states_fn(sync_module_states, strategy):
|
|
|
|
seed_everything()
|
|
|
|
model = torch.nn.Linear(1, 1).to(strategy.root_device)
|
|
|
|
model = strategy.setup_module(model)
|
|
|
|
gathered = strategy.all_gather(model.weight)
|
|
|
|
if sync_module_states:
|
|
|
|
for t in gathered:
|
|
|
|
assert gathered[0] == t
|
|
|
|
else:
|
|
|
|
with pytest.raises(AssertionError):
|
|
|
|
for t in gathered:
|
|
|
|
assert gathered[0] == t
|
|
|
|
|
|
|
|
|
|
|
|
@RunIf(tpu=True)
|
|
|
|
@pytest.mark.parametrize("sync_module_states", [True, False])
|
|
|
|
@mock.patch.dict(os.environ, os.environ.copy(), clear=True)
|
|
|
|
def test_tpu_sync_module_states(sync_module_states):
|
|
|
|
"""Test sync_module_states."""
|
|
|
|
accelerator = XLAAccelerator()
|
|
|
|
strategy = XLAStrategy(
|
|
|
|
accelerator=accelerator,
|
|
|
|
parallel_devices=XLAAccelerator.get_parallel_devices(XLAAccelerator.auto_device_count()),
|
|
|
|
sync_module_states=sync_module_states,
|
|
|
|
)
|
|
|
|
partial_fn = partial(tpu_sync_module_states_fn, sync_module_states)
|
|
|
|
xla_launch(partial_fn, strategy)
|