Add testing for PyTorch 2.4 (Fabric) (#20028)

This commit is contained in:
awaelchli 2024-07-03 00:01:03 +02:00 committed by GitHub
parent 37e04d075a
commit 693c21ac1b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
20 changed files with 84 additions and 104 deletions

View File

@ -62,6 +62,9 @@ jobs:
"Lightning | latest":
image: "pytorchlightning/pytorch_lightning:base-cuda-py3.11-torch2.3-cuda12.1.0"
PACKAGE_NAME: "lightning"
"Lightning | future":
image: "pytorchlightning/pytorch_lightning:base-cuda-py3.11-torch2.4-cuda12.1.0"
PACKAGE_NAME: "lightning"
workspace:
clean: all
steps:
@ -72,9 +75,12 @@ jobs:
echo "##vso[task.setvariable variable=TORCH_URL]https://download.pytorch.org/whl/cu${cuda_ver}/torch_stable.html"
scope=$(python -c 'n = "$(PACKAGE_NAME)" ; print(dict(fabric="lightning_fabric").get(n, n))')
echo "##vso[task.setvariable variable=COVERAGE_SOURCE]$scope"
python_ver=$(python -c "import sys; print(f'{sys.version_info.major}{sys.version_info.minor}')")
echo "##vso[task.setvariable variable=PYTHON_VERSION_MM]$python_ver"
displayName: "set env. vars"
- bash: |
echo "##vso[task.setvariable variable=TORCH_URL]https://download.pytorch.org/whl/test/cu${CUDA_VERSION_MM}/torch_test.html"
echo "##vso[task.setvariable variable=TORCH_URL]https://download.pytorch.org/whl/test/cu${CUDA_VERSION_MM}"
echo "##vso[task.setvariable variable=TORCHVISION_URL]https://download.pytorch.org/whl/test/cu124/torchvision-0.19.0%2Bcu124-cp${PYTHON_VERSION_MM}-cp${PYTHON_VERSION_MM}-linux_x86_64.whl"
condition: endsWith(variables['Agent.JobName'], 'future')
displayName: "set env. vars 4 future"
@ -103,7 +109,7 @@ jobs:
- bash: |
extra=$(python -c "print({'lightning': 'fabric-'}.get('$(PACKAGE_NAME)', ''))")
pip install -e ".[${extra}dev]" pytest-timeout -U --find-links="${TORCH_URL}"
pip install -e ".[${extra}dev]" pytest-timeout -U --find-links="${TORCH_URL}" --find-links="${TORCHVISION_URL}"
displayName: "Install package & dependencies"
- bash: |

View File

@ -49,6 +49,10 @@ jobs:
- { os: "macOS-14", pkg-name: "lightning", python-version: "3.10", pytorch-version: "2.3" }
- { os: "ubuntu-20.04", pkg-name: "lightning", python-version: "3.11", pytorch-version: "2.3" }
- { os: "windows-2022", pkg-name: "lightning", python-version: "3.11", pytorch-version: "2.3" }
- { os: "macOS-14", pkg-name: "lightning", python-version: "3.11", pytorch-version: "2.4" }
- { os: "ubuntu-20.04", pkg-name: "lightning", python-version: "3.11", pytorch-version: "2.4" }
# TODO: PyTorch 2.4 on Windows not yet working with `torch.distributed` (not compiled with libuv support)
# - { os: "windows-2022", pkg-name: "lightning", python-version: "3.11", pytorch-version: "2.4" }
# only run PyTorch latest with Python latest, use Fabric scope to limit dependency issues
- { os: "macOS-12", pkg-name: "fabric", python-version: "3.11", pytorch-version: "2.1" }
- { os: "ubuntu-22.04", pkg-name: "fabric", python-version: "3.11", pytorch-version: "2.1" }
@ -79,7 +83,7 @@ jobs:
FREEZE_REQUIREMENTS: ${{ ! (github.ref == 'refs/heads/master' || startsWith(github.ref, 'refs/heads/release/')) }}
PYPI_CACHE_DIR: "_pip-wheels"
TORCH_URL_STABLE: "https://download.pytorch.org/whl/cpu/torch_stable.html"
TORCH_URL_TEST: "https://download.pytorch.org/whl/test/cpu/torch_test.html"
TORCH_URL_TEST: "https://download.pytorch.org/whl/test/cpu/torch"
# TODO: Remove this - Enable running MPS tests on this platform
DISABLE_MPS: ${{ matrix.os == 'macOS-14' && '1' || '0' }}
steps:
@ -118,7 +122,7 @@ jobs:
- name: Env. variables
run: |
# Switch PyTorch URL
python -c "print('TORCH_URL=' + str('${{env.TORCH_URL_TEST}}' if '${{ matrix.pytorch-version }}' == '2.3' else '${{env.TORCH_URL_STABLE}}'))" >> $GITHUB_ENV
python -c "print('TORCH_URL=' + str('${{env.TORCH_URL_TEST}}' if '${{ matrix.pytorch-version }}' == '2.4' else '${{env.TORCH_URL_STABLE}}'))" >> $GITHUB_ENV
# Switch coverage scope
python -c "print('COVERAGE_SCOPE=' + str('lightning' if '${{matrix.pkg-name}}' == 'lightning' else 'lightning_fabric'))" >> $GITHUB_ENV
# if you install mono-package set dependency only for this subpackage

View File

@ -2,7 +2,7 @@
# in case you want to preserve/enforce restrictions on the latest compatible version, add "strict" as an in-line comment
numpy >=1.21.0, <1.27.0
torch >=2.1.0, <2.4.0
torch >=2.1.0, <2.5.0
fsspec[http] >=2022.5.0, <2024.4.0
packaging >=20.0, <=23.1
typing-extensions >=4.4.0, <4.10.0

View File

@ -1,6 +1,6 @@
# NOTE: the upper bound for the package version is only set for CI stability, and it is dropped while installing this package
# in case you want to preserve/enforce restrictions on the latest compatible version, add "strict" as an in-line comment
torchvision >=0.16.0, <0.19.0
torchvision >=0.16.0, <0.20.0
torchmetrics >=0.10.0, <1.3.0
lightning-utilities >=0.8.0, <0.12.0

View File

@ -22,6 +22,7 @@ from typing_extensions import override
from lightning.fabric.plugins.precision.precision import Precision
from lightning.fabric.plugins.precision.utils import _convert_fp_tensor
from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_4
from lightning.fabric.utilities.types import Optimizable
@ -39,7 +40,7 @@ class MixedPrecision(Precision):
self,
precision: Literal["16-mixed", "bf16-mixed"],
device: str,
scaler: Optional[torch.cuda.amp.GradScaler] = None,
scaler: Optional["torch.cuda.amp.GradScaler"] = None,
) -> None:
if precision not in ("16-mixed", "bf16-mixed"):
raise ValueError(
@ -49,7 +50,7 @@ class MixedPrecision(Precision):
self.precision = precision
if scaler is None and self.precision == "16-mixed":
scaler = torch.cuda.amp.GradScaler()
scaler = torch.amp.GradScaler(device=device) if _TORCH_GREATER_EQUAL_2_4 else torch.cuda.amp.GradScaler()
if scaler is not None and self.precision == "bf16-mixed":
raise ValueError(f"`precision='bf16-mixed'` does not use a scaler, found {scaler}.")
self.device = device

View File

@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import shutil
import warnings
from contextlib import ExitStack, nullcontext
from datetime import timedelta
from functools import partial
@ -83,6 +84,9 @@ if TYPE_CHECKING:
_FSDP_ALIASES = ("fsdp", "fsdp_cpu_offload")
# TODO: Switch to new state-dict APIs
warnings.filterwarnings("ignore", category=FutureWarning, message=".*FSDP.state_dict_type.*") # from torch >= 2.4
class FSDPStrategy(ParallelStrategy, _Sharded):
r"""Strategy for Fully Sharded Data Parallel provided by torch.distributed.

View File

@ -70,7 +70,7 @@ class ModelParallelStrategy(ParallelStrategy):
Currently supports up to 2D parallelism. Specifically, it supports the combination of
Fully Sharded Data-Parallel 2 (FSDP2) with Tensor Parallelism (DTensor). These PyTorch APIs are currently still
experimental in PyTorch. Requires PyTorch 2.3 or newer.
experimental in PyTorch. Requires PyTorch 2.4 or newer.
Arguments:
parallelize_fn: A function that applies parallelisms to a module. The strategy will provide the
@ -95,8 +95,8 @@ class ModelParallelStrategy(ParallelStrategy):
timeout: Optional[timedelta] = default_pg_timeout,
) -> None:
super().__init__()
if not _TORCH_GREATER_EQUAL_2_3:
raise ImportError(f"{type(self).__name__} requires PyTorch 2.3 or higher.")
if not _TORCH_GREATER_EQUAL_2_4:
raise ImportError(f"{type(self).__name__} requires PyTorch 2.4 or higher.")
self._parallelize_fn = parallelize_fn
self._data_parallel_size = data_parallel_size
self._tensor_parallel_size = tensor_parallel_size
@ -178,7 +178,7 @@ class ModelParallelStrategy(ParallelStrategy):
if any(isinstance(mod, FullyShardedDataParallel) for mod in module.modules()):
raise TypeError(
"Found modules that are wrapped with `torch.distributed.fsdp.FullyShardedDataParallel`."
f" The `{self.__class__.__name__}` only supports the new FSDP2 APIs in PyTorch >= 2.3."
f" The `{self.__class__.__name__}` only supports the new FSDP2 APIs in PyTorch >= 2.4."
)
module = self._parallelize_fn(module, self.device_mesh)
@ -329,10 +329,10 @@ class _FSDPNoSync(ContextManager):
self._enabled = enabled
def _set_requires_grad_sync(self, requires_grad_sync: bool) -> None:
from torch.distributed._composable.fsdp import FSDP
from torch.distributed._composable.fsdp import FSDPModule
for mod in self._module.modules():
if isinstance(mod, FSDP):
if isinstance(mod, FSDPModule):
mod.set_requires_gradient_sync(requires_grad_sync, recurse=False)
def __enter__(self) -> None:
@ -458,9 +458,6 @@ def _load_checkpoint(
return metadata
if _is_full_checkpoint(path):
if not _TORCH_GREATER_EQUAL_2_4:
raise ImportError("Loading a non-distributed checkpoint into a distributed model requires PyTorch >= 2.4.")
checkpoint = torch.load(path, mmap=True, map_location="cpu")
_load_raw_module_state(checkpoint.pop(module_key), module, strict=strict)
@ -546,9 +543,6 @@ def _load_raw_module_state(
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
if _has_dtensor_modules(module):
if not _TORCH_GREATER_EQUAL_2_4:
raise ImportError("Loading a non-distributed checkpoint into a distributed model requires PyTorch >= 2.4.")
from torch.distributed.checkpoint.state_dict import StateDictOptions, set_model_state_dict
state_dict_options = StateDictOptions(

View File

@ -28,7 +28,7 @@ _IS_INTERACTIVE = hasattr(sys, "ps1") or bool(sys.flags.interactive)
_TORCH_GREATER_EQUAL_2_2 = compare_version("torch", operator.ge, "2.2.0")
_TORCH_GREATER_EQUAL_2_3 = compare_version("torch", operator.ge, "2.3.0")
_TORCH_GREATER_EQUAL_2_4 = compare_version("torch", operator.ge, "2.4.0", use_base_version=True)
_TORCH_GREATER_EQUAL_2_4 = compare_version("torch", operator.ge, "2.4.0")
_PYTHON_GREATER_EQUAL_3_8_0 = (sys.version_info.major, sys.version_info.minor) >= (3, 8)
_PYTHON_GREATER_EQUAL_3_10_0 = (sys.version_info.major, sys.version_info.minor) >= (3, 10)

View File

@ -24,6 +24,7 @@ from lightning.fabric.accelerators import XLAAccelerator
from lightning.fabric.accelerators.cuda import num_cuda_devices
from lightning.fabric.accelerators.mps import MPSAccelerator
from lightning.fabric.strategies.deepspeed import _DEEPSPEED_AVAILABLE
from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_4
def _runif_reasons(
@ -111,7 +112,9 @@ def _runif_reasons(
reasons.append("Standalone execution")
kwargs["standalone"] = True
if deepspeed and not (_DEEPSPEED_AVAILABLE and RequirementCache(module="deepspeed.utils")):
if deepspeed and not (
_DEEPSPEED_AVAILABLE and not _TORCH_GREATER_EQUAL_2_4 and RequirementCache(module="deepspeed.utils")
):
reasons.append("Deepspeed")
if dynamo:

View File

@ -104,6 +104,7 @@ def thread_police_duuu_daaa_duuu_daaa():
elif (
thread.name == "QueueFeederThread" # tensorboardX
or thread.name == "QueueManagerThread" # torch.compile
or "(_read_thread)" in thread.name # torch.compile
):
thread.join(timeout=20)
elif (

View File

@ -17,11 +17,13 @@ from unittest.mock import Mock
import pytest
import torch
from lightning.fabric.plugins.precision.amp import MixedPrecision
from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_4
def test_amp_precision_default_scaler():
precision = MixedPrecision(precision="16-mixed", device=Mock())
assert isinstance(precision.scaler, torch.cuda.amp.GradScaler)
scaler_cls = torch.amp.GradScaler if _TORCH_GREATER_EQUAL_2_4 else torch.cuda.amp.GradScaler
assert isinstance(precision.scaler, scaler_cls)
def test_amp_precision_scaler_with_bf16():
@ -36,7 +38,8 @@ def test_amp_precision_forward_context():
"""Test to ensure that the context manager correctly is set to bfloat16 on CPU and CUDA."""
precision = MixedPrecision(precision="16-mixed", device="cuda")
assert precision.device == "cuda"
assert isinstance(precision.scaler, torch.cuda.amp.GradScaler)
scaler_cls = torch.amp.GradScaler if _TORCH_GREATER_EQUAL_2_4 else torch.cuda.amp.GradScaler
assert isinstance(precision.scaler, scaler_cls)
assert torch.get_default_dtype() == torch.float32
with precision.forward_context():
assert torch.get_autocast_gpu_dtype() == torch.float16

View File

@ -17,6 +17,7 @@ import pytest
import torch
import torch.nn as nn
from lightning.fabric import Fabric, seed_everything
from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_4
from tests_fabric.helpers.runif import RunIf
@ -82,7 +83,8 @@ def test_amp_fused_optimizer_parity():
optimizer = torch.optim.Adam(model.parameters(), lr=1.0, fused=fused)
model, optimizer = fabric.setup(model, optimizer)
assert isinstance(fabric._precision.scaler, torch.cuda.amp.GradScaler)
scaler_cls = torch.amp.GradScaler if _TORCH_GREATER_EQUAL_2_4 else torch.cuda.amp.GradScaler
assert isinstance(fabric._precision.scaler, scaler_cls)
data = torch.randn(10, 10, device="cuda")
target = torch.randn(10, 10, device="cuda")

View File

@ -93,7 +93,7 @@ def test_bitsandbytes_plugin(monkeypatch):
precision.convert_module(model)
@RunIf(min_cuda_gpus=1)
@RunIf(min_cuda_gpus=1, max_torch="2.4")
@pytest.mark.skipif(not _BITSANDBYTES_AVAILABLE, reason="bitsandbytes unavailable")
@pytest.mark.parametrize(
("args", "expected"),
@ -232,7 +232,7 @@ def test_bitsandbytes_layers_meta_device(args, expected, tmp_path):
assert model.l.weight.dtype == expected
@RunIf(min_cuda_gpus=1)
@RunIf(min_cuda_gpus=1, max_torch="2.4")
@pytest.mark.skipif(not _BITSANDBYTES_AVAILABLE, reason="bitsandbytes unavailable")
def test_load_quantized_checkpoint(tmp_path):
"""Test that a checkpoint saved from a quantized model can be loaded back into a quantized model."""

View File

@ -74,6 +74,7 @@ def test_dp_module_state_dict():
assert strategy.get_module_state_dict(wrapped_module).keys() == original_module.state_dict().keys()
@pytest.mark.filterwarnings("ignore::FutureWarning")
@pytest.mark.parametrize(
"precision",
[

View File

@ -118,7 +118,7 @@ class _TrainerManualWrapping(_Trainer):
return model
@RunIf(min_cuda_gpus=2, standalone=True)
@RunIf(min_cuda_gpus=2, standalone=True, max_torch="2.4")
@pytest.mark.parametrize("precision", ["16-mixed", pytest.param("bf16-mixed", marks=RunIf(bf16_cuda=True))])
@pytest.mark.parametrize("manual_wrapping", [True, False])
def test_train_save_load(tmp_path, manual_wrapping, precision):
@ -173,6 +173,7 @@ def test_train_save_load(tmp_path, manual_wrapping, precision):
assert state["coconut"] == 11
@pytest.mark.filterwarnings("ignore::FutureWarning")
@RunIf(min_cuda_gpus=2, standalone=True)
def test_save_full_state_dict(tmp_path):
"""Test that FSDP saves the full state into a single file with `state_dict_type="full"`."""
@ -287,6 +288,7 @@ def test_save_full_state_dict(tmp_path):
trainer.run()
@pytest.mark.filterwarnings("ignore::FutureWarning")
@RunIf(min_cuda_gpus=2, standalone=True)
def test_load_full_state_dict_into_sharded_model(tmp_path):
"""Test that the strategy can load a full-state checkpoint into a FSDP sharded model."""
@ -469,6 +471,7 @@ def test_module_init_context(precision, expected_dtype):
_run_setup_assertions(empty_init=True, expected_device=torch.device("meta"))
@pytest.mark.filterwarnings("ignore::FutureWarning")
@RunIf(min_cuda_gpus=2, standalone=True)
def test_save_filter(tmp_path):
fabric = Fabric(accelerator="cuda", strategy=FSDPStrategy(state_dict_type="full"), devices=2)
@ -602,6 +605,7 @@ def test_clip_gradients(clip_type, precision):
optimizer.zero_grad()
@pytest.mark.filterwarnings("ignore::FutureWarning")
@RunIf(min_cuda_gpus=2, standalone=True, min_torch="2.3.0")
def test_save_sharded_and_consolidate_and_load(tmp_path):
"""Test the consolidation of a FSDP-sharded checkpoint into a single file."""

View File

@ -28,20 +28,20 @@ from torch.optim import Adam
from tests_fabric.helpers.runif import RunIf
@mock.patch("lightning.fabric.strategies.model_parallel._TORCH_GREATER_EQUAL_2_3", False)
def test_torch_greater_equal_2_3():
with pytest.raises(ImportError, match="ModelParallelStrategy requires PyTorch 2.3 or higher"):
@mock.patch("lightning.fabric.strategies.model_parallel._TORCH_GREATER_EQUAL_2_4", False)
def test_torch_greater_equal_2_4():
with pytest.raises(ImportError, match="ModelParallelStrategy requires PyTorch 2.4 or higher"):
ModelParallelStrategy(parallelize_fn=(lambda m, _: m))
@RunIf(min_torch="2.3")
@RunIf(min_torch="2.4")
def test_device_mesh_access():
strategy = ModelParallelStrategy(parallelize_fn=(lambda m, _: m))
with pytest.raises(RuntimeError, match="Accessing the device mesh .* not allowed"):
_ = strategy.device_mesh
@RunIf(min_torch="2.3")
@RunIf(min_torch="2.4")
@pytest.mark.parametrize(
("num_nodes", "devices", "invalid_dp_size", "invalid_tp_size"),
[
@ -70,7 +70,7 @@ def test_validate_device_mesh_dimensions(num_nodes, devices, invalid_dp_size, in
strategy.setup_environment()
@RunIf(min_torch="2.3")
@RunIf(min_torch="2.4")
def test_checkpoint_io_unsupported():
"""Test that the ModelParallel strategy does not support the `CheckpointIO` plugin."""
strategy = ModelParallelStrategy(parallelize_fn=(lambda m, _: m))
@ -81,18 +81,18 @@ def test_checkpoint_io_unsupported():
strategy.checkpoint_io = Mock()
@RunIf(min_torch="2.3")
@RunIf(min_torch="2.4")
def test_fsdp_v1_modules_unsupported():
"""Test that the strategy won't allow setting up a module wrapped with the legacy FSDP API."""
from torch.distributed.fsdp import FullyShardedDataParallel
module = Mock(modules=Mock(return_value=[Mock(spec=FullyShardedDataParallel)]))
strategy = ModelParallelStrategy(parallelize_fn=(lambda x, _: x))
with pytest.raises(TypeError, match="only supports the new FSDP2 APIs in PyTorch >= 2.3"):
with pytest.raises(TypeError, match="only supports the new FSDP2 APIs in PyTorch >= 2.4"):
strategy.setup_module(module)
@RunIf(min_torch="2.3")
@RunIf(min_torch="2.4")
def test_parallelize_fn_call():
model = nn.Linear(2, 2)
optimizer = Adam(model.parameters())
@ -116,15 +116,15 @@ def test_parallelize_fn_call():
strategy.setup_module_and_optimizers(model, [optimizer])
@RunIf(min_torch="2.3")
@RunIf(min_torch="2.4")
def test_no_backward_sync():
"""Test that the backward sync control disables gradient sync on modules that benefit from it."""
from torch.distributed._composable.fsdp import FSDP
from torch.distributed._composable.fsdp import FSDPModule
strategy = ModelParallelStrategy(parallelize_fn=(lambda m, _: m))
assert isinstance(strategy._backward_sync_control, _ParallelBackwardSyncControl)
fsdp_layer = Mock(spec=FSDP)
fsdp_layer = Mock(spec=FSDPModule)
other_layer = nn.Linear(2, 2)
module = Mock()
module.modules = Mock(return_value=[fsdp_layer, other_layer])
@ -138,7 +138,7 @@ def test_no_backward_sync():
fsdp_layer.set_requires_gradient_sync.assert_called_with(False, recurse=False)
@RunIf(min_torch="2.3")
@RunIf(min_torch="2.4")
def test_save_checkpoint_storage_options(tmp_path):
"""Test that the strategy does not accept storage options for saving checkpoints."""
strategy = ModelParallelStrategy(parallelize_fn=(lambda m, _: m))
@ -148,7 +148,7 @@ def test_save_checkpoint_storage_options(tmp_path):
strategy.save_checkpoint(path=tmp_path, state=Mock(), storage_options=Mock())
@RunIf(min_torch="2.3")
@RunIf(min_torch="2.4")
@mock.patch("lightning.fabric.strategies.model_parallel.ModelParallelStrategy.broadcast", lambda _, x: x)
@mock.patch("lightning.fabric.strategies.model_parallel._has_dtensor_modules", return_value=True)
@mock.patch("torch.distributed.checkpoint.state_dict.get_model_state_dict", return_value={})
@ -205,7 +205,7 @@ def test_save_checkpoint_path_exists(shutil_mock, torch_save_mock, _, __, ___, t
assert path.is_dir()
@RunIf(min_torch="2.3")
@RunIf(min_torch="2.4")
def test_save_checkpoint_one_dist_module_required(tmp_path):
"""Test that the ModelParallelStrategy strategy can only save one distributed model per checkpoint."""
strategy = ModelParallelStrategy(parallelize_fn=(lambda m, _: m))
@ -226,29 +226,7 @@ def test_save_checkpoint_one_dist_module_required(tmp_path):
strategy.save_checkpoint(path=tmp_path, state={"model1": model1, "model2": model2})
@RunIf(min_torch="2.3")
@mock.patch("lightning.fabric.strategies.model_parallel.torch.load", Mock())
@mock.patch("lightning.fabric.strategies.model_parallel._TORCH_GREATER_EQUAL_2_4", False)
def test_load_full_checkpoint_support(tmp_path):
"""Test that loading non-distributed checkpoints into distributed models requires PyTorch >= 2.4."""
strategy = ModelParallelStrategy(parallelize_fn=(lambda m, _: m))
model = Mock(spec=nn.Module)
model.parameters.return_value = [torch.zeros(2, 1)]
path = tmp_path / "full.ckpt"
path.touch()
with pytest.raises(ImportError, match="Loading .* into a distributed model requires PyTorch >= 2.4"), mock.patch(
"lightning.fabric.strategies.model_parallel._has_dtensor_modules", return_value=True
):
strategy.load_checkpoint(path=path, state={"model": model})
with pytest.raises(ImportError, match="Loading .* into a distributed model requires PyTorch >= 2.4"), mock.patch(
"lightning.fabric.strategies.model_parallel._has_dtensor_modules", return_value=True
):
strategy.load_checkpoint(path=path, state=model)
@RunIf(min_torch="2.3")
@RunIf(min_torch="2.4")
def test_load_checkpoint_no_state(tmp_path):
"""Test that the ModelParallelStrategy strategy can't load the full state without access to a model instance from
the user."""
@ -259,7 +237,7 @@ def test_load_checkpoint_no_state(tmp_path):
strategy.load_checkpoint(path=tmp_path, state={})
@RunIf(min_torch="2.3")
@RunIf(min_torch="2.4")
@mock.patch("lightning.fabric.strategies.model_parallel.ModelParallelStrategy.broadcast", lambda _, x: x)
@mock.patch("lightning.fabric.strategies.model_parallel.torch.load", Mock())
def test_load_checkpoint_one_dist_module_required(tmp_path):
@ -289,7 +267,7 @@ def test_load_checkpoint_one_dist_module_required(tmp_path):
strategy.load_checkpoint(path=path, state=model)
@RunIf(min_torch="2.3")
@RunIf(min_torch="2.4")
@mock.patch("lightning.fabric.strategies.model_parallel._has_dtensor_modules", return_value=True)
def test_load_unknown_checkpoint_type(_, tmp_path):
"""Test that the strategy validates the contents at the checkpoint path."""
@ -301,7 +279,7 @@ def test_load_unknown_checkpoint_type(_, tmp_path):
strategy.load_checkpoint(path=path, state={"model": model})
@RunIf(min_torch="2.3")
@RunIf(min_torch="2.4")
def test_load_raw_checkpoint_validate_single_file(tmp_path):
"""Test that we validate the given checkpoint is a single file when loading a raw PyTorch state-dict checkpoint."""
strategy = ModelParallelStrategy(parallelize_fn=(lambda m, _: m))
@ -312,7 +290,7 @@ def test_load_raw_checkpoint_validate_single_file(tmp_path):
strategy.load_checkpoint(path=path, state=model)
@RunIf(min_torch="2.3")
@RunIf(min_torch="2.4")
def test_load_raw_checkpoint_optimizer_unsupported(tmp_path):
"""Validate that the ModelParallelStrategy strategy does not yet support loading the raw PyTorch state-dict for an
optimizer."""
@ -324,7 +302,7 @@ def test_load_raw_checkpoint_optimizer_unsupported(tmp_path):
strategy.load_checkpoint(path=tmp_path, state=optimizer)
@RunIf(min_torch="2.3")
@RunIf(min_torch="2.4")
@mock.patch("lightning.fabric.strategies.model_parallel._setup_device_mesh")
@mock.patch("torch.distributed.init_process_group")
def test_set_timeout(init_process_group_mock, _):
@ -343,7 +321,7 @@ def test_set_timeout(init_process_group_mock, _):
)
@RunIf(min_torch="2.3")
@RunIf(min_torch="2.4")
def test_meta_device_materialization():
"""Test that the `setup_module()` method materializes meta-device tensors in the module."""

View File

@ -80,7 +80,7 @@ def _parallelize_feed_forward_fsdp2_tp(model, device_mesh):
return model
@RunIf(min_torch="2.3", standalone=True, min_cuda_gpus=4)
@RunIf(min_torch="2.4", standalone=True, min_cuda_gpus=4)
def test_setup_device_mesh():
from torch.distributed.device_mesh import DeviceMesh
@ -116,7 +116,7 @@ def test_setup_device_mesh():
assert fabric.strategy.device_mesh.size(1) == 4
@RunIf(min_torch="2.3", standalone=True, min_cuda_gpus=2)
@RunIf(min_torch="2.4", standalone=True, min_cuda_gpus=2)
def test_tensor_parallel():
from torch.distributed._tensor import DTensor
@ -160,7 +160,7 @@ def test_tensor_parallel():
optimizer.zero_grad()
@RunIf(min_torch="2.3", standalone=True, min_cuda_gpus=4)
@RunIf(min_torch="2.4", standalone=True, min_cuda_gpus=4)
def test_fsdp2_tensor_parallel():
from torch.distributed._tensor import DTensor
@ -237,7 +237,7 @@ def _train(fabric, model=None, optimizer=None):
return model, optimizer
@RunIf(min_torch="2.3", min_cuda_gpus=4, standalone=True)
@RunIf(min_torch="2.4", min_cuda_gpus=4, standalone=True)
@pytest.mark.parametrize(
"precision",
[
@ -445,7 +445,7 @@ def test_load_full_state_dict_into_sharded_model(tmp_path):
assert torch.equal(params_before, params_after)
@RunIf(min_torch="2.3", min_cuda_gpus=2, skip_windows=True, standalone=True)
@RunIf(min_torch="2.4", min_cuda_gpus=2, skip_windows=True, standalone=True)
@pytest.mark.parametrize("move_to_device", [True, False])
@mock.patch("lightning.fabric.wrappers._FabricModule")
def test_setup_module_move_to_device(fabric_module_mock, move_to_device):
@ -471,7 +471,7 @@ def test_setup_module_move_to_device(fabric_module_mock, move_to_device):
assert fabric.device == torch.device("cuda", fabric.local_rank)
@RunIf(min_torch="2.3", min_cuda_gpus=2, skip_windows=True, standalone=True)
@RunIf(min_torch="2.4", min_cuda_gpus=2, skip_windows=True, standalone=True)
@pytest.mark.parametrize(
("precision", "expected_dtype"),
[
@ -502,7 +502,7 @@ def test_module_init_context(precision, expected_dtype):
_run_setup_assertions(empty_init=True, expected_device=torch.device("meta"))
@RunIf(min_torch="2.3", min_cuda_gpus=2, standalone=True)
@RunIf(min_torch="2.4", min_cuda_gpus=2, standalone=True)
def test_save_filter(tmp_path):
strategy = ModelParallelStrategy(
parallelize_fn=_parallelize_feed_forward_fsdp2,
@ -541,7 +541,7 @@ def _parallelize_single_linear_tp_fsdp2(model, device_mesh):
return model
@RunIf(min_torch="2.3", min_cuda_gpus=2, standalone=True)
@RunIf(min_torch="2.4", min_cuda_gpus=2, standalone=True)
@pytest.mark.parametrize(
"precision",
[
@ -597,7 +597,7 @@ def test_clip_gradients(clip_type, precision):
optimizer.zero_grad()
@RunIf(min_torch="2.3", min_cuda_gpus=4, standalone=True)
@RunIf(min_torch="2.4", min_cuda_gpus=4, standalone=True)
def test_save_sharded_and_consolidate_and_load(tmp_path):
"""Test the consolidation of a distributed (DTensor) checkpoint into a single file."""
strategy = ModelParallelStrategy(

View File

@ -868,7 +868,7 @@ def test_precision_selection_amp_ddp(strategy, devices, is_custom_plugin, plugin
assert isinstance(connector.precision, plugin_cls)
@RunIf(min_torch="2.3")
@RunIf(min_torch="2.4")
@pytest.mark.parametrize(
("precision", "raises"),
[("32-true", False), ("16-true", False), ("bf16-true", False), ("16-mixed", True), ("bf16-mixed", False)],

View File

@ -685,7 +685,7 @@ def test_unwrap_compiled():
assert unwrapped is compiled._orig_mod
assert compile_kwargs == {"fullgraph": True, "dynamic": True, "disable": False}
del compiled._compile_kwargs
compiled._compile_kwargs = None
with pytest.raises(RuntimeError, match="Failed to determine the arguments that were used to compile the module"):
_unwrap_compiled(compiled)

View File

@ -174,27 +174,6 @@ def test_save_checkpoint_path_exists(shutil_mock, torch_save_mock, tmp_path):
assert path.is_dir()
@RunIf(min_torch="2.3")
@mock.patch("lightning.fabric.strategies.model_parallel._TORCH_GREATER_EQUAL_2_4", False)
def test_load_full_checkpoint_support(tmp_path):
"""Test that loading non-distributed checkpoints into distributed models requires PyTorch >= 2.4."""
strategy = ModelParallelStrategy()
strategy.model = Mock()
strategy._lightning_module = Mock(strict_loading=True)
path = tmp_path / "full.ckpt"
path.touch()
with pytest.raises(ImportError, match="Loading .* into a distributed model requires PyTorch >= 2.4"), mock.patch(
"lightning.fabric.strategies.model_parallel._has_dtensor_modules", return_value=True
):
strategy.load_checkpoint(checkpoint_path=path)
with pytest.raises(ImportError, match="Loading .* into a distributed model requires PyTorch >= 2.4"), mock.patch(
"lightning.fabric.strategies.model_parallel._has_dtensor_modules", return_value=True
):
strategy.load_checkpoint(checkpoint_path=path)
@RunIf(min_torch="2.3")
@mock.patch("lightning.fabric.strategies.model_parallel._has_dtensor_modules", return_value=True)
def test_load_unknown_checkpoint_type(_, tmp_path):