diff --git a/.azure/gpu-tests-pytorch.yml b/.azure/gpu-tests-pytorch.yml index dd1a3d4abc..15ce2f6ace 100644 --- a/.azure/gpu-tests-pytorch.yml +++ b/.azure/gpu-tests-pytorch.yml @@ -55,6 +55,9 @@ jobs: "Lightning | latest": image: "pytorchlightning/pytorch_lightning:base-cuda-py3.11-torch2.3-cuda12.1.0" PACKAGE_NAME: "lightning" + "Lightning | future": + image: "pytorchlightning/pytorch_lightning:base-cuda-py3.11-torch2.4-cuda12.1.0" + PACKAGE_NAME: "lightning" pool: lit-rtx-3090 variables: DEVICES: $( python -c 'print("$(Agent.Name)".split("_")[-1])' ) @@ -76,9 +79,12 @@ jobs: echo "##vso[task.setvariable variable=TORCH_URL]https://download.pytorch.org/whl/cu${cuda_ver}/torch_stable.html" scope=$(python -c 'n = "$(PACKAGE_NAME)" ; print(dict(pytorch="pytorch_lightning").get(n, n))') echo "##vso[task.setvariable variable=COVERAGE_SOURCE]$scope" + python_ver=$(python -c "import sys; print(f'{sys.version_info.major}{sys.version_info.minor}')") + echo "##vso[task.setvariable variable=PYTHON_VERSION_MM]$python_ver" displayName: "set env. vars" - bash: | - echo "##vso[task.setvariable variable=TORCH_URL]https://download.pytorch.org/whl/test/cu${CUDA_VERSION_MM}/torch_test.html" + echo "##vso[task.setvariable variable=TORCH_URL]https://download.pytorch.org/whl/test/cu${CUDA_VERSION_MM}" + echo "##vso[task.setvariable variable=TORCHVISION_URL]https://download.pytorch.org/whl/test/cu124/torchvision-0.19.0%2Bcu124-cp${PYTHON_VERSION_MM}-cp${PYTHON_VERSION_MM}-linux_x86_64.whl" condition: endsWith(variables['Agent.JobName'], 'future') displayName: "set env. vars 4 future" @@ -107,7 +113,7 @@ jobs: - bash: | extra=$(python -c "print({'lightning': 'pytorch-'}.get('$(PACKAGE_NAME)', ''))") - pip install -e ".[${extra}dev]" pytest-timeout -U --find-links="${TORCH_URL}" + pip install -e ".[${extra}dev]" pytest-timeout -U --find-links="${TORCH_URL}" --find-links="${TORCHVISION_URL}" displayName: "Install package & dependencies" - bash: pip uninstall -y lightning diff --git a/.github/checkgroup.yml b/.github/checkgroup.yml index 3774a56e2f..9e0a537e70 100644 --- a/.github/checkgroup.yml +++ b/.github/checkgroup.yml @@ -142,11 +142,15 @@ subprojects: - "build-cuda (3.10, 2.2, 12.1.0)" - "build-cuda (3.11, 2.1, 12.1.0)" - "build-cuda (3.11, 2.2, 12.1.0)" + - "build-cuda (3.11, 2.3, 12.1.0)" + - "build-cuda (3.11, 2.4, 12.1.0)" #- "build-NGC" - "build-pl (3.10, 2.1, 12.1.0)" - "build-pl (3.10, 2.2, 12.1.0)" - "build-pl (3.11, 2.1, 12.1.0)" - "build-pl (3.11, 2.2, 12.1.0)" + - "build-pl (3.11, 2.3, 12.1.0)" + - "build-pl (3.11, 2.4, 12.1.0)" # SECTION: lightning_fabric diff --git a/.github/workflows/ci-tests-pytorch.yml b/.github/workflows/ci-tests-pytorch.yml index b75b6e73d9..6ff9fbf05b 100644 --- a/.github/workflows/ci-tests-pytorch.yml +++ b/.github/workflows/ci-tests-pytorch.yml @@ -53,6 +53,9 @@ jobs: - { os: "macOS-14", pkg-name: "lightning", python-version: "3.10", pytorch-version: "2.3" } - { os: "ubuntu-20.04", pkg-name: "lightning", python-version: "3.10", pytorch-version: "2.3" } - { os: "windows-2022", pkg-name: "lightning", python-version: "3.10", pytorch-version: "2.3" } + - { os: "macOS-14", pkg-name: "lightning", python-version: "3.11", pytorch-version: "2.4" } + - { os: "ubuntu-20.04", pkg-name: "lightning", python-version: "3.11", pytorch-version: "2.4" } + - { os: "windows-2022", pkg-name: "lightning", python-version: "3.11", pytorch-version: "2.4" } # only run PyTorch latest with Python latest, use PyTorch scope to limit dependency issues - { os: "macOS-12", pkg-name: "pytorch", python-version: "3.11", pytorch-version: "2.1" } - { os: "ubuntu-22.04", pkg-name: "pytorch", python-version: "3.11", pytorch-version: "2.1" } @@ -82,7 +85,7 @@ jobs: PACKAGE_NAME: ${{ matrix.pkg-name }} TORCH_URL: "https://download.pytorch.org/whl/cpu/torch_stable.html" TORCH_URL_STABLE: "https://download.pytorch.org/whl/cpu/torch_stable.html" - TORCH_URL_TEST: "https://download.pytorch.org/whl/test/cpu/torch_test.html" + TORCH_URL_TEST: "https://download.pytorch.org/whl/test/cpu/torch" FREEZE_REQUIREMENTS: ${{ ! (github.ref == 'refs/heads/master' || startsWith(github.ref, 'refs/heads/release/')) }} PYPI_CACHE_DIR: "_pip-wheels" # TODO: Remove this - Enable running MPS tests on this platform @@ -124,11 +127,13 @@ jobs: - name: Env. variables run: | # Switch PyTorch URL - python -c "print('TORCH_URL=' + str('${{env.TORCH_URL_TEST}}' if '${{ matrix.pytorch-version }}' == '2.3' else '${{env.TORCH_URL_STABLE}}'))" >> $GITHUB_ENV + python -c "print('TORCH_URL=' + str('${{env.TORCH_URL_TEST}}' if '${{ matrix.pytorch-version }}' == '2.4' else '${{env.TORCH_URL_STABLE}}'))" >> $GITHUB_ENV # Switch coverage scope python -c "print('COVERAGE_SCOPE=' + str('lightning' if '${{matrix.pkg-name}}' == 'lightning' else 'pytorch_lightning'))" >> $GITHUB_ENV # if you install mono-package set dependency only for this subpackage python -c "print('EXTRA_PREFIX=' + str('' if '${{matrix.pkg-name}}' != 'lightning' else 'pytorch-'))" >> $GITHUB_ENV + # Avoid issue on Windows with PyTorch 2.4: "RuntimeError: use_libuv was requested but PyTorch was build without libuv support" + python -c "print('USE_LIBUV=0' if '${{matrix.os}}' == 'windows-2022' and '${{matrix.pytorch-version}}' == '2.4' else '')" >> $GITHUB_ENV - name: Install package & dependencies timeout-minutes: 20 diff --git a/.github/workflows/docker-build.yml b/.github/workflows/docker-build.yml index 0891205421..6fa9d0d64d 100644 --- a/.github/workflows/docker-build.yml +++ b/.github/workflows/docker-build.yml @@ -47,6 +47,8 @@ jobs: - { python_version: "3.10", pytorch_version: "2.2", cuda_version: "12.1.0" } - { python_version: "3.11", pytorch_version: "2.1", cuda_version: "12.1.0" } - { python_version: "3.11", pytorch_version: "2.2", cuda_version: "12.1.0" } + - { python_version: "3.11", pytorch_version: "2.3", cuda_version: "12.1.0" } + - { python_version: "3.11", pytorch_version: "2.4", cuda_version: "12.1.0" } steps: - uses: actions/checkout@v4 with: @@ -74,7 +76,7 @@ jobs: tags = [f"latest-py{py_ver}-torch{pt_ver}-cuda{cuda_ver}"] if ver: tags += [f"{ver}-py{py_ver}-torch{pt_ver}-cuda{cuda_ver}"] - if py_ver == '3.10' and pt_ver == '2.1' and cuda_ver == '12.1.0': + if py_ver == '3.11' and pt_ver == '2.3' and cuda_ver == '12.1.0': tags += ["latest"] tags = [f"{repo}:{tag}" for tag in tags] @@ -108,6 +110,7 @@ jobs: - { python_version: "3.11", pytorch_version: "2.1", cuda_version: "12.1.0" } - { python_version: "3.11", pytorch_version: "2.2", cuda_version: "12.1.0" } - { python_version: "3.11", pytorch_version: "2.3", cuda_version: "12.1.0" } + - { python_version: "3.11", pytorch_version: "2.4", cuda_version: "12.1.0" } # - { python_version: "3.12", pytorch_version: "2.2", cuda_version: "12.1.0" } # todo: pending on `onnxruntime` steps: - uses: actions/checkout@v4 diff --git a/dockers/base-cuda/Dockerfile b/dockers/base-cuda/Dockerfile index b8c29d01b0..9153fd33eb 100644 --- a/dockers/base-cuda/Dockerfile +++ b/dockers/base-cuda/Dockerfile @@ -20,7 +20,7 @@ FROM nvidia/cuda:${CUDA_VERSION}-runtime-ubuntu${UBUNTU_VERSION} ARG PYTHON_VERSION=3.10 ARG PYTORCH_VERSION=2.1 -ARG MAX_ALLOWED_NCCL=2.17.1 +ARG MAX_ALLOWED_NCCL=2.22.3 SHELL ["/bin/bash", "-c"] # https://techoverflow.net/2019/05/18/how-to-fix-configuring-tzdata-interactive-input-when-building-docker-images/ @@ -92,7 +92,8 @@ RUN \ -r requirements/pytorch/test.txt \ -r requirements/pytorch/strategies.txt \ --find-links="https://download.pytorch.org/whl/cu${CUDA_VERSION_MM//'.'/''}/torch_stable.html" \ - --find-links="https://download.pytorch.org/whl/test/cu${CUDA_VERSION_MM//'.'/''}/torch_test.html" + --find-links="https://download.pytorch.org/whl/test/cu${CUDA_VERSION_MM//'.'/''}/torch" \ + --find-links="https://download.pytorch.org/whl/test/cu${CUDA_VERSION_MM//'.'/''}/pytorch-triton" RUN \ # Show what we have diff --git a/docs/source-pytorch/versioning.rst b/docs/source-pytorch/versioning.rst index ebae1f920a..d923b01c7e 100644 --- a/docs/source-pytorch/versioning.rst +++ b/docs/source-pytorch/versioning.rst @@ -79,6 +79,18 @@ The table below indicates the coverage of tested versions in our CI. Versions ou - ``torch`` - ``torchmetrics`` - Python + * - 2.4 + - 2.4 + - 2.4 + - ≥2.1, ≤2.4 + - ≥0.7.0 + - ≥3.9, ≤3.12 + * - 2.3 + - 2.3 + - 2.3 + - ≥2.0, ≤2.3 + - ≥0.7.0 + - ≥3.8, ≤3.11 * - 2.2 - 2.2 - 2.2 diff --git a/requirements/pytorch/base.txt b/requirements/pytorch/base.txt index cd71466551..5aae9ae1cb 100644 --- a/requirements/pytorch/base.txt +++ b/requirements/pytorch/base.txt @@ -2,7 +2,7 @@ # in case you want to preserve/enforce restrictions on the latest compatible version, add "strict" as an in-line comment numpy >=1.21.0, <1.27.0 -torch >=2.1.0, <2.4.0 +torch >=2.1.0, <2.5.0 tqdm >=4.57.0, <4.67.0 PyYAML >=5.4, <6.1.0 fsspec[http] >=2022.5.0, <2024.4.0 diff --git a/requirements/pytorch/examples.txt b/requirements/pytorch/examples.txt index e4b1bc31e9..9a6ae7e47d 100644 --- a/requirements/pytorch/examples.txt +++ b/requirements/pytorch/examples.txt @@ -2,7 +2,7 @@ # in case you want to preserve/enforce restrictions on the latest compatible version, add "strict" as an in-line comment requests <2.32.0 -torchvision >=0.16.0, <0.19.0 +torchvision >=0.16.0, <0.20.0 ipython[all] <8.15.0 torchmetrics >=0.10.0, <1.3.0 lightning-utilities >=0.8.0, <0.12.0 diff --git a/src/lightning/fabric/utilities/cloud_io.py b/src/lightning/fabric/utilities/cloud_io.py index 49b02cd095..7ecc9eea50 100644 --- a/src/lightning/fabric/utilities/cloud_io.py +++ b/src/lightning/fabric/utilities/cloud_io.py @@ -33,6 +33,7 @@ log = logging.getLogger(__name__) def _load( path_or_url: Union[IO, _PATH], map_location: _MAP_LOCATION_TYPE = None, + weights_only: bool = False, ) -> Any: """Loads a checkpoint. @@ -46,15 +47,21 @@ def _load( return torch.load( path_or_url, map_location=map_location, # type: ignore[arg-type] # upstream annotation is not correct + weights_only=weights_only, ) if str(path_or_url).startswith("http"): return torch.hub.load_state_dict_from_url( str(path_or_url), map_location=map_location, # type: ignore[arg-type] + weights_only=weights_only, ) fs = get_filesystem(path_or_url) with fs.open(path_or_url, "rb") as f: - return torch.load(f, map_location=map_location) # type: ignore[arg-type] + return torch.load( + f, + map_location=map_location, # type: ignore[arg-type] + weights_only=weights_only, + ) def get_filesystem(path: _PATH, **kwargs: Any) -> AbstractFileSystem: diff --git a/src/lightning/pytorch/plugins/precision/amp.py b/src/lightning/pytorch/plugins/precision/amp.py index c0a309f070..f086de6974 100644 --- a/src/lightning/pytorch/plugins/precision/amp.py +++ b/src/lightning/pytorch/plugins/precision/amp.py @@ -19,6 +19,7 @@ from typing_extensions import override import lightning.pytorch as pl from lightning.fabric.plugins.precision.amp import _optimizer_handles_unscaling +from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_4 from lightning.fabric.utilities.types import Optimizable from lightning.pytorch.plugins.precision.precision import Precision from lightning.pytorch.utilities import GradClipAlgorithmType @@ -39,7 +40,7 @@ class MixedPrecision(Precision): self, precision: Literal["16-mixed", "bf16-mixed"], device: str, - scaler: Optional[torch.cuda.amp.GradScaler] = None, + scaler: Optional["torch.cuda.amp.GradScaler"] = None, ) -> None: if precision not in ("16-mixed", "bf16-mixed"): raise ValueError( @@ -49,7 +50,7 @@ class MixedPrecision(Precision): self.precision = precision if scaler is None and self.precision == "16-mixed": - scaler = torch.cuda.amp.GradScaler() + scaler = torch.amp.GradScaler(device=device) if _TORCH_GREATER_EQUAL_2_4 else torch.cuda.amp.GradScaler() if scaler is not None and self.precision == "bf16-mixed": raise MisconfigurationException(f"`precision='bf16-mixed'` does not use a scaler, found {scaler}.") self.device = device diff --git a/src/lightning/pytorch/profilers/pytorch.py b/src/lightning/pytorch/profilers/pytorch.py index 9bdffad8de..a26b3d321d 100644 --- a/src/lightning/pytorch/profilers/pytorch.py +++ b/src/lightning/pytorch/profilers/pytorch.py @@ -28,6 +28,7 @@ from torch.utils.hooks import RemovableHandle from typing_extensions import override from lightning.fabric.accelerators.cuda import is_cuda_available +from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_4 from lightning.pytorch.profilers.profiler import Profiler from lightning.pytorch.utilities.exceptions import MisconfigurationException from lightning.pytorch.utilities.rank_zero import WarningCache, rank_zero_warn @@ -295,7 +296,7 @@ class PyTorchProfiler(Profiler): self._emit_nvtx = emit_nvtx self._export_to_chrome = export_to_chrome self._row_limit = row_limit - self._sort_by_key = sort_by_key or f"{'cuda' if profiler_kwargs.get('use_cuda', False) else 'cpu'}_time_total" + self._sort_by_key = sort_by_key or _default_sort_by_key(profiler_kwargs) self._record_module_names = record_module_names self._profiler_kwargs = profiler_kwargs self._table_kwargs = table_kwargs if table_kwargs is not None else {} @@ -403,10 +404,16 @@ class PyTorchProfiler(Profiler): activities: List[ProfilerActivity] = [] if not _KINETO_AVAILABLE: return activities - if self._profiler_kwargs.get("use_cpu", True): + if _TORCH_GREATER_EQUAL_2_4: activities.append(ProfilerActivity.CPU) - if self._profiler_kwargs.get("use_cuda", is_cuda_available()): - activities.append(ProfilerActivity.CUDA) + if is_cuda_available(): + activities.append(ProfilerActivity.CUDA) + else: + # `use_cpu` and `use_cuda` are deprecated in PyTorch >= 2.4 + if self._profiler_kwargs.get("use_cpu", True): + activities.append(ProfilerActivity.CPU) + if self._profiler_kwargs.get("use_cuda", is_cuda_available()): + activities.append(ProfilerActivity.CUDA) return activities @override @@ -565,3 +572,13 @@ class PyTorchProfiler(Profiler): self._recording_map = {} super().teardown(stage=stage) + + +def _default_sort_by_key(profiler_kwargs: dict) -> str: + activities = profiler_kwargs.get("activities", []) + is_cuda = ( + profiler_kwargs.get("use_cuda", False) # `use_cuda` is deprecated in PyTorch >= 2.4 + or (activities and ProfilerActivity.CUDA in activities) + or (not activities and is_cuda_available()) + ) + return f"{'cuda' if is_cuda else 'cpu'}_time_total" diff --git a/src/lightning/pytorch/strategies/launchers/multiprocessing.py b/src/lightning/pytorch/strategies/launchers/multiprocessing.py index 7431927df2..05e3fed561 100644 --- a/src/lightning/pytorch/strategies/launchers/multiprocessing.py +++ b/src/lightning/pytorch/strategies/launchers/multiprocessing.py @@ -254,7 +254,7 @@ class _MultiProcessingLauncher(_Launcher): """ # NOTE: `get_extra_results` needs to be called before callback_metrics_bytes = extra["callback_metrics_bytes"] - callback_metrics = torch.load(io.BytesIO(callback_metrics_bytes)) + callback_metrics = torch.load(io.BytesIO(callback_metrics_bytes), weights_only=True) trainer.callback_metrics.update(callback_metrics) @override diff --git a/src/lightning/pytorch/strategies/model_parallel.py b/src/lightning/pytorch/strategies/model_parallel.py index a136a1406e..fb45166378 100644 --- a/src/lightning/pytorch/strategies/model_parallel.py +++ b/src/lightning/pytorch/strategies/model_parallel.py @@ -38,7 +38,7 @@ from lightning.fabric.utilities.distributed import ( _sync_ddp_if_available, ) from lightning.fabric.utilities.distributed import group as _group -from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_3 +from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_4 from lightning.fabric.utilities.init import _materialize_distributed_module from lightning.fabric.utilities.load import _METADATA_FILENAME from lightning.fabric.utilities.optimizer import _optimizers_to_device @@ -64,7 +64,7 @@ class ModelParallelStrategy(ParallelStrategy): Currently supports up to 2D parallelism. Specifically, it supports the combination of Fully Sharded Data-Parallel 2 (FSDP2) with Tensor Parallelism (DTensor). These PyTorch APIs are currently still experimental in PyTorch (see https://pytorch.org/docs/stable/distributed.tensor.parallel.html). - Requires PyTorch 2.3 or newer. + Requires PyTorch 2.4 or newer. Arguments: data_parallel_size: The number of devices within a data-parallel group. Defaults to ``"auto"``, which @@ -86,8 +86,8 @@ class ModelParallelStrategy(ParallelStrategy): timeout: Optional[timedelta] = default_pg_timeout, ) -> None: super().__init__() - if not _TORCH_GREATER_EQUAL_2_3: - raise ImportError(f"{type(self).__name__} requires PyTorch 2.3 or higher.") + if not _TORCH_GREATER_EQUAL_2_4: + raise ImportError(f"{type(self).__name__} requires PyTorch 2.4 or higher.") self._data_parallel_size = data_parallel_size self._tensor_parallel_size = tensor_parallel_size self._save_distributed_checkpoint = save_distributed_checkpoint @@ -170,7 +170,7 @@ class ModelParallelStrategy(ParallelStrategy): if any(isinstance(mod, FullyShardedDataParallel) for mod in self.model.modules()): raise TypeError( "Found modules that are wrapped with `torch.distributed.fsdp.FullyShardedDataParallel`." - f" The `{self.__class__.__name__}` only supports the new FSDP2 APIs in PyTorch >= 2.3." + f" The `{self.__class__.__name__}` only supports the new FSDP2 APIs in PyTorch >= 2.4." ) _materialize_distributed_module(self.model, self.root_device) diff --git a/tests/tests_pytorch/conftest.py b/tests/tests_pytorch/conftest.py index 97c17c4e46..78e81c7c5f 100644 --- a/tests/tests_pytorch/conftest.py +++ b/tests/tests_pytorch/conftest.py @@ -156,6 +156,7 @@ def thread_police_duuu_daaa_duuu_daaa(): elif ( thread.name == "QueueFeederThread" # tensorboardX or thread.name == "QueueManagerThread" # torch.compile + or "(_read_thread)" in thread.name # torch.compile ): thread.join(timeout=20) elif isinstance(thread, TMonitor): diff --git a/tests/tests_pytorch/models/test_torchscript.py b/tests/tests_pytorch/models/test_torchscript.py index a5783f143f..993085729e 100644 --- a/tests/tests_pytorch/models/test_torchscript.py +++ b/tests/tests_pytorch/models/test_torchscript.py @@ -19,6 +19,7 @@ import pytest import torch from fsspec.implementations.local import LocalFileSystem from lightning.fabric.utilities.cloud_io import get_filesystem +from lightning.fabric.utilities.imports import _IS_WINDOWS, _TORCH_GREATER_EQUAL_2_4 from lightning.pytorch.core.module import LightningModule from lightning.pytorch.demos.boring_classes import BoringModel @@ -26,6 +27,7 @@ from tests_pytorch.helpers.advanced_models import BasicGAN, ParityModuleRNN from tests_pytorch.helpers.runif import RunIf +@pytest.mark.skipif(_IS_WINDOWS and _TORCH_GREATER_EQUAL_2_4, reason="not close on Windows + PyTorch 2.4") @pytest.mark.parametrize("modelclass", [BoringModel, ParityModuleRNN, BasicGAN]) def test_torchscript_input_output(modelclass): """Test that scripted LightningModule forward works.""" @@ -45,6 +47,7 @@ def test_torchscript_input_output(modelclass): assert torch.allclose(script_output, model_output) +@pytest.mark.skipif(_IS_WINDOWS and _TORCH_GREATER_EQUAL_2_4, reason="not close on Windows + PyTorch 2.4") @pytest.mark.parametrize("modelclass", [BoringModel, ParityModuleRNN, BasicGAN]) def test_torchscript_example_input_output_trace(modelclass): """Test that traced LightningModule forward works with example_input_array.""" diff --git a/tests/tests_pytorch/plugins/precision/test_amp_integration.py b/tests/tests_pytorch/plugins/precision/test_amp_integration.py index 10257531d5..bc9f779079 100644 --- a/tests/tests_pytorch/plugins/precision/test_amp_integration.py +++ b/tests/tests_pytorch/plugins/precision/test_amp_integration.py @@ -15,6 +15,7 @@ from unittest.mock import Mock import torch from lightning.fabric import seed_everything +from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_4 from lightning.pytorch import Trainer from lightning.pytorch.demos.boring_classes import BoringModel from lightning.pytorch.plugins.precision import MixedPrecision @@ -28,7 +29,8 @@ class FusedOptimizerParityModel(BoringModel): self.fused = fused def configure_optimizers(self): - assert isinstance(self.trainer.precision_plugin.scaler, torch.cuda.amp.GradScaler) + scaler_cls = torch.amp.GradScaler if _TORCH_GREATER_EQUAL_2_4 else torch.cuda.amp.GradScaler + assert isinstance(self.trainer.precision_plugin.scaler, scaler_cls) return torch.optim.Adam(self.parameters(), lr=1.0, fused=self.fused) diff --git a/tests/tests_pytorch/profilers/test_profiler.py b/tests/tests_pytorch/profilers/test_profiler.py index 4d5c9762cc..44ce7c4a3a 100644 --- a/tests/tests_pytorch/profilers/test_profiler.py +++ b/tests/tests_pytorch/profilers/test_profiler.py @@ -21,6 +21,7 @@ from unittest.mock import patch import numpy as np import pytest import torch +from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_4 from lightning.pytorch import Callback, Trainer from lightning.pytorch.callbacks import EarlyStopping, StochasticWeightAveraging from lightning.pytorch.demos.boring_classes import BoringModel, ManualOptimBoringModel @@ -430,7 +431,8 @@ def test_pytorch_profiler_trainer(fn, step_name, boring_model_cls, tmp_path): def test_pytorch_profiler_nested(tmp_path): """Ensure that the profiler handles nested context.""" - pytorch_profiler = PyTorchProfiler(use_cuda=False, dirpath=tmp_path, filename="profiler", schedule=None) + kwargs = {} if _TORCH_GREATER_EQUAL_2_4 else {"use_cuda": False} + pytorch_profiler = PyTorchProfiler(dirpath=tmp_path, filename="profiler", schedule=None, **kwargs) with pytorch_profiler.profile("a"): a = torch.ones(42) @@ -475,13 +477,14 @@ def test_pytorch_profiler_multiple_loggers(tmp_path): def test_register_record_function(tmp_path): use_cuda = torch.cuda.is_available() + kwargs = {} if _TORCH_GREATER_EQUAL_2_4 else {"use_cuda": torch.cuda.is_available()} pytorch_profiler = PyTorchProfiler( export_to_chrome=False, - use_cuda=use_cuda, dirpath=tmp_path, filename="profiler", schedule=None, on_trace_ready=None, + **kwargs, ) class TestModel(BoringModel): diff --git a/tests/tests_pytorch/strategies/test_deepspeed.py b/tests/tests_pytorch/strategies/test_deepspeed.py index 9d7f01096b..be9428ff75 100644 --- a/tests/tests_pytorch/strategies/test_deepspeed.py +++ b/tests/tests_pytorch/strategies/test_deepspeed.py @@ -28,7 +28,7 @@ from lightning.pytorch.callbacks import Callback, LearningRateMonitor, ModelChec from lightning.pytorch.demos.boring_classes import BoringModel, RandomDataset, RandomIterableDataset from lightning.pytorch.loggers import CSVLogger from lightning.pytorch.plugins import DeepSpeedPrecision -from lightning.pytorch.strategies.deepspeed import _DEEPSPEED_AVAILABLE, DeepSpeedStrategy +from lightning.pytorch.strategies.deepspeed import DeepSpeedStrategy from lightning.pytorch.utilities.exceptions import MisconfigurationException from lightning.pytorch.utilities.imports import _TORCHMETRICS_GREATER_EQUAL_0_11 as _TM_GE_0_11 from torch import Tensor, nn @@ -38,11 +38,6 @@ from torchmetrics import Accuracy from tests_pytorch.helpers.datamodules import ClassifDataModule from tests_pytorch.helpers.runif import RunIf -if _DEEPSPEED_AVAILABLE: - import deepspeed - from deepspeed.runtime.zero.stage_1_and_2 import DeepSpeedZeroOptimizer - from deepspeed.utils.zero_to_fp32 import convert_zero_checkpoint_to_fp32_state_dict - class ModelParallelBoringModel(BoringModel): def __init__(self): @@ -245,6 +240,7 @@ def test_deepspeed_auto_batch_size_config_select(_, __, tmp_path, dataset_cls, v def test_deepspeed_run_configure_optimizers(tmp_path): """Test end to end that deepspeed works with defaults (without ZeRO as that requires compilation), whilst using configure_optimizers for optimizers and schedulers.""" + from deepspeed.runtime.zero.stage_1_and_2 import DeepSpeedZeroOptimizer class TestCB(Callback): def on_train_start(self, trainer, pl_module) -> None: @@ -284,6 +280,7 @@ def test_deepspeed_run_configure_optimizers(tmp_path): def test_deepspeed_config(tmp_path, deepspeed_zero_config): """Test to ensure deepspeed works correctly when passed a DeepSpeed config object including optimizers/schedulers and saves the model weights to load correctly.""" + from deepspeed.runtime.zero.stage_1_and_2 import DeepSpeedZeroOptimizer class TestCB(Callback): def on_train_start(self, trainer, pl_module) -> None: @@ -397,6 +394,8 @@ def test_deepspeed_custom_activation_checkpointing_params(): def test_deepspeed_custom_activation_checkpointing_params_forwarded(tmp_path): """Ensure if we modify the activation checkpointing parameters, we pass these to deepspeed.checkpointing.configure correctly.""" + import deepspeed + ds = DeepSpeedStrategy( partition_activations=True, cpu_checkpointing=True, @@ -453,6 +452,8 @@ def test_deepspeed_assert_config_zero_offload_disabled(tmp_path, deepspeed_zero_ @RunIf(min_cuda_gpus=2, standalone=True, deepspeed=True) def test_deepspeed_multigpu(tmp_path): """Test to ensure that DeepSpeed with multiple GPUs works and deepspeed distributed is initialized correctly.""" + import deepspeed + model = BoringModel() trainer = Trainer( default_root_dir=tmp_path, @@ -978,6 +979,8 @@ def test_deepspeed_strategy_env_variables(mock_deepspeed_distributed, tmp_path, def _assert_save_model_is_equal(model, tmp_path, trainer): + from deepspeed.utils.zero_to_fp32 import convert_zero_checkpoint_to_fp32_state_dict + checkpoint_path = os.path.join(tmp_path, "model.pt") checkpoint_path = trainer.strategy.broadcast(checkpoint_path) trainer.save_checkpoint(checkpoint_path) diff --git a/tests/tests_pytorch/strategies/test_fsdp.py b/tests/tests_pytorch/strategies/test_fsdp.py index 94e00f35ff..aec01b83e9 100644 --- a/tests/tests_pytorch/strategies/test_fsdp.py +++ b/tests/tests_pytorch/strategies/test_fsdp.py @@ -217,6 +217,7 @@ def test_custom_mixed_precision(): assert strategy.mixed_precision_config == config +@pytest.mark.filterwarnings("ignore::FutureWarning") @RunIf(min_cuda_gpus=2, skip_windows=True, standalone=True) def test_strategy_sync_batchnorm(tmp_path): """Test to ensure that sync_batchnorm works when using FSDP and GPU, and all stages can be run.""" @@ -233,6 +234,7 @@ def test_strategy_sync_batchnorm(tmp_path): _run_multiple_stages(trainer, model, os.path.join(tmp_path, "last.ckpt")) +@pytest.mark.filterwarnings("ignore::FutureWarning") @RunIf(min_cuda_gpus=1, skip_windows=True) def test_modules_without_parameters(tmp_path): """Test that TorchMetrics get moved to the device despite not having any parameters.""" @@ -263,6 +265,7 @@ def test_modules_without_parameters(tmp_path): trainer.fit(model) +@pytest.mark.filterwarnings("ignore::FutureWarning") @RunIf(min_cuda_gpus=2, skip_windows=True, standalone=True) @pytest.mark.parametrize("precision", ["16-mixed", pytest.param("bf16-mixed", marks=RunIf(bf16_cuda=True))]) @pytest.mark.parametrize("state_dict_type", ["sharded", "full"]) @@ -284,6 +287,7 @@ def custom_auto_wrap_policy( return nonwrapped_numel >= 2 +@pytest.mark.filterwarnings("ignore::FutureWarning") @RunIf(min_cuda_gpus=2, skip_windows=True, standalone=True) @pytest.mark.parametrize("wrap_min_params", [2, 1024, 100000000]) def test_strategy_full_state_dict(tmp_path, wrap_min_params): @@ -319,6 +323,7 @@ def test_strategy_full_state_dict(tmp_path, wrap_min_params): assert all(_ex == _co for _ex, _co in zip(full_state_dict.keys(), correct_state_dict.keys())) +@pytest.mark.filterwarnings("ignore::FutureWarning") @RunIf(min_cuda_gpus=2, skip_windows=True, standalone=True) @pytest.mark.parametrize( ("model", "strategy", "strategy_cfg"), @@ -552,6 +557,7 @@ def test_strategy_load_optimizer_states_multiple(_, tmp_path): strategy.load_checkpoint(tmp_path / "one-state.ckpt") +@pytest.mark.filterwarnings("ignore::FutureWarning") @RunIf(min_cuda_gpus=2, skip_windows=True, standalone=True) @pytest.mark.parametrize("wrap_min_params", [2, 1024, 100000000]) def test_strategy_save_optimizer_states(tmp_path, wrap_min_params): @@ -610,6 +616,7 @@ def test_strategy_save_optimizer_states(tmp_path, wrap_min_params): trainer.strategy.barrier() +@pytest.mark.filterwarnings("ignore::FutureWarning") @RunIf(min_cuda_gpus=2, skip_windows=True, standalone=True) @pytest.mark.parametrize("wrap_min_params", [2, 1024, 100000000]) def test_strategy_load_optimizer_states(wrap_min_params, tmp_path): @@ -808,6 +815,7 @@ class TestFSDPCheckpointModel(BoringModel): torch.testing.assert_close(p0, p1, atol=0, rtol=0, equal_nan=True) +@pytest.mark.filterwarnings("ignore::FutureWarning") @RunIf(min_cuda_gpus=2, standalone=True) def test_save_load_sharded_state_dict(tmp_path): """Test FSDP saving and loading with the sharded state dict format.""" @@ -917,11 +925,20 @@ def test_module_init_context(precision, expected_dtype, tmp_path): _run_setup_assertions(empty_init=True, expected_device=torch.device("meta")) +@pytest.mark.filterwarnings("ignore::FutureWarning") @RunIf(min_cuda_gpus=2, standalone=True, min_torch="2.3.0") def test_save_sharded_and_consolidate_and_load(tmp_path): """Test the consolidation of a FSDP-sharded checkpoint into a single file.""" - model = BoringModel() + class CustomModel(BoringModel): + def configure_optimizers(self): + # Use Adam instead of SGD for this test because it has state + # In PyTorch >= 2.4, saving an optimizer with empty state would result in a `KeyError: 'state'` + # when loading the optimizer state-dict back. + # TODO: To resolve this, switch to the new `torch.distributed.checkpoint` APIs in FSDPStrategy + return torch.optim.Adam(self.parameters(), lr=0.1) + + model = CustomModel() trainer = Trainer( default_root_dir=tmp_path, accelerator="cuda", @@ -942,7 +959,7 @@ def test_save_sharded_and_consolidate_and_load(tmp_path): torch.save(checkpoint, checkpoint_path_full) trainer.strategy.barrier() - model = BoringModel() + model = CustomModel() trainer = Trainer( default_root_dir=tmp_path, accelerator="cuda", diff --git a/tests/tests_pytorch/strategies/test_model_parallel.py b/tests/tests_pytorch/strategies/test_model_parallel.py index 4f30ae8fef..731da66d4a 100644 --- a/tests/tests_pytorch/strategies/test_model_parallel.py +++ b/tests/tests_pytorch/strategies/test_model_parallel.py @@ -28,20 +28,20 @@ from lightning.pytorch.strategies import ModelParallelStrategy from tests_pytorch.helpers.runif import RunIf -@mock.patch("lightning.pytorch.strategies.model_parallel._TORCH_GREATER_EQUAL_2_3", False) -def test_torch_greater_equal_2_3(): - with pytest.raises(ImportError, match="ModelParallelStrategy requires PyTorch 2.3 or higher"): +@mock.patch("lightning.pytorch.strategies.model_parallel._TORCH_GREATER_EQUAL_2_4", False) +def test_torch_greater_equal_2_4(): + with pytest.raises(ImportError, match="ModelParallelStrategy requires PyTorch 2.4 or higher"): ModelParallelStrategy() -@RunIf(min_torch="2.3") +@RunIf(min_torch="2.4") def test_device_mesh_access(): strategy = ModelParallelStrategy() with pytest.raises(RuntimeError, match="Accessing the device mesh .* not allowed"): _ = strategy.device_mesh -@RunIf(min_torch="2.3") +@RunIf(min_torch="2.4") @pytest.mark.parametrize( ("num_nodes", "devices", "invalid_dp_size", "invalid_tp_size"), [ @@ -69,7 +69,7 @@ def test_validate_device_mesh_dimensions(num_nodes, devices, invalid_dp_size, in strategy.setup_environment() -@RunIf(min_torch="2.3") +@RunIf(min_torch="2.4") def test_fsdp_v1_modules_unsupported(): """Test that the strategy won't allow setting up a module wrapped with the legacy FSDP API.""" from torch.distributed.fsdp import FullyShardedDataParallel @@ -85,11 +85,11 @@ def test_fsdp_v1_modules_unsupported(): strategy._lightning_module = model strategy._accelerator = Mock() - with pytest.raises(TypeError, match="only supports the new FSDP2 APIs in PyTorch >= 2.3"): + with pytest.raises(TypeError, match="only supports the new FSDP2 APIs in PyTorch >= 2.4"): strategy.setup(Mock()) -@RunIf(min_torch="2.3") +@RunIf(min_torch="2.4") def test_configure_model_required(): class Model1(LightningModule): pass @@ -114,7 +114,7 @@ def test_configure_model_required(): strategy.setup(Mock()) -@RunIf(min_torch="2.3") +@RunIf(min_torch="2.4") def test_save_checkpoint_storage_options(tmp_path): """Test that the strategy does not accept storage options for saving checkpoints.""" strategy = ModelParallelStrategy() @@ -124,7 +124,7 @@ def test_save_checkpoint_storage_options(tmp_path): strategy.save_checkpoint(checkpoint=Mock(), filepath=tmp_path, storage_options=Mock()) -@RunIf(min_torch="2.3") +@RunIf(min_torch="2.4") @mock.patch("lightning.pytorch.strategies.model_parallel.ModelParallelStrategy.broadcast", lambda _, x: x) @mock.patch("lightning.fabric.plugins.io.torch_io._atomic_save") @mock.patch("lightning.pytorch.strategies.model_parallel.shutil") @@ -174,7 +174,7 @@ def test_save_checkpoint_path_exists(shutil_mock, torch_save_mock, tmp_path): assert path.is_dir() -@RunIf(min_torch="2.3") +@RunIf(min_torch="2.4") @mock.patch("lightning.fabric.strategies.model_parallel._has_dtensor_modules", return_value=True) def test_load_unknown_checkpoint_type(_, tmp_path): """Test that the strategy validates the contents at the checkpoint path.""" @@ -187,7 +187,7 @@ def test_load_unknown_checkpoint_type(_, tmp_path): strategy.load_checkpoint(checkpoint_path=path) -@RunIf(min_torch="2.3") +@RunIf(min_torch="2.4") @mock.patch("lightning.pytorch.strategies.model_parallel._setup_device_mesh") @mock.patch("torch.distributed.init_process_group") def test_set_timeout(init_process_group_mock, _): @@ -207,7 +207,7 @@ def test_set_timeout(init_process_group_mock, _): ) -@RunIf(min_torch="2.3") +@RunIf(min_torch="2.4") def test_meta_device_materialization(): """Test that the `setup()` method materializes meta-device tensors in the LightningModule.""" diff --git a/tests/tests_pytorch/strategies/test_model_parallel_integration.py b/tests/tests_pytorch/strategies/test_model_parallel_integration.py index 5aff576554..5947277b05 100644 --- a/tests/tests_pytorch/strategies/test_model_parallel_integration.py +++ b/tests/tests_pytorch/strategies/test_model_parallel_integration.py @@ -111,7 +111,7 @@ class FSDP2TensorParallelModel(TemplateModel): _parallelize_feed_forward_fsdp2_tp(self.model, device_mesh=self.device_mesh) -@RunIf(min_torch="2.3", standalone=True, min_cuda_gpus=4) +@RunIf(min_torch="2.4", standalone=True, min_cuda_gpus=4) def test_setup_device_mesh(): from torch.distributed.device_mesh import DeviceMesh @@ -168,7 +168,7 @@ def test_setup_device_mesh(): trainer.fit(model) -@RunIf(min_torch="2.3", standalone=True, min_cuda_gpus=2) +@RunIf(min_torch="2.4", standalone=True, min_cuda_gpus=2) def test_tensor_parallel(): from torch.distributed._tensor import DTensor @@ -209,7 +209,7 @@ def test_tensor_parallel(): trainer.fit(model) -@RunIf(min_torch="2.3", standalone=True, min_cuda_gpus=4) +@RunIf(min_torch="2.4", standalone=True, min_cuda_gpus=4) def test_fsdp2_tensor_parallel(): from torch.distributed._tensor import DTensor @@ -266,7 +266,7 @@ def test_fsdp2_tensor_parallel(): trainer.fit(model) -@RunIf(min_torch="2.3", min_cuda_gpus=2, standalone=True) +@RunIf(min_torch="2.4", min_cuda_gpus=2, standalone=True) def test_modules_without_parameters(tmp_path): """Test that TorchMetrics get moved to the device despite not having any parameters.""" @@ -297,7 +297,7 @@ def test_modules_without_parameters(tmp_path): trainer.fit(model) -@RunIf(min_torch="2.3", min_cuda_gpus=2, standalone=True) +@RunIf(min_torch="2.4", min_cuda_gpus=2, standalone=True) @pytest.mark.parametrize( ("precision", "expected_dtype"), [ @@ -343,7 +343,7 @@ def test_module_init_context(precision, expected_dtype, tmp_path): _run_setup_assertions(empty_init=True, expected_device=torch.device("meta")) -@RunIf(min_torch="2.3", min_cuda_gpus=2, skip_windows=True, standalone=True) +@RunIf(min_torch="2.4", min_cuda_gpus=2, skip_windows=True, standalone=True) @pytest.mark.parametrize("save_distributed_checkpoint", [True, False]) def test_strategy_state_dict(tmp_path, save_distributed_checkpoint): """Test that the strategy returns the correct state dict of the LightningModule.""" @@ -377,7 +377,7 @@ def test_strategy_state_dict(tmp_path, save_distributed_checkpoint): assert list(state_dict.keys()) == list(correct_state_dict.keys()) -@RunIf(min_torch="2.3", min_cuda_gpus=2, skip_windows=True, standalone=True) +@RunIf(min_torch="2.4", min_cuda_gpus=2, skip_windows=True, standalone=True) def test_load_full_state_checkpoint_into_regular_model(tmp_path): """Test that a full-state checkpoint saved from a distributed model can be loaded back into a regular model.""" diff --git a/tests/tests_pytorch/trainer/connectors/test_accelerator_connector.py b/tests/tests_pytorch/trainer/connectors/test_accelerator_connector.py index 32f16c92f1..65c5777e28 100644 --- a/tests/tests_pytorch/trainer/connectors/test_accelerator_connector.py +++ b/tests/tests_pytorch/trainer/connectors/test_accelerator_connector.py @@ -1067,7 +1067,7 @@ def test_bitsandbytes_precision_cuda_required(monkeypatch): _AcceleratorConnector(accelerator="cpu", plugins=BitsandbytesPrecision(mode="int8")) -@RunIf(min_torch="2.3") +@RunIf(min_torch="2.4") @pytest.mark.parametrize( ("precision", "raises"), [("32-true", False), ("16-true", False), ("bf16-true", False), ("16-mixed", True), ("bf16-mixed", False)], diff --git a/tests/tests_pytorch/utilities/test_compile.py b/tests/tests_pytorch/utilities/test_compile.py index 9da6c390e5..67f992421f 100644 --- a/tests/tests_pytorch/utilities/test_compile.py +++ b/tests/tests_pytorch/utilities/test_compile.py @@ -11,16 +11,18 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import os import sys +from contextlib import nullcontext from unittest import mock import pytest import torch -from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_2 +from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_2, _TORCH_GREATER_EQUAL_2_4 from lightning.pytorch import LightningModule, Trainer from lightning.pytorch.demos.boring_classes import BoringModel from lightning.pytorch.utilities.compile import from_compiled, to_uncompiled -from lightning_utilities.core import module_available +from lightning_utilities.core.imports import RequirementCache from tests_pytorch.conftest import mock_cuda_count from tests_pytorch.helpers.runif import RunIf @@ -67,10 +69,20 @@ def test_trainer_compiled_model(_, tmp_path, monkeypatch, mps_count_0): assert trainer.model._compiler_ctx is None # some strategies do not support it - if module_available("deepspeed"): + if RequirementCache("deepspeed"): compiled_model = torch.compile(model) mock_cuda_count(monkeypatch, 2) - trainer = Trainer(strategy="deepspeed", accelerator="cuda", **trainer_kwargs) + + # TODO: Update deepspeed to avoid deprecation warning for `torch.cuda.amp.custom_fwd` on import + warn_context = ( + pytest.warns(FutureWarning, match="torch.cuda.amp.*is deprecated") + if _TORCH_GREATER_EQUAL_2_4 + else nullcontext() + ) + + with warn_context: + trainer = Trainer(strategy="deepspeed", accelerator="cuda", **trainer_kwargs) + with pytest.raises(RuntimeError, match="Using a compiled model is incompatible with the current strategy.*"): trainer.fit(compiled_model) @@ -122,6 +134,7 @@ def test_compile_uncompile(): sys.platform == "win32" and _TORCH_GREATER_EQUAL_2_2, strict=False, reason="RuntimeError: Failed to import" ) @RunIf(dynamo=True) +@mock.patch.dict(os.environ, {}) def test_trainer_compiled_model_that_logs(tmp_path): class MyModel(BoringModel): def training_step(self, batch, batch_idx): @@ -152,6 +165,7 @@ def test_trainer_compiled_model_that_logs(tmp_path): sys.platform == "win32" and _TORCH_GREATER_EQUAL_2_2, strict=False, reason="RuntimeError: Failed to import" ) @RunIf(dynamo=True) +@mock.patch.dict(os.environ, {}) def test_trainer_compiled_model_test(tmp_path): model = BoringModel() compiled_model = torch.compile(model)