diff --git a/.azure/gpu-tests-fabric.yml b/.azure/gpu-tests-fabric.yml index 1a85460460..576b9c3eb3 100644 --- a/.azure/gpu-tests-fabric.yml +++ b/.azure/gpu-tests-fabric.yml @@ -62,6 +62,9 @@ jobs: "Lightning | latest": image: "pytorchlightning/pytorch_lightning:base-cuda-py3.11-torch2.3-cuda12.1.0" PACKAGE_NAME: "lightning" + "Lightning | future": + image: "pytorchlightning/pytorch_lightning:base-cuda-py3.11-torch2.4-cuda12.1.0" + PACKAGE_NAME: "lightning" workspace: clean: all steps: @@ -72,9 +75,12 @@ jobs: echo "##vso[task.setvariable variable=TORCH_URL]https://download.pytorch.org/whl/cu${cuda_ver}/torch_stable.html" scope=$(python -c 'n = "$(PACKAGE_NAME)" ; print(dict(fabric="lightning_fabric").get(n, n))') echo "##vso[task.setvariable variable=COVERAGE_SOURCE]$scope" + python_ver=$(python -c "import sys; print(f'{sys.version_info.major}{sys.version_info.minor}')") + echo "##vso[task.setvariable variable=PYTHON_VERSION_MM]$python_ver" displayName: "set env. vars" - bash: | - echo "##vso[task.setvariable variable=TORCH_URL]https://download.pytorch.org/whl/test/cu${CUDA_VERSION_MM}/torch_test.html" + echo "##vso[task.setvariable variable=TORCH_URL]https://download.pytorch.org/whl/test/cu${CUDA_VERSION_MM}" + echo "##vso[task.setvariable variable=TORCHVISION_URL]https://download.pytorch.org/whl/test/cu124/torchvision-0.19.0%2Bcu124-cp${PYTHON_VERSION_MM}-cp${PYTHON_VERSION_MM}-linux_x86_64.whl" condition: endsWith(variables['Agent.JobName'], 'future') displayName: "set env. vars 4 future" @@ -103,7 +109,7 @@ jobs: - bash: | extra=$(python -c "print({'lightning': 'fabric-'}.get('$(PACKAGE_NAME)', ''))") - pip install -e ".[${extra}dev]" pytest-timeout -U --find-links="${TORCH_URL}" + pip install -e ".[${extra}dev]" pytest-timeout -U --find-links="${TORCH_URL}" --find-links="${TORCHVISION_URL}" displayName: "Install package & dependencies" - bash: | diff --git a/.github/workflows/ci-tests-fabric.yml b/.github/workflows/ci-tests-fabric.yml index 2c0d8d16b8..8d5ed3e9e7 100644 --- a/.github/workflows/ci-tests-fabric.yml +++ b/.github/workflows/ci-tests-fabric.yml @@ -49,6 +49,10 @@ jobs: - { os: "macOS-14", pkg-name: "lightning", python-version: "3.10", pytorch-version: "2.3" } - { os: "ubuntu-20.04", pkg-name: "lightning", python-version: "3.11", pytorch-version: "2.3" } - { os: "windows-2022", pkg-name: "lightning", python-version: "3.11", pytorch-version: "2.3" } + - { os: "macOS-14", pkg-name: "lightning", python-version: "3.11", pytorch-version: "2.4" } + - { os: "ubuntu-20.04", pkg-name: "lightning", python-version: "3.11", pytorch-version: "2.4" } + # TODO: PyTorch 2.4 on Windows not yet working with `torch.distributed` (not compiled with libuv support) + # - { os: "windows-2022", pkg-name: "lightning", python-version: "3.11", pytorch-version: "2.4" } # only run PyTorch latest with Python latest, use Fabric scope to limit dependency issues - { os: "macOS-12", pkg-name: "fabric", python-version: "3.11", pytorch-version: "2.1" } - { os: "ubuntu-22.04", pkg-name: "fabric", python-version: "3.11", pytorch-version: "2.1" } @@ -79,7 +83,7 @@ jobs: FREEZE_REQUIREMENTS: ${{ ! (github.ref == 'refs/heads/master' || startsWith(github.ref, 'refs/heads/release/')) }} PYPI_CACHE_DIR: "_pip-wheels" TORCH_URL_STABLE: "https://download.pytorch.org/whl/cpu/torch_stable.html" - TORCH_URL_TEST: "https://download.pytorch.org/whl/test/cpu/torch_test.html" + TORCH_URL_TEST: "https://download.pytorch.org/whl/test/cpu/torch" # TODO: Remove this - Enable running MPS tests on this platform DISABLE_MPS: ${{ matrix.os == 'macOS-14' && '1' || '0' }} steps: @@ -118,7 +122,7 @@ jobs: - name: Env. variables run: | # Switch PyTorch URL - python -c "print('TORCH_URL=' + str('${{env.TORCH_URL_TEST}}' if '${{ matrix.pytorch-version }}' == '2.3' else '${{env.TORCH_URL_STABLE}}'))" >> $GITHUB_ENV + python -c "print('TORCH_URL=' + str('${{env.TORCH_URL_TEST}}' if '${{ matrix.pytorch-version }}' == '2.4' else '${{env.TORCH_URL_STABLE}}'))" >> $GITHUB_ENV # Switch coverage scope python -c "print('COVERAGE_SCOPE=' + str('lightning' if '${{matrix.pkg-name}}' == 'lightning' else 'lightning_fabric'))" >> $GITHUB_ENV # if you install mono-package set dependency only for this subpackage diff --git a/requirements/fabric/base.txt b/requirements/fabric/base.txt index 7ca4556821..a4ba3bc9c5 100644 --- a/requirements/fabric/base.txt +++ b/requirements/fabric/base.txt @@ -2,7 +2,7 @@ # in case you want to preserve/enforce restrictions on the latest compatible version, add "strict" as an in-line comment numpy >=1.21.0, <1.27.0 -torch >=2.1.0, <2.4.0 +torch >=2.1.0, <2.5.0 fsspec[http] >=2022.5.0, <2024.4.0 packaging >=20.0, <=23.1 typing-extensions >=4.4.0, <4.10.0 diff --git a/requirements/fabric/examples.txt b/requirements/fabric/examples.txt index 49ffde9d0f..cb4135da24 100644 --- a/requirements/fabric/examples.txt +++ b/requirements/fabric/examples.txt @@ -1,6 +1,6 @@ # NOTE: the upper bound for the package version is only set for CI stability, and it is dropped while installing this package # in case you want to preserve/enforce restrictions on the latest compatible version, add "strict" as an in-line comment -torchvision >=0.16.0, <0.19.0 +torchvision >=0.16.0, <0.20.0 torchmetrics >=0.10.0, <1.3.0 lightning-utilities >=0.8.0, <0.12.0 diff --git a/src/lightning/fabric/plugins/precision/amp.py b/src/lightning/fabric/plugins/precision/amp.py index 75d7932ddb..c3b7fb74c2 100644 --- a/src/lightning/fabric/plugins/precision/amp.py +++ b/src/lightning/fabric/plugins/precision/amp.py @@ -22,6 +22,7 @@ from typing_extensions import override from lightning.fabric.plugins.precision.precision import Precision from lightning.fabric.plugins.precision.utils import _convert_fp_tensor +from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_4 from lightning.fabric.utilities.types import Optimizable @@ -39,7 +40,7 @@ class MixedPrecision(Precision): self, precision: Literal["16-mixed", "bf16-mixed"], device: str, - scaler: Optional[torch.cuda.amp.GradScaler] = None, + scaler: Optional["torch.cuda.amp.GradScaler"] = None, ) -> None: if precision not in ("16-mixed", "bf16-mixed"): raise ValueError( @@ -49,7 +50,7 @@ class MixedPrecision(Precision): self.precision = precision if scaler is None and self.precision == "16-mixed": - scaler = torch.cuda.amp.GradScaler() + scaler = torch.amp.GradScaler(device=device) if _TORCH_GREATER_EQUAL_2_4 else torch.cuda.amp.GradScaler() if scaler is not None and self.precision == "bf16-mixed": raise ValueError(f"`precision='bf16-mixed'` does not use a scaler, found {scaler}.") self.device = device diff --git a/src/lightning/fabric/strategies/fsdp.py b/src/lightning/fabric/strategies/fsdp.py index b8a3a26847..88a4be4549 100644 --- a/src/lightning/fabric/strategies/fsdp.py +++ b/src/lightning/fabric/strategies/fsdp.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import shutil +import warnings from contextlib import ExitStack, nullcontext from datetime import timedelta from functools import partial @@ -83,6 +84,9 @@ if TYPE_CHECKING: _FSDP_ALIASES = ("fsdp", "fsdp_cpu_offload") +# TODO: Switch to new state-dict APIs +warnings.filterwarnings("ignore", category=FutureWarning, message=".*FSDP.state_dict_type.*") # from torch >= 2.4 + class FSDPStrategy(ParallelStrategy, _Sharded): r"""Strategy for Fully Sharded Data Parallel provided by torch.distributed. diff --git a/src/lightning/fabric/strategies/model_parallel.py b/src/lightning/fabric/strategies/model_parallel.py index 629113b291..c727700896 100644 --- a/src/lightning/fabric/strategies/model_parallel.py +++ b/src/lightning/fabric/strategies/model_parallel.py @@ -70,7 +70,7 @@ class ModelParallelStrategy(ParallelStrategy): Currently supports up to 2D parallelism. Specifically, it supports the combination of Fully Sharded Data-Parallel 2 (FSDP2) with Tensor Parallelism (DTensor). These PyTorch APIs are currently still - experimental in PyTorch. Requires PyTorch 2.3 or newer. + experimental in PyTorch. Requires PyTorch 2.4 or newer. Arguments: parallelize_fn: A function that applies parallelisms to a module. The strategy will provide the @@ -95,8 +95,8 @@ class ModelParallelStrategy(ParallelStrategy): timeout: Optional[timedelta] = default_pg_timeout, ) -> None: super().__init__() - if not _TORCH_GREATER_EQUAL_2_3: - raise ImportError(f"{type(self).__name__} requires PyTorch 2.3 or higher.") + if not _TORCH_GREATER_EQUAL_2_4: + raise ImportError(f"{type(self).__name__} requires PyTorch 2.4 or higher.") self._parallelize_fn = parallelize_fn self._data_parallel_size = data_parallel_size self._tensor_parallel_size = tensor_parallel_size @@ -178,7 +178,7 @@ class ModelParallelStrategy(ParallelStrategy): if any(isinstance(mod, FullyShardedDataParallel) for mod in module.modules()): raise TypeError( "Found modules that are wrapped with `torch.distributed.fsdp.FullyShardedDataParallel`." - f" The `{self.__class__.__name__}` only supports the new FSDP2 APIs in PyTorch >= 2.3." + f" The `{self.__class__.__name__}` only supports the new FSDP2 APIs in PyTorch >= 2.4." ) module = self._parallelize_fn(module, self.device_mesh) @@ -329,10 +329,10 @@ class _FSDPNoSync(ContextManager): self._enabled = enabled def _set_requires_grad_sync(self, requires_grad_sync: bool) -> None: - from torch.distributed._composable.fsdp import FSDP + from torch.distributed._composable.fsdp import FSDPModule for mod in self._module.modules(): - if isinstance(mod, FSDP): + if isinstance(mod, FSDPModule): mod.set_requires_gradient_sync(requires_grad_sync, recurse=False) def __enter__(self) -> None: @@ -458,9 +458,6 @@ def _load_checkpoint( return metadata if _is_full_checkpoint(path): - if not _TORCH_GREATER_EQUAL_2_4: - raise ImportError("Loading a non-distributed checkpoint into a distributed model requires PyTorch >= 2.4.") - checkpoint = torch.load(path, mmap=True, map_location="cpu") _load_raw_module_state(checkpoint.pop(module_key), module, strict=strict) @@ -546,9 +543,6 @@ def _load_raw_module_state( from torch.distributed.fsdp import FullyShardedDataParallel as FSDP if _has_dtensor_modules(module): - if not _TORCH_GREATER_EQUAL_2_4: - raise ImportError("Loading a non-distributed checkpoint into a distributed model requires PyTorch >= 2.4.") - from torch.distributed.checkpoint.state_dict import StateDictOptions, set_model_state_dict state_dict_options = StateDictOptions( diff --git a/src/lightning/fabric/utilities/imports.py b/src/lightning/fabric/utilities/imports.py index fc40175ff5..c8fa1ddf1e 100644 --- a/src/lightning/fabric/utilities/imports.py +++ b/src/lightning/fabric/utilities/imports.py @@ -28,7 +28,7 @@ _IS_INTERACTIVE = hasattr(sys, "ps1") or bool(sys.flags.interactive) _TORCH_GREATER_EQUAL_2_2 = compare_version("torch", operator.ge, "2.2.0") _TORCH_GREATER_EQUAL_2_3 = compare_version("torch", operator.ge, "2.3.0") -_TORCH_GREATER_EQUAL_2_4 = compare_version("torch", operator.ge, "2.4.0", use_base_version=True) +_TORCH_GREATER_EQUAL_2_4 = compare_version("torch", operator.ge, "2.4.0") _PYTHON_GREATER_EQUAL_3_8_0 = (sys.version_info.major, sys.version_info.minor) >= (3, 8) _PYTHON_GREATER_EQUAL_3_10_0 = (sys.version_info.major, sys.version_info.minor) >= (3, 10) diff --git a/src/lightning/fabric/utilities/testing/_runif.py b/src/lightning/fabric/utilities/testing/_runif.py index f282651048..6f0513465c 100644 --- a/src/lightning/fabric/utilities/testing/_runif.py +++ b/src/lightning/fabric/utilities/testing/_runif.py @@ -24,6 +24,7 @@ from lightning.fabric.accelerators import XLAAccelerator from lightning.fabric.accelerators.cuda import num_cuda_devices from lightning.fabric.accelerators.mps import MPSAccelerator from lightning.fabric.strategies.deepspeed import _DEEPSPEED_AVAILABLE +from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_4 def _runif_reasons( @@ -111,7 +112,9 @@ def _runif_reasons( reasons.append("Standalone execution") kwargs["standalone"] = True - if deepspeed and not (_DEEPSPEED_AVAILABLE and RequirementCache(module="deepspeed.utils")): + if deepspeed and not ( + _DEEPSPEED_AVAILABLE and not _TORCH_GREATER_EQUAL_2_4 and RequirementCache(module="deepspeed.utils") + ): reasons.append("Deepspeed") if dynamo: diff --git a/tests/tests_fabric/conftest.py b/tests/tests_fabric/conftest.py index 8b0d83d7f2..446994167d 100644 --- a/tests/tests_fabric/conftest.py +++ b/tests/tests_fabric/conftest.py @@ -104,6 +104,7 @@ def thread_police_duuu_daaa_duuu_daaa(): elif ( thread.name == "QueueFeederThread" # tensorboardX or thread.name == "QueueManagerThread" # torch.compile + or "(_read_thread)" in thread.name # torch.compile ): thread.join(timeout=20) elif ( diff --git a/tests/tests_fabric/plugins/precision/test_amp.py b/tests/tests_fabric/plugins/precision/test_amp.py index 34f14b8871..93d53eb406 100644 --- a/tests/tests_fabric/plugins/precision/test_amp.py +++ b/tests/tests_fabric/plugins/precision/test_amp.py @@ -17,11 +17,13 @@ from unittest.mock import Mock import pytest import torch from lightning.fabric.plugins.precision.amp import MixedPrecision +from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_4 def test_amp_precision_default_scaler(): precision = MixedPrecision(precision="16-mixed", device=Mock()) - assert isinstance(precision.scaler, torch.cuda.amp.GradScaler) + scaler_cls = torch.amp.GradScaler if _TORCH_GREATER_EQUAL_2_4 else torch.cuda.amp.GradScaler + assert isinstance(precision.scaler, scaler_cls) def test_amp_precision_scaler_with_bf16(): @@ -36,7 +38,8 @@ def test_amp_precision_forward_context(): """Test to ensure that the context manager correctly is set to bfloat16 on CPU and CUDA.""" precision = MixedPrecision(precision="16-mixed", device="cuda") assert precision.device == "cuda" - assert isinstance(precision.scaler, torch.cuda.amp.GradScaler) + scaler_cls = torch.amp.GradScaler if _TORCH_GREATER_EQUAL_2_4 else torch.cuda.amp.GradScaler + assert isinstance(precision.scaler, scaler_cls) assert torch.get_default_dtype() == torch.float32 with precision.forward_context(): assert torch.get_autocast_gpu_dtype() == torch.float16 diff --git a/tests/tests_fabric/plugins/precision/test_amp_integration.py b/tests/tests_fabric/plugins/precision/test_amp_integration.py index 5d88a7d9ba..aa6c6cfce4 100644 --- a/tests/tests_fabric/plugins/precision/test_amp_integration.py +++ b/tests/tests_fabric/plugins/precision/test_amp_integration.py @@ -17,6 +17,7 @@ import pytest import torch import torch.nn as nn from lightning.fabric import Fabric, seed_everything +from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_4 from tests_fabric.helpers.runif import RunIf @@ -82,7 +83,8 @@ def test_amp_fused_optimizer_parity(): optimizer = torch.optim.Adam(model.parameters(), lr=1.0, fused=fused) model, optimizer = fabric.setup(model, optimizer) - assert isinstance(fabric._precision.scaler, torch.cuda.amp.GradScaler) + scaler_cls = torch.amp.GradScaler if _TORCH_GREATER_EQUAL_2_4 else torch.cuda.amp.GradScaler + assert isinstance(fabric._precision.scaler, scaler_cls) data = torch.randn(10, 10, device="cuda") target = torch.randn(10, 10, device="cuda") diff --git a/tests/tests_fabric/plugins/precision/test_bitsandbytes.py b/tests/tests_fabric/plugins/precision/test_bitsandbytes.py index c45adef192..31d655fc4c 100644 --- a/tests/tests_fabric/plugins/precision/test_bitsandbytes.py +++ b/tests/tests_fabric/plugins/precision/test_bitsandbytes.py @@ -93,7 +93,7 @@ def test_bitsandbytes_plugin(monkeypatch): precision.convert_module(model) -@RunIf(min_cuda_gpus=1) +@RunIf(min_cuda_gpus=1, max_torch="2.4") @pytest.mark.skipif(not _BITSANDBYTES_AVAILABLE, reason="bitsandbytes unavailable") @pytest.mark.parametrize( ("args", "expected"), @@ -232,7 +232,7 @@ def test_bitsandbytes_layers_meta_device(args, expected, tmp_path): assert model.l.weight.dtype == expected -@RunIf(min_cuda_gpus=1) +@RunIf(min_cuda_gpus=1, max_torch="2.4") @pytest.mark.skipif(not _BITSANDBYTES_AVAILABLE, reason="bitsandbytes unavailable") def test_load_quantized_checkpoint(tmp_path): """Test that a checkpoint saved from a quantized model can be loaded back into a quantized model.""" diff --git a/tests/tests_fabric/strategies/test_dp.py b/tests/tests_fabric/strategies/test_dp.py index 572bbd20d3..e50abb1882 100644 --- a/tests/tests_fabric/strategies/test_dp.py +++ b/tests/tests_fabric/strategies/test_dp.py @@ -74,6 +74,7 @@ def test_dp_module_state_dict(): assert strategy.get_module_state_dict(wrapped_module).keys() == original_module.state_dict().keys() +@pytest.mark.filterwarnings("ignore::FutureWarning") @pytest.mark.parametrize( "precision", [ diff --git a/tests/tests_fabric/strategies/test_fsdp_integration.py b/tests/tests_fabric/strategies/test_fsdp_integration.py index e324ab5056..77b2f975d2 100644 --- a/tests/tests_fabric/strategies/test_fsdp_integration.py +++ b/tests/tests_fabric/strategies/test_fsdp_integration.py @@ -118,7 +118,7 @@ class _TrainerManualWrapping(_Trainer): return model -@RunIf(min_cuda_gpus=2, standalone=True) +@RunIf(min_cuda_gpus=2, standalone=True, max_torch="2.4") @pytest.mark.parametrize("precision", ["16-mixed", pytest.param("bf16-mixed", marks=RunIf(bf16_cuda=True))]) @pytest.mark.parametrize("manual_wrapping", [True, False]) def test_train_save_load(tmp_path, manual_wrapping, precision): @@ -173,6 +173,7 @@ def test_train_save_load(tmp_path, manual_wrapping, precision): assert state["coconut"] == 11 +@pytest.mark.filterwarnings("ignore::FutureWarning") @RunIf(min_cuda_gpus=2, standalone=True) def test_save_full_state_dict(tmp_path): """Test that FSDP saves the full state into a single file with `state_dict_type="full"`.""" @@ -287,6 +288,7 @@ def test_save_full_state_dict(tmp_path): trainer.run() +@pytest.mark.filterwarnings("ignore::FutureWarning") @RunIf(min_cuda_gpus=2, standalone=True) def test_load_full_state_dict_into_sharded_model(tmp_path): """Test that the strategy can load a full-state checkpoint into a FSDP sharded model.""" @@ -469,6 +471,7 @@ def test_module_init_context(precision, expected_dtype): _run_setup_assertions(empty_init=True, expected_device=torch.device("meta")) +@pytest.mark.filterwarnings("ignore::FutureWarning") @RunIf(min_cuda_gpus=2, standalone=True) def test_save_filter(tmp_path): fabric = Fabric(accelerator="cuda", strategy=FSDPStrategy(state_dict_type="full"), devices=2) @@ -602,6 +605,7 @@ def test_clip_gradients(clip_type, precision): optimizer.zero_grad() +@pytest.mark.filterwarnings("ignore::FutureWarning") @RunIf(min_cuda_gpus=2, standalone=True, min_torch="2.3.0") def test_save_sharded_and_consolidate_and_load(tmp_path): """Test the consolidation of a FSDP-sharded checkpoint into a single file.""" diff --git a/tests/tests_fabric/strategies/test_model_parallel.py b/tests/tests_fabric/strategies/test_model_parallel.py index 03b9268b31..1f8b5b783b 100644 --- a/tests/tests_fabric/strategies/test_model_parallel.py +++ b/tests/tests_fabric/strategies/test_model_parallel.py @@ -28,20 +28,20 @@ from torch.optim import Adam from tests_fabric.helpers.runif import RunIf -@mock.patch("lightning.fabric.strategies.model_parallel._TORCH_GREATER_EQUAL_2_3", False) -def test_torch_greater_equal_2_3(): - with pytest.raises(ImportError, match="ModelParallelStrategy requires PyTorch 2.3 or higher"): +@mock.patch("lightning.fabric.strategies.model_parallel._TORCH_GREATER_EQUAL_2_4", False) +def test_torch_greater_equal_2_4(): + with pytest.raises(ImportError, match="ModelParallelStrategy requires PyTorch 2.4 or higher"): ModelParallelStrategy(parallelize_fn=(lambda m, _: m)) -@RunIf(min_torch="2.3") +@RunIf(min_torch="2.4") def test_device_mesh_access(): strategy = ModelParallelStrategy(parallelize_fn=(lambda m, _: m)) with pytest.raises(RuntimeError, match="Accessing the device mesh .* not allowed"): _ = strategy.device_mesh -@RunIf(min_torch="2.3") +@RunIf(min_torch="2.4") @pytest.mark.parametrize( ("num_nodes", "devices", "invalid_dp_size", "invalid_tp_size"), [ @@ -70,7 +70,7 @@ def test_validate_device_mesh_dimensions(num_nodes, devices, invalid_dp_size, in strategy.setup_environment() -@RunIf(min_torch="2.3") +@RunIf(min_torch="2.4") def test_checkpoint_io_unsupported(): """Test that the ModelParallel strategy does not support the `CheckpointIO` plugin.""" strategy = ModelParallelStrategy(parallelize_fn=(lambda m, _: m)) @@ -81,18 +81,18 @@ def test_checkpoint_io_unsupported(): strategy.checkpoint_io = Mock() -@RunIf(min_torch="2.3") +@RunIf(min_torch="2.4") def test_fsdp_v1_modules_unsupported(): """Test that the strategy won't allow setting up a module wrapped with the legacy FSDP API.""" from torch.distributed.fsdp import FullyShardedDataParallel module = Mock(modules=Mock(return_value=[Mock(spec=FullyShardedDataParallel)])) strategy = ModelParallelStrategy(parallelize_fn=(lambda x, _: x)) - with pytest.raises(TypeError, match="only supports the new FSDP2 APIs in PyTorch >= 2.3"): + with pytest.raises(TypeError, match="only supports the new FSDP2 APIs in PyTorch >= 2.4"): strategy.setup_module(module) -@RunIf(min_torch="2.3") +@RunIf(min_torch="2.4") def test_parallelize_fn_call(): model = nn.Linear(2, 2) optimizer = Adam(model.parameters()) @@ -116,15 +116,15 @@ def test_parallelize_fn_call(): strategy.setup_module_and_optimizers(model, [optimizer]) -@RunIf(min_torch="2.3") +@RunIf(min_torch="2.4") def test_no_backward_sync(): """Test that the backward sync control disables gradient sync on modules that benefit from it.""" - from torch.distributed._composable.fsdp import FSDP + from torch.distributed._composable.fsdp import FSDPModule strategy = ModelParallelStrategy(parallelize_fn=(lambda m, _: m)) assert isinstance(strategy._backward_sync_control, _ParallelBackwardSyncControl) - fsdp_layer = Mock(spec=FSDP) + fsdp_layer = Mock(spec=FSDPModule) other_layer = nn.Linear(2, 2) module = Mock() module.modules = Mock(return_value=[fsdp_layer, other_layer]) @@ -138,7 +138,7 @@ def test_no_backward_sync(): fsdp_layer.set_requires_gradient_sync.assert_called_with(False, recurse=False) -@RunIf(min_torch="2.3") +@RunIf(min_torch="2.4") def test_save_checkpoint_storage_options(tmp_path): """Test that the strategy does not accept storage options for saving checkpoints.""" strategy = ModelParallelStrategy(parallelize_fn=(lambda m, _: m)) @@ -148,7 +148,7 @@ def test_save_checkpoint_storage_options(tmp_path): strategy.save_checkpoint(path=tmp_path, state=Mock(), storage_options=Mock()) -@RunIf(min_torch="2.3") +@RunIf(min_torch="2.4") @mock.patch("lightning.fabric.strategies.model_parallel.ModelParallelStrategy.broadcast", lambda _, x: x) @mock.patch("lightning.fabric.strategies.model_parallel._has_dtensor_modules", return_value=True) @mock.patch("torch.distributed.checkpoint.state_dict.get_model_state_dict", return_value={}) @@ -205,7 +205,7 @@ def test_save_checkpoint_path_exists(shutil_mock, torch_save_mock, _, __, ___, t assert path.is_dir() -@RunIf(min_torch="2.3") +@RunIf(min_torch="2.4") def test_save_checkpoint_one_dist_module_required(tmp_path): """Test that the ModelParallelStrategy strategy can only save one distributed model per checkpoint.""" strategy = ModelParallelStrategy(parallelize_fn=(lambda m, _: m)) @@ -226,29 +226,7 @@ def test_save_checkpoint_one_dist_module_required(tmp_path): strategy.save_checkpoint(path=tmp_path, state={"model1": model1, "model2": model2}) -@RunIf(min_torch="2.3") -@mock.patch("lightning.fabric.strategies.model_parallel.torch.load", Mock()) -@mock.patch("lightning.fabric.strategies.model_parallel._TORCH_GREATER_EQUAL_2_4", False) -def test_load_full_checkpoint_support(tmp_path): - """Test that loading non-distributed checkpoints into distributed models requires PyTorch >= 2.4.""" - strategy = ModelParallelStrategy(parallelize_fn=(lambda m, _: m)) - model = Mock(spec=nn.Module) - model.parameters.return_value = [torch.zeros(2, 1)] - path = tmp_path / "full.ckpt" - path.touch() - - with pytest.raises(ImportError, match="Loading .* into a distributed model requires PyTorch >= 2.4"), mock.patch( - "lightning.fabric.strategies.model_parallel._has_dtensor_modules", return_value=True - ): - strategy.load_checkpoint(path=path, state={"model": model}) - - with pytest.raises(ImportError, match="Loading .* into a distributed model requires PyTorch >= 2.4"), mock.patch( - "lightning.fabric.strategies.model_parallel._has_dtensor_modules", return_value=True - ): - strategy.load_checkpoint(path=path, state=model) - - -@RunIf(min_torch="2.3") +@RunIf(min_torch="2.4") def test_load_checkpoint_no_state(tmp_path): """Test that the ModelParallelStrategy strategy can't load the full state without access to a model instance from the user.""" @@ -259,7 +237,7 @@ def test_load_checkpoint_no_state(tmp_path): strategy.load_checkpoint(path=tmp_path, state={}) -@RunIf(min_torch="2.3") +@RunIf(min_torch="2.4") @mock.patch("lightning.fabric.strategies.model_parallel.ModelParallelStrategy.broadcast", lambda _, x: x) @mock.patch("lightning.fabric.strategies.model_parallel.torch.load", Mock()) def test_load_checkpoint_one_dist_module_required(tmp_path): @@ -289,7 +267,7 @@ def test_load_checkpoint_one_dist_module_required(tmp_path): strategy.load_checkpoint(path=path, state=model) -@RunIf(min_torch="2.3") +@RunIf(min_torch="2.4") @mock.patch("lightning.fabric.strategies.model_parallel._has_dtensor_modules", return_value=True) def test_load_unknown_checkpoint_type(_, tmp_path): """Test that the strategy validates the contents at the checkpoint path.""" @@ -301,7 +279,7 @@ def test_load_unknown_checkpoint_type(_, tmp_path): strategy.load_checkpoint(path=path, state={"model": model}) -@RunIf(min_torch="2.3") +@RunIf(min_torch="2.4") def test_load_raw_checkpoint_validate_single_file(tmp_path): """Test that we validate the given checkpoint is a single file when loading a raw PyTorch state-dict checkpoint.""" strategy = ModelParallelStrategy(parallelize_fn=(lambda m, _: m)) @@ -312,7 +290,7 @@ def test_load_raw_checkpoint_validate_single_file(tmp_path): strategy.load_checkpoint(path=path, state=model) -@RunIf(min_torch="2.3") +@RunIf(min_torch="2.4") def test_load_raw_checkpoint_optimizer_unsupported(tmp_path): """Validate that the ModelParallelStrategy strategy does not yet support loading the raw PyTorch state-dict for an optimizer.""" @@ -324,7 +302,7 @@ def test_load_raw_checkpoint_optimizer_unsupported(tmp_path): strategy.load_checkpoint(path=tmp_path, state=optimizer) -@RunIf(min_torch="2.3") +@RunIf(min_torch="2.4") @mock.patch("lightning.fabric.strategies.model_parallel._setup_device_mesh") @mock.patch("torch.distributed.init_process_group") def test_set_timeout(init_process_group_mock, _): @@ -343,7 +321,7 @@ def test_set_timeout(init_process_group_mock, _): ) -@RunIf(min_torch="2.3") +@RunIf(min_torch="2.4") def test_meta_device_materialization(): """Test that the `setup_module()` method materializes meta-device tensors in the module.""" diff --git a/tests/tests_fabric/strategies/test_model_parallel_integration.py b/tests/tests_fabric/strategies/test_model_parallel_integration.py index 6db31d00f7..e8a8e5b5a4 100644 --- a/tests/tests_fabric/strategies/test_model_parallel_integration.py +++ b/tests/tests_fabric/strategies/test_model_parallel_integration.py @@ -80,7 +80,7 @@ def _parallelize_feed_forward_fsdp2_tp(model, device_mesh): return model -@RunIf(min_torch="2.3", standalone=True, min_cuda_gpus=4) +@RunIf(min_torch="2.4", standalone=True, min_cuda_gpus=4) def test_setup_device_mesh(): from torch.distributed.device_mesh import DeviceMesh @@ -116,7 +116,7 @@ def test_setup_device_mesh(): assert fabric.strategy.device_mesh.size(1) == 4 -@RunIf(min_torch="2.3", standalone=True, min_cuda_gpus=2) +@RunIf(min_torch="2.4", standalone=True, min_cuda_gpus=2) def test_tensor_parallel(): from torch.distributed._tensor import DTensor @@ -160,7 +160,7 @@ def test_tensor_parallel(): optimizer.zero_grad() -@RunIf(min_torch="2.3", standalone=True, min_cuda_gpus=4) +@RunIf(min_torch="2.4", standalone=True, min_cuda_gpus=4) def test_fsdp2_tensor_parallel(): from torch.distributed._tensor import DTensor @@ -237,7 +237,7 @@ def _train(fabric, model=None, optimizer=None): return model, optimizer -@RunIf(min_torch="2.3", min_cuda_gpus=4, standalone=True) +@RunIf(min_torch="2.4", min_cuda_gpus=4, standalone=True) @pytest.mark.parametrize( "precision", [ @@ -445,7 +445,7 @@ def test_load_full_state_dict_into_sharded_model(tmp_path): assert torch.equal(params_before, params_after) -@RunIf(min_torch="2.3", min_cuda_gpus=2, skip_windows=True, standalone=True) +@RunIf(min_torch="2.4", min_cuda_gpus=2, skip_windows=True, standalone=True) @pytest.mark.parametrize("move_to_device", [True, False]) @mock.patch("lightning.fabric.wrappers._FabricModule") def test_setup_module_move_to_device(fabric_module_mock, move_to_device): @@ -471,7 +471,7 @@ def test_setup_module_move_to_device(fabric_module_mock, move_to_device): assert fabric.device == torch.device("cuda", fabric.local_rank) -@RunIf(min_torch="2.3", min_cuda_gpus=2, skip_windows=True, standalone=True) +@RunIf(min_torch="2.4", min_cuda_gpus=2, skip_windows=True, standalone=True) @pytest.mark.parametrize( ("precision", "expected_dtype"), [ @@ -502,7 +502,7 @@ def test_module_init_context(precision, expected_dtype): _run_setup_assertions(empty_init=True, expected_device=torch.device("meta")) -@RunIf(min_torch="2.3", min_cuda_gpus=2, standalone=True) +@RunIf(min_torch="2.4", min_cuda_gpus=2, standalone=True) def test_save_filter(tmp_path): strategy = ModelParallelStrategy( parallelize_fn=_parallelize_feed_forward_fsdp2, @@ -541,7 +541,7 @@ def _parallelize_single_linear_tp_fsdp2(model, device_mesh): return model -@RunIf(min_torch="2.3", min_cuda_gpus=2, standalone=True) +@RunIf(min_torch="2.4", min_cuda_gpus=2, standalone=True) @pytest.mark.parametrize( "precision", [ @@ -597,7 +597,7 @@ def test_clip_gradients(clip_type, precision): optimizer.zero_grad() -@RunIf(min_torch="2.3", min_cuda_gpus=4, standalone=True) +@RunIf(min_torch="2.4", min_cuda_gpus=4, standalone=True) def test_save_sharded_and_consolidate_and_load(tmp_path): """Test the consolidation of a distributed (DTensor) checkpoint into a single file.""" strategy = ModelParallelStrategy( diff --git a/tests/tests_fabric/test_connector.py b/tests/tests_fabric/test_connector.py index b6f6b03b37..08d6dbb45e 100644 --- a/tests/tests_fabric/test_connector.py +++ b/tests/tests_fabric/test_connector.py @@ -868,7 +868,7 @@ def test_precision_selection_amp_ddp(strategy, devices, is_custom_plugin, plugin assert isinstance(connector.precision, plugin_cls) -@RunIf(min_torch="2.3") +@RunIf(min_torch="2.4") @pytest.mark.parametrize( ("precision", "raises"), [("32-true", False), ("16-true", False), ("bf16-true", False), ("16-mixed", True), ("bf16-mixed", False)], diff --git a/tests/tests_fabric/test_wrappers.py b/tests/tests_fabric/test_wrappers.py index b89a536aff..26223b47f8 100644 --- a/tests/tests_fabric/test_wrappers.py +++ b/tests/tests_fabric/test_wrappers.py @@ -685,7 +685,7 @@ def test_unwrap_compiled(): assert unwrapped is compiled._orig_mod assert compile_kwargs == {"fullgraph": True, "dynamic": True, "disable": False} - del compiled._compile_kwargs + compiled._compile_kwargs = None with pytest.raises(RuntimeError, match="Failed to determine the arguments that were used to compile the module"): _unwrap_compiled(compiled) diff --git a/tests/tests_pytorch/strategies/test_model_parallel.py b/tests/tests_pytorch/strategies/test_model_parallel.py index e22593f391..4f30ae8fef 100644 --- a/tests/tests_pytorch/strategies/test_model_parallel.py +++ b/tests/tests_pytorch/strategies/test_model_parallel.py @@ -174,27 +174,6 @@ def test_save_checkpoint_path_exists(shutil_mock, torch_save_mock, tmp_path): assert path.is_dir() -@RunIf(min_torch="2.3") -@mock.patch("lightning.fabric.strategies.model_parallel._TORCH_GREATER_EQUAL_2_4", False) -def test_load_full_checkpoint_support(tmp_path): - """Test that loading non-distributed checkpoints into distributed models requires PyTorch >= 2.4.""" - strategy = ModelParallelStrategy() - strategy.model = Mock() - strategy._lightning_module = Mock(strict_loading=True) - path = tmp_path / "full.ckpt" - path.touch() - - with pytest.raises(ImportError, match="Loading .* into a distributed model requires PyTorch >= 2.4"), mock.patch( - "lightning.fabric.strategies.model_parallel._has_dtensor_modules", return_value=True - ): - strategy.load_checkpoint(checkpoint_path=path) - - with pytest.raises(ImportError, match="Loading .* into a distributed model requires PyTorch >= 2.4"), mock.patch( - "lightning.fabric.strategies.model_parallel._has_dtensor_modules", return_value=True - ): - strategy.load_checkpoint(checkpoint_path=path) - - @RunIf(min_torch="2.3") @mock.patch("lightning.fabric.strategies.model_parallel._has_dtensor_modules", return_value=True) def test_load_unknown_checkpoint_type(_, tmp_path):