# Copyright The Lightning AI team. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import os from unittest import mock from unittest.mock import MagicMock, Mock import pytest import torch.nn import torch.nn as nn from torch.optim import Adam from lightning.fabric.accelerators import XLAAccelerator from lightning.fabric.strategies import XLAFSDPStrategy from lightning.fabric.strategies.xla_fsdp import _activation_checkpointing_auto_wrapper, _XLAFSDPBackwardSyncControl from tests_fabric.helpers.runif import RunIf @RunIf(min_torch="2.0", tpu=True) @pytest.mark.parametrize("torch_ge_2_0", [False, True]) def test_xla_fsdp_setup_optimizer_validation(torch_ge_2_0): """Test that `setup_optimizer()` validates the param groups and reference to FSDP parameters.""" module = nn.Linear(2, 2) strategy = XLAFSDPStrategy( parallel_devices=XLAAccelerator.get_parallel_devices(XLAAccelerator.auto_device_count()), ) with mock.patch("lightning.fabric.strategies.xla_fsdp._TORCH_GREATER_EQUAL_2_0", torch_ge_2_0): bad_optimizer_1 = Adam([{"params": [module.weight]}, {"params": [module.bias], "lr": 1e-3}]) bad_optimizer_2 = Adam(module.parameters()) if torch_ge_2_0: strategy.setup_optimizer(bad_optimizer_1) strategy.setup_optimizer(bad_optimizer_2) else: with pytest.raises(ValueError, match="does not support multiple param groups"): strategy.setup_optimizer(bad_optimizer_1) with pytest.raises(ValueError, match="The optimizer does not seem to reference any XLAFSDP parameter"): strategy.setup_optimizer(bad_optimizer_2) @RunIf(min_torch="2.0", tpu=True) def test_xla_fsdp_no_backward_sync(): """Test that the backward sync control calls `.no_sync()`, and only on a module wrapped in XlaFullyShardedDataParallel.""" from torch_xla.distributed.fsdp import XlaFullyShardedDataParallel strategy = XLAFSDPStrategy() assert isinstance(strategy._backward_sync_control, _XLAFSDPBackwardSyncControl) with pytest.raises( TypeError, match="is only possible if the module passed to .* is wrapped in `XlaFullyShardedDataParallel`" ), strategy._backward_sync_control.no_backward_sync(object()): pass module = MagicMock(spec=XlaFullyShardedDataParallel) with strategy._backward_sync_control.no_backward_sync(module): pass module.no_sync.assert_called_once() @RunIf(min_torch="2.0", tpu=True) def test_xla_fsdp_grad_clipping_value_error(): strategy = XLAFSDPStrategy() with pytest.raises(NotImplementedError, match="does not support to clip gradients by value"): strategy.clip_gradients_value(Mock(), Mock(), Mock()) @RunIf(min_torch="2.0", tpu=True) def test_xla_fsdp_activation_checkpointing_setup(): """Test XLAFSDP activation checkpointing setup.""" from torch_xla.distributed.fsdp import checkpoint_module from torch_xla.distributed.fsdp.xla_fully_sharded_data_parallel import XlaFullyShardedDataParallel auto_wrapper_callable = lambda m, *args, **kwargs: XlaFullyShardedDataParallel( checkpoint_module(m), *args, **kwargs ) strategy = XLAFSDPStrategy(auto_wrapper_callable=auto_wrapper_callable) assert auto_wrapper_callable in strategy._fsdp_kwargs.values() @mock.patch.dict(os.environ, os.environ.copy(), clear=True) def test_rank_properties_access(xla_available): """Test that the strategy returns the expected values depending on whether we're in the main process or not.""" strategy = XLAFSDPStrategy() strategy.cluster_environment = Mock() # we're in the main process, no processes have been launched yet assert not strategy._launched assert strategy.global_rank == 0 assert strategy.local_rank == 0 assert strategy.node_rank == 0 assert strategy.world_size == 1 # simulate we're in a worker process strategy._launched = True assert strategy.global_rank == strategy.cluster_environment.global_rank() assert strategy.local_rank == strategy.cluster_environment.local_rank() assert strategy.node_rank == strategy.cluster_environment.node_rank() assert strategy.world_size == strategy.cluster_environment.world_size() def test_xla_fsdp_policy(xla_available): strategy = XLAFSDPStrategy(foo=1) assert strategy._fsdp_kwargs == {"foo": 1} strategy = XLAFSDPStrategy(auto_wrap_policy={torch.nn.Linear}) assert "auto_wrap_policy" in strategy._fsdp_kwargs assert strategy._fsdp_kwargs["auto_wrap_policy"].func._mock_name == "transformer_auto_wrap_policy" strategy = XLAFSDPStrategy(activation_checkpointing_policy={torch.nn.Linear}) assert "auto_wrapper_callable" in strategy._fsdp_kwargs assert strategy._fsdp_kwargs["auto_wrapper_callable"].func is _activation_checkpointing_auto_wrapper with pytest.raises(ValueError, match="cannot set both"): XLAFSDPStrategy(activation_checkpointing_policy={torch.nn.Linear}, auto_wrapper_callable="foo") with pytest.raises(TypeError, match="must be a set"): XLAFSDPStrategy(activation_checkpointing_policy="foo")