lightning/tests/tests_fabric/strategies/test_fsdp.py

Ignoring revisions in .git-blame-ignore-revs. Click here to bypass and see the normal blame view.

458 lines
20 KiB
Python
Raw Normal View History

# Copyright The Lightning AI team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from datetime import timedelta
from re import escape
from unittest import mock
from unittest.mock import ANY, MagicMock, Mock
ruff: replace isort with ruff +TPU (#17684) * ruff: replace isort with ruff * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fixing & imports * lines in warning test * docs * fix enum import * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fixing * import * fix lines * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * type ClusterEnvironment * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
2023-09-26 15:54:55 +00:00
import lightning.fabric
import pytest
import torch
import torch.nn as nn
from lightning.fabric.plugins import HalfPrecision
from lightning.fabric.plugins.environments import LightningEnvironment
from lightning.fabric.strategies import FSDPStrategy
from lightning.fabric.strategies.fsdp import (
_FSDPBackwardSyncControl,
_get_full_state_dict_context,
_has_meta_device_parameters,
_is_sharded_checkpoint,
)
from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_1, _TORCH_GREATER_EQUAL_2_2
from torch.distributed.fsdp.fully_sharded_data_parallel import CPUOffload, FullyShardedDataParallel, MixedPrecision
ruff: replace isort with ruff +TPU (#17684) * ruff: replace isort with ruff * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fixing & imports * lines in warning test * docs * fix enum import * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fixing * import * fix lines * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * type ClusterEnvironment * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
2023-09-26 15:54:55 +00:00
from torch.optim import Adam
from tests_fabric.helpers.runif import RunIf
def test_fsdp_custom_mixed_precision():
"""Test that passing a custom mixed precision config works."""
config = MixedPrecision()
strategy = FSDPStrategy(mixed_precision=config)
assert strategy.mixed_precision_config == config
def test_fsdp_cpu_offload():
"""Test the different ways cpu offloading can be enabled."""
# bool
strategy = FSDPStrategy(cpu_offload=True)
assert strategy.cpu_offload == CPUOffload(offload_params=True)
# dataclass
config = CPUOffload()
strategy = FSDPStrategy(cpu_offload=config)
assert strategy.cpu_offload == config
def test_fsdp_sharding_strategy():
"""Test the different ways the sharding strategy can be set."""
from torch.distributed.fsdp import ShardingStrategy
# default
strategy = FSDPStrategy()
assert strategy.sharding_strategy == ShardingStrategy.FULL_SHARD
# enum
strategy = FSDPStrategy(sharding_strategy=ShardingStrategy.SHARD_GRAD_OP)
assert strategy.sharding_strategy == ShardingStrategy.SHARD_GRAD_OP
# string
strategy = FSDPStrategy(sharding_strategy="NO_SHARD")
assert strategy.sharding_strategy == ShardingStrategy.NO_SHARD
strategy = FSDPStrategy(sharding_strategy="no_shard")
assert strategy.sharding_strategy == ShardingStrategy.NO_SHARD
@RunIf(min_torch="2.0")
@pytest.mark.parametrize("sharding_strategy", ["HYBRID_SHARD", "_HYBRID_SHARD_ZERO2"])
def test_fsdp_hybrid_shard_configuration(sharding_strategy):
"""Test that the hybrid sharding strategies can only be used with automatic wrapping or a manually specified pg."""
with pytest.raises(RuntimeError, match="The hybrid sharding strategy requires you to pass at least one of"):
FSDPStrategy(sharding_strategy=sharding_strategy)
strategy = FSDPStrategy(auto_wrap_policy={nn.Linear}, sharding_strategy=sharding_strategy)
assert strategy.sharding_strategy.name == sharding_strategy
process_group = (Mock(), Mock())
strategy = FSDPStrategy(sharding_strategy=sharding_strategy, process_group=process_group)
assert strategy.sharding_strategy.name == sharding_strategy
assert strategy._fsdp_kwargs["process_group"] is process_group
device_mesh = Mock()
strategy = FSDPStrategy(sharding_strategy=sharding_strategy, device_mesh=device_mesh)
assert strategy.sharding_strategy.name == sharding_strategy
assert strategy._fsdp_kwargs["device_mesh"] is device_mesh
with pytest.raises(ValueError, match="process_group.* device_mesh=.* are mutually exclusive"):
FSDPStrategy(sharding_strategy=sharding_strategy, process_group=process_group, device_mesh=device_mesh)
def test_fsdp_checkpoint_io_unsupported():
"""Test that the FSDP strategy does not support the `CheckpointIO` plugin."""
strategy = FSDPStrategy()
with pytest.raises(NotImplementedError, match="does not use the `CheckpointIO` plugin"):
_ = strategy.checkpoint_io
with pytest.raises(NotImplementedError, match="does not support setting a `CheckpointIO` plugin"):
strategy.checkpoint_io = Mock()
@pytest.mark.parametrize("torch_ge_2_0", [False, True])
def test_fsdp_setup_optimizer_validation(torch_ge_2_0):
"""Test that `setup_optimizer()` validates the param groups and reference to FSDP parameters."""
module = nn.Linear(2, 2)
with mock.patch("lightning.fabric.strategies.fsdp._TORCH_GREATER_EQUAL_2_0", torch_ge_2_0):
2023-10-11 13:26:30 +00:00
strategy = FSDPStrategy(parallel_devices=[torch.device("cpu")])
bad_optimizer = Adam(module.parameters())
if torch_ge_2_0:
2023-10-11 13:26:30 +00:00
strategy.setup_optimizer(bad_optimizer)
else:
with pytest.raises(ValueError, match="The optimizer does not seem to reference any FSDP parameter"):
2023-10-11 13:26:30 +00:00
strategy.setup_optimizer(bad_optimizer)
@RunIf(min_torch="2.0.0")
@mock.patch("lightning.fabric.strategies.fsdp.FSDPStrategy.setup_module")
def test_fsdp_setup_use_orig_params(_):
module = nn.Linear(2, 2)
optimizer = Adam(module.parameters())
strategy = FSDPStrategy(parallel_devices=[torch.device("cpu")], use_orig_params=False)
assert not strategy._fsdp_kwargs["use_orig_params"]
with pytest.raises(ValueError, match=r"`FSDPStrategy\(use_orig_params=False\)` but this is not supported"):
strategy.setup_module_and_optimizers(module, optimizer)
strategy = FSDPStrategy(parallel_devices=[torch.device("cpu")])
assert strategy._fsdp_kwargs["use_orig_params"]
strategy.setup_module_and_optimizers(module, optimizer)
assert strategy._fsdp_kwargs["use_orig_params"]
def test_fsdp_no_backward_sync():
"""Test that the backward sync control calls `.no_sync()`, and only on a module wrapped in
FullyShardedDataParallel."""
strategy = FSDPStrategy()
assert isinstance(strategy._backward_sync_control, _FSDPBackwardSyncControl)
with pytest.raises(
TypeError, match="is only possible if the module passed to .* is wrapped in `FullyShardedDataParallel`"
2023-04-24 21:57:08 +00:00
), strategy._backward_sync_control.no_backward_sync(Mock()):
pass
module = MagicMock(spec=FullyShardedDataParallel)
with strategy._backward_sync_control.no_backward_sync(module):
pass
module.no_sync.assert_called_once()
def test_fsdp_activation_checkpointing_support(monkeypatch):
"""Test that we error out if activation checkpointing requires a newer PyTorch version."""
monkeypatch.setattr(lightning.fabric.strategies.fsdp, "_TORCH_GREATER_EQUAL_2_1", False)
with pytest.raises(ValueError, match="activation_checkpointing_policy` requires torch >= 2.1.0"):
FSDPStrategy(activation_checkpointing_policy=Mock())
def test_fsdp_activation_checkpointing():
"""Test that the FSDP strategy can apply activation checkpointing to the given layers."""
class Block1(nn.Linear):
pass
class Block2(nn.Linear):
pass
class Model(nn.Module):
def __init__(self):
super().__init__()
self.layer0 = nn.Sequential(Block1(4, 4), Block1(5, 5))
self.layer1 = Block2(2, 2)
self.layer2 = nn.Linear(3, 3)
if _TORCH_GREATER_EQUAL_2_1:
from torch.distributed.fsdp.wrap import ModuleWrapPolicy
strategy = FSDPStrategy(activation_checkpointing_policy={Block1})
assert set(strategy._activation_checkpointing_kwargs) == {"auto_wrap_policy"}
assert isinstance(strategy._activation_checkpointing_kwargs["auto_wrap_policy"], ModuleWrapPolicy)
strategy = FSDPStrategy(activation_checkpointing_policy=ModuleWrapPolicy({Block1, Block2}))
assert set(strategy._activation_checkpointing_kwargs) == {"auto_wrap_policy"}
assert isinstance(strategy._activation_checkpointing_kwargs["auto_wrap_policy"], ModuleWrapPolicy)
else:
strategy = FSDPStrategy(activation_checkpointing=Block1)
assert set(strategy._activation_checkpointing_kwargs) == {"check_fn"}
strategy = FSDPStrategy(activation_checkpointing=[Block1, Block2])
assert set(strategy._activation_checkpointing_kwargs) == {"check_fn"}
strategy = FSDPStrategy(activation_checkpointing_policy={Block1})
assert set(strategy._activation_checkpointing_kwargs) == {"check_fn"}
strategy = FSDPStrategy(activation_checkpointing_policy={Block1, Block2})
assert set(strategy._activation_checkpointing_kwargs) == {"check_fn"}
strategy._parallel_devices = [torch.device("cuda", 0)]
with mock.patch("torch.distributed.fsdp.FullyShardedDataParallel", new=MagicMock), mock.patch(
"torch.distributed.algorithms._checkpoint.checkpoint_wrapper.apply_activation_checkpointing"
) as apply_mock:
wrapped = strategy.setup_module(Model())
apply_mock.assert_called_with(wrapped, checkpoint_wrapper_fn=ANY, **strategy._activation_checkpointing_kwargs)
def test_fsdp_forbidden_precision_raises():
with pytest.raises(TypeError, match="can only work with the `FSDPPrecision"):
FSDPStrategy(precision=HalfPrecision())
strategy = FSDPStrategy()
with pytest.raises(TypeError, match="can only work with the `FSDPPrecision"):
strategy.precision = HalfPrecision()
def test_fsdp_grad_clipping_norm_error():
strategy = FSDPStrategy()
with pytest.raises(
TypeError,
match="only possible if the module.*is wrapped in `FullyShardedDataParallel`",
):
strategy.clip_gradients_norm(Mock(), Mock(), Mock())
@RunIf(min_torch="2.0.0")
def test_fsdp_save_checkpoint_storage_options(tmp_path):
"""Test that the FSDP strategy does not accept storage options for saving checkpoints."""
strategy = FSDPStrategy()
with pytest.raises(TypeError, match=escape("FSDPStrategy.save_checkpoint(..., storage_options=...)` is not")):
strategy.save_checkpoint(path=tmp_path, state=Mock(), storage_options=Mock())
@RunIf(min_torch="2.0.0")
@mock.patch("lightning.fabric.strategies.fsdp.FSDPStrategy.broadcast", lambda _, x: x)
@mock.patch("lightning.fabric.strategies.fsdp._get_full_state_dict_context")
@mock.patch("lightning.fabric.strategies.fsdp._get_sharded_state_dict_context")
@mock.patch("lightning.fabric.strategies.fsdp.torch.save")
@mock.patch("lightning.fabric.strategies.fsdp.shutil")
def test_fsdp_save_checkpoint_path_exists(shutil_mock, torch_save_mock, __, ___, tmp_path):
strategy = FSDPStrategy(state_dict_type="full")
# state_dict_type='full', path exists, path is not a sharded checkpoint: error
path = tmp_path / "not-empty"
path.mkdir()
(path / "file").touch()
assert not _is_sharded_checkpoint(path)
with pytest.raises(IsADirectoryError, match="exists and is a directory"):
strategy.save_checkpoint(path=path, state=Mock())
# state_dict_type='full', path exists, path is a sharded checkpoint: no error (overwrite)
path = tmp_path / "sharded-checkpoint"
path.mkdir()
(path / "meta.pt").touch()
assert _is_sharded_checkpoint(path)
model = Mock(spec=FullyShardedDataParallel)
model.modules.return_value = [model]
strategy.save_checkpoint(path=path, state={"model": model})
shutil_mock.rmtree.assert_called_once_with(path)
# state_dict_type='full', path exists, path is a file: no error (overwrite)
path = tmp_path / "file.pt"
path.touch()
model = Mock(spec=FullyShardedDataParallel)
model.modules.return_value = [model]
torch_save_mock.reset_mock()
strategy.save_checkpoint(path=path, state={"model": model})
torch_save_mock.assert_called_once()
strategy = FSDPStrategy(state_dict_type="sharded")
save_mock = mock.patch(
"torch.distributed.checkpoint.save"
if _TORCH_GREATER_EQUAL_2_2 else "torch.distributed.checkpoint.save_state_dict")
# state_dict_type='sharded', path exists, path is a folder: no error (overwrite)
path = tmp_path / "not-empty-2"
path.mkdir()
(path / "file").touch()
model = Mock(spec=FullyShardedDataParallel)
model.modules.return_value = [model]
with save_mock:
strategy.save_checkpoint(path=path, state={"model": model})
assert (path / "file").exists()
# state_dict_type='sharded', path exists, path is a file: no error (overwrite)
path = tmp_path / "file-2.pt"
path.touch()
model = Mock(spec=FullyShardedDataParallel)
model.modules.return_value = [model]
with save_mock:
strategy.save_checkpoint(path=path, state={"model": model})
assert path.is_dir()
@RunIf(min_torch="2.0.0")
@mock.patch("lightning.fabric.strategies.fsdp.FSDPStrategy.broadcast", lambda _, x: x)
def test_fsdp_save_checkpoint_one_fsdp_module_required(tmp_path):
"""Test that the FSDP strategy can only save one FSDP model per checkpoint."""
strategy = FSDPStrategy()
# missing FSDP model
with pytest.raises(ValueError, match="Could not find a FSDP model in the provided checkpoint state."):
strategy.save_checkpoint(path=tmp_path, state={})
with pytest.raises(ValueError, match="Could not find a FSDP model in the provided checkpoint state."):
strategy.load_checkpoint(path=tmp_path, state={"model": torch.nn.Linear(3, 3)})
# multiple FSDP models
model1 = Mock(spec=FullyShardedDataParallel)
model1.modules.return_value = [model1]
model2 = Mock(spec=FullyShardedDataParallel)
model2.modules.return_value = [model2]
with pytest.raises(ValueError, match="Found multiple FSDP models in the given state."):
strategy.save_checkpoint(path=tmp_path, state={"model1": model1, "model2": model2})
@RunIf(min_torch="2.0.0")
def test_fsdp_load_checkpoint_no_state(tmp_path):
"""Test that the FSDP strategy can't load the full state without access to a model instance from the user."""
strategy = FSDPStrategy()
with pytest.raises(ValueError, match=escape("Got FSDPStrategy.load_checkpoint(..., state=None")):
strategy.load_checkpoint(path=tmp_path, state=None)
with pytest.raises(ValueError, match=escape("Got FSDPStrategy.load_checkpoint(..., state={})")):
strategy.load_checkpoint(path=tmp_path, state={})
@RunIf(min_torch="2.0.0")
@mock.patch("lightning.fabric.strategies.fsdp.FSDPStrategy.broadcast", lambda _, x: x)
@mock.patch("lightning.fabric.strategies.fsdp._lazy_load", Mock())
def test_fsdp_load_checkpoint_one_fsdp_module_required(tmp_path):
"""Test that the FSDP strategy can only load one FSDP model per checkpoint."""
strategy = FSDPStrategy()
# missing FSDP model
with pytest.raises(ValueError, match="Could not find a FSDP model in the provided checkpoint state."):
strategy.load_checkpoint(path=tmp_path, state={"other": "data"})
with pytest.raises(ValueError, match="Could not find a FSDP model in the provided checkpoint state."):
strategy.load_checkpoint(path=tmp_path, state={"model": torch.nn.Linear(3, 3)})
# multiple FSDP models
model1 = Mock(spec=FullyShardedDataParallel)
model1.modules.return_value = [model1]
model2 = Mock(spec=FullyShardedDataParallel)
model2.modules.return_value = [model2]
with pytest.raises(ValueError, match="Found multiple FSDP models in the given state."):
strategy.load_checkpoint(path=tmp_path, state={"model1": model1, "model2": model2})
# A raw nn.Module instead of a dictionary is ok
model = Mock(spec=nn.Module)
path = tmp_path / "full.ckpt"
path.touch()
strategy.load_checkpoint(path=path, state=model)
@RunIf(min_torch="2.0.0")
@mock.patch("lightning.fabric.strategies.fsdp.FSDPStrategy.broadcast", lambda _, x: x)
def test_fsdp_save_checkpoint_unknown_state_dict_type(tmp_path):
strategy = FSDPStrategy(state_dict_type="invalid")
model = Mock(spec=FullyShardedDataParallel)
model.modules.return_value = [model]
with pytest.raises(ValueError, match="Unknown state_dict_type"):
strategy.save_checkpoint(path=tmp_path, state={"model": model})
@RunIf(min_torch="2.0.0")
def test_fsdp_load_unknown_checkpoint_type(tmp_path):
"""Test that the strategy validates the contents at the checkpoint path."""
strategy = FSDPStrategy()
model = Mock(spec=FullyShardedDataParallel)
model.modules.return_value = [model]
path = tmp_path / "empty_dir" # neither a single file nor a directory with meta file
path.mkdir()
with pytest.raises(ValueError, match="does not point to a valid checkpoint"):
strategy.load_checkpoint(path=path, state={"model": model})
@RunIf(min_torch="2.0.0")
def test_fsdp_load_raw_checkpoint_validate_single_file(tmp_path):
"""Test that we validate the given checkpoint is a single file when loading a raw PyTorch state-dict checkpoint."""
strategy = FSDPStrategy()
model = Mock(spec=nn.Module)
path = tmp_path / "folder"
path.mkdir()
with pytest.raises(ValueError, match="The given path must be a single file containing the full state dict"):
strategy.load_checkpoint(path=path, state=model)
@RunIf(min_torch="2.0.0")
def test_fsdp_load_raw_checkpoint_optimizer_unsupported(tmp_path):
"""Validate that the FSDP strategy does not yet support loading the raw PyTorch state-dict for an optimizer."""
strategy = FSDPStrategy()
optimizer = Mock(spec=torch.optim.Optimizer)
with pytest.raises(
NotImplementedError, match="Loading a single optimizer object from a checkpoint is not supported"
):
strategy.load_checkpoint(path=tmp_path, state=optimizer)
@mock.patch("torch.distributed.init_process_group")
def test_set_timeout(init_process_group_mock):
"""Test that the timeout gets passed to the ``torch.distributed.init_process_group`` function."""
test_timedelta = timedelta(seconds=30)
strategy = FSDPStrategy(timeout=test_timedelta, parallel_devices=[torch.device("cpu")])
strategy.cluster_environment = LightningEnvironment()
strategy.accelerator = Mock()
strategy.setup_environment()
process_group_backend = strategy._get_process_group_backend()
global_rank = strategy.cluster_environment.global_rank()
world_size = strategy.cluster_environment.world_size()
init_process_group_mock.assert_called_with(
process_group_backend, rank=global_rank, world_size=world_size, timeout=test_timedelta
)
def test_has_meta_device_parameters():
"""Test that the `_has_meta_device_parameters` function can find meta-device parameters in models and
optimizers."""
# nn.Module
module = nn.Linear(2, 2)
meta_module = nn.Linear(2, 2, device="meta")
assert not _has_meta_device_parameters(module)
assert _has_meta_device_parameters(meta_module)
assert _has_meta_device_parameters(nn.Sequential(module, meta_module, nn.ReLU()))
# optim.Optimizer
optimizer = torch.optim.SGD(module.parameters(), lr=0.1)
meta_optimizer = torch.optim.SGD(meta_module.parameters(), lr=0.1)
assert not _has_meta_device_parameters(optimizer)
assert _has_meta_device_parameters(meta_optimizer)
# unsupported objects
with pytest.raises(TypeError, match="Expected `torch.nn.Module` or `torch.optim.Optimizer`"):
_has_meta_device_parameters(None)
@RunIf(min_torch="2.0")
@pytest.mark.parametrize("torch_ge_2_1", [True, False])
@mock.patch("torch.distributed.fsdp.fully_sharded_data_parallel.FullyShardedDataParallel.set_state_dict_type")
def test_get_full_state_dict_context_offload(set_type_mock, monkeypatch, torch_ge_2_1):
"""Test that the state dict context manager handles CPU offloading depending on the PyTorch version."""
monkeypatch.setattr("lightning.fabric.strategies.fsdp._TORCH_GREATER_EQUAL_2_1", torch_ge_2_1)
with _get_full_state_dict_context(module=Mock(spec=FullyShardedDataParallel), world_size=1):
assert set_type_mock.call_args_list[0][0][2].offload_to_cpu is torch_ge_2_1 # model config
assert set_type_mock.call_args_list[0][0][3].offload_to_cpu is torch_ge_2_1 # optim config
set_type_mock.reset_mock()
with _get_full_state_dict_context(module=Mock(spec=FullyShardedDataParallel), world_size=4):
assert set_type_mock.call_args_list[0][0][2].offload_to_cpu # model config
assert set_type_mock.call_args_list[0][0][3].offload_to_cpu # optim config