Add testing for PyTorch 2.4 (Trainer) (#20010)

This commit is contained in:
awaelchli 2024-07-11 12:52:56 +02:00 committed by GitHub
parent 96b75df41a
commit bf25167bbf
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
23 changed files with 157 additions and 58 deletions

View File

@ -55,6 +55,9 @@ jobs:
"Lightning | latest":
image: "pytorchlightning/pytorch_lightning:base-cuda-py3.11-torch2.3-cuda12.1.0"
PACKAGE_NAME: "lightning"
"Lightning | future":
image: "pytorchlightning/pytorch_lightning:base-cuda-py3.11-torch2.4-cuda12.1.0"
PACKAGE_NAME: "lightning"
pool: lit-rtx-3090
variables:
DEVICES: $( python -c 'print("$(Agent.Name)".split("_")[-1])' )
@ -76,9 +79,12 @@ jobs:
echo "##vso[task.setvariable variable=TORCH_URL]https://download.pytorch.org/whl/cu${cuda_ver}/torch_stable.html"
scope=$(python -c 'n = "$(PACKAGE_NAME)" ; print(dict(pytorch="pytorch_lightning").get(n, n))')
echo "##vso[task.setvariable variable=COVERAGE_SOURCE]$scope"
python_ver=$(python -c "import sys; print(f'{sys.version_info.major}{sys.version_info.minor}')")
echo "##vso[task.setvariable variable=PYTHON_VERSION_MM]$python_ver"
displayName: "set env. vars"
- bash: |
echo "##vso[task.setvariable variable=TORCH_URL]https://download.pytorch.org/whl/test/cu${CUDA_VERSION_MM}/torch_test.html"
echo "##vso[task.setvariable variable=TORCH_URL]https://download.pytorch.org/whl/test/cu${CUDA_VERSION_MM}"
echo "##vso[task.setvariable variable=TORCHVISION_URL]https://download.pytorch.org/whl/test/cu124/torchvision-0.19.0%2Bcu124-cp${PYTHON_VERSION_MM}-cp${PYTHON_VERSION_MM}-linux_x86_64.whl"
condition: endsWith(variables['Agent.JobName'], 'future')
displayName: "set env. vars 4 future"
@ -107,7 +113,7 @@ jobs:
- bash: |
extra=$(python -c "print({'lightning': 'pytorch-'}.get('$(PACKAGE_NAME)', ''))")
pip install -e ".[${extra}dev]" pytest-timeout -U --find-links="${TORCH_URL}"
pip install -e ".[${extra}dev]" pytest-timeout -U --find-links="${TORCH_URL}" --find-links="${TORCHVISION_URL}"
displayName: "Install package & dependencies"
- bash: pip uninstall -y lightning

View File

@ -142,11 +142,15 @@ subprojects:
- "build-cuda (3.10, 2.2, 12.1.0)"
- "build-cuda (3.11, 2.1, 12.1.0)"
- "build-cuda (3.11, 2.2, 12.1.0)"
- "build-cuda (3.11, 2.3, 12.1.0)"
- "build-cuda (3.11, 2.4, 12.1.0)"
#- "build-NGC"
- "build-pl (3.10, 2.1, 12.1.0)"
- "build-pl (3.10, 2.2, 12.1.0)"
- "build-pl (3.11, 2.1, 12.1.0)"
- "build-pl (3.11, 2.2, 12.1.0)"
- "build-pl (3.11, 2.3, 12.1.0)"
- "build-pl (3.11, 2.4, 12.1.0)"
# SECTION: lightning_fabric

View File

@ -53,6 +53,9 @@ jobs:
- { os: "macOS-14", pkg-name: "lightning", python-version: "3.10", pytorch-version: "2.3" }
- { os: "ubuntu-20.04", pkg-name: "lightning", python-version: "3.10", pytorch-version: "2.3" }
- { os: "windows-2022", pkg-name: "lightning", python-version: "3.10", pytorch-version: "2.3" }
- { os: "macOS-14", pkg-name: "lightning", python-version: "3.11", pytorch-version: "2.4" }
- { os: "ubuntu-20.04", pkg-name: "lightning", python-version: "3.11", pytorch-version: "2.4" }
- { os: "windows-2022", pkg-name: "lightning", python-version: "3.11", pytorch-version: "2.4" }
# only run PyTorch latest with Python latest, use PyTorch scope to limit dependency issues
- { os: "macOS-12", pkg-name: "pytorch", python-version: "3.11", pytorch-version: "2.1" }
- { os: "ubuntu-22.04", pkg-name: "pytorch", python-version: "3.11", pytorch-version: "2.1" }
@ -82,7 +85,7 @@ jobs:
PACKAGE_NAME: ${{ matrix.pkg-name }}
TORCH_URL: "https://download.pytorch.org/whl/cpu/torch_stable.html"
TORCH_URL_STABLE: "https://download.pytorch.org/whl/cpu/torch_stable.html"
TORCH_URL_TEST: "https://download.pytorch.org/whl/test/cpu/torch_test.html"
TORCH_URL_TEST: "https://download.pytorch.org/whl/test/cpu/torch"
FREEZE_REQUIREMENTS: ${{ ! (github.ref == 'refs/heads/master' || startsWith(github.ref, 'refs/heads/release/')) }}
PYPI_CACHE_DIR: "_pip-wheels"
# TODO: Remove this - Enable running MPS tests on this platform
@ -124,11 +127,13 @@ jobs:
- name: Env. variables
run: |
# Switch PyTorch URL
python -c "print('TORCH_URL=' + str('${{env.TORCH_URL_TEST}}' if '${{ matrix.pytorch-version }}' == '2.3' else '${{env.TORCH_URL_STABLE}}'))" >> $GITHUB_ENV
python -c "print('TORCH_URL=' + str('${{env.TORCH_URL_TEST}}' if '${{ matrix.pytorch-version }}' == '2.4' else '${{env.TORCH_URL_STABLE}}'))" >> $GITHUB_ENV
# Switch coverage scope
python -c "print('COVERAGE_SCOPE=' + str('lightning' if '${{matrix.pkg-name}}' == 'lightning' else 'pytorch_lightning'))" >> $GITHUB_ENV
# if you install mono-package set dependency only for this subpackage
python -c "print('EXTRA_PREFIX=' + str('' if '${{matrix.pkg-name}}' != 'lightning' else 'pytorch-'))" >> $GITHUB_ENV
# Avoid issue on Windows with PyTorch 2.4: "RuntimeError: use_libuv was requested but PyTorch was build without libuv support"
python -c "print('USE_LIBUV=0' if '${{matrix.os}}' == 'windows-2022' and '${{matrix.pytorch-version}}' == '2.4' else '')" >> $GITHUB_ENV
- name: Install package & dependencies
timeout-minutes: 20

View File

@ -47,6 +47,8 @@ jobs:
- { python_version: "3.10", pytorch_version: "2.2", cuda_version: "12.1.0" }
- { python_version: "3.11", pytorch_version: "2.1", cuda_version: "12.1.0" }
- { python_version: "3.11", pytorch_version: "2.2", cuda_version: "12.1.0" }
- { python_version: "3.11", pytorch_version: "2.3", cuda_version: "12.1.0" }
- { python_version: "3.11", pytorch_version: "2.4", cuda_version: "12.1.0" }
steps:
- uses: actions/checkout@v4
with:
@ -74,7 +76,7 @@ jobs:
tags = [f"latest-py{py_ver}-torch{pt_ver}-cuda{cuda_ver}"]
if ver:
tags += [f"{ver}-py{py_ver}-torch{pt_ver}-cuda{cuda_ver}"]
if py_ver == '3.10' and pt_ver == '2.1' and cuda_ver == '12.1.0':
if py_ver == '3.11' and pt_ver == '2.3' and cuda_ver == '12.1.0':
tags += ["latest"]
tags = [f"{repo}:{tag}" for tag in tags]
@ -108,6 +110,7 @@ jobs:
- { python_version: "3.11", pytorch_version: "2.1", cuda_version: "12.1.0" }
- { python_version: "3.11", pytorch_version: "2.2", cuda_version: "12.1.0" }
- { python_version: "3.11", pytorch_version: "2.3", cuda_version: "12.1.0" }
- { python_version: "3.11", pytorch_version: "2.4", cuda_version: "12.1.0" }
# - { python_version: "3.12", pytorch_version: "2.2", cuda_version: "12.1.0" } # todo: pending on `onnxruntime`
steps:
- uses: actions/checkout@v4

View File

@ -20,7 +20,7 @@ FROM nvidia/cuda:${CUDA_VERSION}-runtime-ubuntu${UBUNTU_VERSION}
ARG PYTHON_VERSION=3.10
ARG PYTORCH_VERSION=2.1
ARG MAX_ALLOWED_NCCL=2.17.1
ARG MAX_ALLOWED_NCCL=2.22.3
SHELL ["/bin/bash", "-c"]
# https://techoverflow.net/2019/05/18/how-to-fix-configuring-tzdata-interactive-input-when-building-docker-images/
@ -92,7 +92,8 @@ RUN \
-r requirements/pytorch/test.txt \
-r requirements/pytorch/strategies.txt \
--find-links="https://download.pytorch.org/whl/cu${CUDA_VERSION_MM//'.'/''}/torch_stable.html" \
--find-links="https://download.pytorch.org/whl/test/cu${CUDA_VERSION_MM//'.'/''}/torch_test.html"
--find-links="https://download.pytorch.org/whl/test/cu${CUDA_VERSION_MM//'.'/''}/torch" \
--find-links="https://download.pytorch.org/whl/test/cu${CUDA_VERSION_MM//'.'/''}/pytorch-triton"
RUN \
# Show what we have

View File

@ -79,6 +79,18 @@ The table below indicates the coverage of tested versions in our CI. Versions ou
- ``torch``
- ``torchmetrics``
- Python
* - 2.4
- 2.4
- 2.4
- ≥2.1, ≤2.4
- ≥0.7.0
- ≥3.9, ≤3.12
* - 2.3
- 2.3
- 2.3
- ≥2.0, ≤2.3
- ≥0.7.0
- ≥3.8, ≤3.11
* - 2.2
- 2.2
- 2.2

View File

@ -2,7 +2,7 @@
# in case you want to preserve/enforce restrictions on the latest compatible version, add "strict" as an in-line comment
numpy >=1.21.0, <1.27.0
torch >=2.1.0, <2.4.0
torch >=2.1.0, <2.5.0
tqdm >=4.57.0, <4.67.0
PyYAML >=5.4, <6.1.0
fsspec[http] >=2022.5.0, <2024.4.0

View File

@ -2,7 +2,7 @@
# in case you want to preserve/enforce restrictions on the latest compatible version, add "strict" as an in-line comment
requests <2.32.0
torchvision >=0.16.0, <0.19.0
torchvision >=0.16.0, <0.20.0
ipython[all] <8.15.0
torchmetrics >=0.10.0, <1.3.0
lightning-utilities >=0.8.0, <0.12.0

View File

@ -33,6 +33,7 @@ log = logging.getLogger(__name__)
def _load(
path_or_url: Union[IO, _PATH],
map_location: _MAP_LOCATION_TYPE = None,
weights_only: bool = False,
) -> Any:
"""Loads a checkpoint.
@ -46,15 +47,21 @@ def _load(
return torch.load(
path_or_url,
map_location=map_location, # type: ignore[arg-type] # upstream annotation is not correct
weights_only=weights_only,
)
if str(path_or_url).startswith("http"):
return torch.hub.load_state_dict_from_url(
str(path_or_url),
map_location=map_location, # type: ignore[arg-type]
weights_only=weights_only,
)
fs = get_filesystem(path_or_url)
with fs.open(path_or_url, "rb") as f:
return torch.load(f, map_location=map_location) # type: ignore[arg-type]
return torch.load(
f,
map_location=map_location, # type: ignore[arg-type]
weights_only=weights_only,
)
def get_filesystem(path: _PATH, **kwargs: Any) -> AbstractFileSystem:

View File

@ -19,6 +19,7 @@ from typing_extensions import override
import lightning.pytorch as pl
from lightning.fabric.plugins.precision.amp import _optimizer_handles_unscaling
from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_4
from lightning.fabric.utilities.types import Optimizable
from lightning.pytorch.plugins.precision.precision import Precision
from lightning.pytorch.utilities import GradClipAlgorithmType
@ -39,7 +40,7 @@ class MixedPrecision(Precision):
self,
precision: Literal["16-mixed", "bf16-mixed"],
device: str,
scaler: Optional[torch.cuda.amp.GradScaler] = None,
scaler: Optional["torch.cuda.amp.GradScaler"] = None,
) -> None:
if precision not in ("16-mixed", "bf16-mixed"):
raise ValueError(
@ -49,7 +50,7 @@ class MixedPrecision(Precision):
self.precision = precision
if scaler is None and self.precision == "16-mixed":
scaler = torch.cuda.amp.GradScaler()
scaler = torch.amp.GradScaler(device=device) if _TORCH_GREATER_EQUAL_2_4 else torch.cuda.amp.GradScaler()
if scaler is not None and self.precision == "bf16-mixed":
raise MisconfigurationException(f"`precision='bf16-mixed'` does not use a scaler, found {scaler}.")
self.device = device

View File

@ -28,6 +28,7 @@ from torch.utils.hooks import RemovableHandle
from typing_extensions import override
from lightning.fabric.accelerators.cuda import is_cuda_available
from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_4
from lightning.pytorch.profilers.profiler import Profiler
from lightning.pytorch.utilities.exceptions import MisconfigurationException
from lightning.pytorch.utilities.rank_zero import WarningCache, rank_zero_warn
@ -295,7 +296,7 @@ class PyTorchProfiler(Profiler):
self._emit_nvtx = emit_nvtx
self._export_to_chrome = export_to_chrome
self._row_limit = row_limit
self._sort_by_key = sort_by_key or f"{'cuda' if profiler_kwargs.get('use_cuda', False) else 'cpu'}_time_total"
self._sort_by_key = sort_by_key or _default_sort_by_key(profiler_kwargs)
self._record_module_names = record_module_names
self._profiler_kwargs = profiler_kwargs
self._table_kwargs = table_kwargs if table_kwargs is not None else {}
@ -403,10 +404,16 @@ class PyTorchProfiler(Profiler):
activities: List[ProfilerActivity] = []
if not _KINETO_AVAILABLE:
return activities
if self._profiler_kwargs.get("use_cpu", True):
if _TORCH_GREATER_EQUAL_2_4:
activities.append(ProfilerActivity.CPU)
if self._profiler_kwargs.get("use_cuda", is_cuda_available()):
activities.append(ProfilerActivity.CUDA)
if is_cuda_available():
activities.append(ProfilerActivity.CUDA)
else:
# `use_cpu` and `use_cuda` are deprecated in PyTorch >= 2.4
if self._profiler_kwargs.get("use_cpu", True):
activities.append(ProfilerActivity.CPU)
if self._profiler_kwargs.get("use_cuda", is_cuda_available()):
activities.append(ProfilerActivity.CUDA)
return activities
@override
@ -565,3 +572,13 @@ class PyTorchProfiler(Profiler):
self._recording_map = {}
super().teardown(stage=stage)
def _default_sort_by_key(profiler_kwargs: dict) -> str:
activities = profiler_kwargs.get("activities", [])
is_cuda = (
profiler_kwargs.get("use_cuda", False) # `use_cuda` is deprecated in PyTorch >= 2.4
or (activities and ProfilerActivity.CUDA in activities)
or (not activities and is_cuda_available())
)
return f"{'cuda' if is_cuda else 'cpu'}_time_total"

View File

@ -254,7 +254,7 @@ class _MultiProcessingLauncher(_Launcher):
"""
# NOTE: `get_extra_results` needs to be called before
callback_metrics_bytes = extra["callback_metrics_bytes"]
callback_metrics = torch.load(io.BytesIO(callback_metrics_bytes))
callback_metrics = torch.load(io.BytesIO(callback_metrics_bytes), weights_only=True)
trainer.callback_metrics.update(callback_metrics)
@override

View File

@ -38,7 +38,7 @@ from lightning.fabric.utilities.distributed import (
_sync_ddp_if_available,
)
from lightning.fabric.utilities.distributed import group as _group
from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_3
from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_4
from lightning.fabric.utilities.init import _materialize_distributed_module
from lightning.fabric.utilities.load import _METADATA_FILENAME
from lightning.fabric.utilities.optimizer import _optimizers_to_device
@ -64,7 +64,7 @@ class ModelParallelStrategy(ParallelStrategy):
Currently supports up to 2D parallelism. Specifically, it supports the combination of
Fully Sharded Data-Parallel 2 (FSDP2) with Tensor Parallelism (DTensor). These PyTorch APIs are currently still
experimental in PyTorch (see https://pytorch.org/docs/stable/distributed.tensor.parallel.html).
Requires PyTorch 2.3 or newer.
Requires PyTorch 2.4 or newer.
Arguments:
data_parallel_size: The number of devices within a data-parallel group. Defaults to ``"auto"``, which
@ -86,8 +86,8 @@ class ModelParallelStrategy(ParallelStrategy):
timeout: Optional[timedelta] = default_pg_timeout,
) -> None:
super().__init__()
if not _TORCH_GREATER_EQUAL_2_3:
raise ImportError(f"{type(self).__name__} requires PyTorch 2.3 or higher.")
if not _TORCH_GREATER_EQUAL_2_4:
raise ImportError(f"{type(self).__name__} requires PyTorch 2.4 or higher.")
self._data_parallel_size = data_parallel_size
self._tensor_parallel_size = tensor_parallel_size
self._save_distributed_checkpoint = save_distributed_checkpoint
@ -170,7 +170,7 @@ class ModelParallelStrategy(ParallelStrategy):
if any(isinstance(mod, FullyShardedDataParallel) for mod in self.model.modules()):
raise TypeError(
"Found modules that are wrapped with `torch.distributed.fsdp.FullyShardedDataParallel`."
f" The `{self.__class__.__name__}` only supports the new FSDP2 APIs in PyTorch >= 2.3."
f" The `{self.__class__.__name__}` only supports the new FSDP2 APIs in PyTorch >= 2.4."
)
_materialize_distributed_module(self.model, self.root_device)

View File

@ -156,6 +156,7 @@ def thread_police_duuu_daaa_duuu_daaa():
elif (
thread.name == "QueueFeederThread" # tensorboardX
or thread.name == "QueueManagerThread" # torch.compile
or "(_read_thread)" in thread.name # torch.compile
):
thread.join(timeout=20)
elif isinstance(thread, TMonitor):

View File

@ -19,6 +19,7 @@ import pytest
import torch
from fsspec.implementations.local import LocalFileSystem
from lightning.fabric.utilities.cloud_io import get_filesystem
from lightning.fabric.utilities.imports import _IS_WINDOWS, _TORCH_GREATER_EQUAL_2_4
from lightning.pytorch.core.module import LightningModule
from lightning.pytorch.demos.boring_classes import BoringModel
@ -26,6 +27,7 @@ from tests_pytorch.helpers.advanced_models import BasicGAN, ParityModuleRNN
from tests_pytorch.helpers.runif import RunIf
@pytest.mark.skipif(_IS_WINDOWS and _TORCH_GREATER_EQUAL_2_4, reason="not close on Windows + PyTorch 2.4")
@pytest.mark.parametrize("modelclass", [BoringModel, ParityModuleRNN, BasicGAN])
def test_torchscript_input_output(modelclass):
"""Test that scripted LightningModule forward works."""
@ -45,6 +47,7 @@ def test_torchscript_input_output(modelclass):
assert torch.allclose(script_output, model_output)
@pytest.mark.skipif(_IS_WINDOWS and _TORCH_GREATER_EQUAL_2_4, reason="not close on Windows + PyTorch 2.4")
@pytest.mark.parametrize("modelclass", [BoringModel, ParityModuleRNN, BasicGAN])
def test_torchscript_example_input_output_trace(modelclass):
"""Test that traced LightningModule forward works with example_input_array."""

View File

@ -15,6 +15,7 @@ from unittest.mock import Mock
import torch
from lightning.fabric import seed_everything
from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_4
from lightning.pytorch import Trainer
from lightning.pytorch.demos.boring_classes import BoringModel
from lightning.pytorch.plugins.precision import MixedPrecision
@ -28,7 +29,8 @@ class FusedOptimizerParityModel(BoringModel):
self.fused = fused
def configure_optimizers(self):
assert isinstance(self.trainer.precision_plugin.scaler, torch.cuda.amp.GradScaler)
scaler_cls = torch.amp.GradScaler if _TORCH_GREATER_EQUAL_2_4 else torch.cuda.amp.GradScaler
assert isinstance(self.trainer.precision_plugin.scaler, scaler_cls)
return torch.optim.Adam(self.parameters(), lr=1.0, fused=self.fused)

View File

@ -21,6 +21,7 @@ from unittest.mock import patch
import numpy as np
import pytest
import torch
from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_4
from lightning.pytorch import Callback, Trainer
from lightning.pytorch.callbacks import EarlyStopping, StochasticWeightAveraging
from lightning.pytorch.demos.boring_classes import BoringModel, ManualOptimBoringModel
@ -430,7 +431,8 @@ def test_pytorch_profiler_trainer(fn, step_name, boring_model_cls, tmp_path):
def test_pytorch_profiler_nested(tmp_path):
"""Ensure that the profiler handles nested context."""
pytorch_profiler = PyTorchProfiler(use_cuda=False, dirpath=tmp_path, filename="profiler", schedule=None)
kwargs = {} if _TORCH_GREATER_EQUAL_2_4 else {"use_cuda": False}
pytorch_profiler = PyTorchProfiler(dirpath=tmp_path, filename="profiler", schedule=None, **kwargs)
with pytorch_profiler.profile("a"):
a = torch.ones(42)
@ -475,13 +477,14 @@ def test_pytorch_profiler_multiple_loggers(tmp_path):
def test_register_record_function(tmp_path):
use_cuda = torch.cuda.is_available()
kwargs = {} if _TORCH_GREATER_EQUAL_2_4 else {"use_cuda": torch.cuda.is_available()}
pytorch_profiler = PyTorchProfiler(
export_to_chrome=False,
use_cuda=use_cuda,
dirpath=tmp_path,
filename="profiler",
schedule=None,
on_trace_ready=None,
**kwargs,
)
class TestModel(BoringModel):

View File

@ -28,7 +28,7 @@ from lightning.pytorch.callbacks import Callback, LearningRateMonitor, ModelChec
from lightning.pytorch.demos.boring_classes import BoringModel, RandomDataset, RandomIterableDataset
from lightning.pytorch.loggers import CSVLogger
from lightning.pytorch.plugins import DeepSpeedPrecision
from lightning.pytorch.strategies.deepspeed import _DEEPSPEED_AVAILABLE, DeepSpeedStrategy
from lightning.pytorch.strategies.deepspeed import DeepSpeedStrategy
from lightning.pytorch.utilities.exceptions import MisconfigurationException
from lightning.pytorch.utilities.imports import _TORCHMETRICS_GREATER_EQUAL_0_11 as _TM_GE_0_11
from torch import Tensor, nn
@ -38,11 +38,6 @@ from torchmetrics import Accuracy
from tests_pytorch.helpers.datamodules import ClassifDataModule
from tests_pytorch.helpers.runif import RunIf
if _DEEPSPEED_AVAILABLE:
import deepspeed
from deepspeed.runtime.zero.stage_1_and_2 import DeepSpeedZeroOptimizer
from deepspeed.utils.zero_to_fp32 import convert_zero_checkpoint_to_fp32_state_dict
class ModelParallelBoringModel(BoringModel):
def __init__(self):
@ -245,6 +240,7 @@ def test_deepspeed_auto_batch_size_config_select(_, __, tmp_path, dataset_cls, v
def test_deepspeed_run_configure_optimizers(tmp_path):
"""Test end to end that deepspeed works with defaults (without ZeRO as that requires compilation), whilst using
configure_optimizers for optimizers and schedulers."""
from deepspeed.runtime.zero.stage_1_and_2 import DeepSpeedZeroOptimizer
class TestCB(Callback):
def on_train_start(self, trainer, pl_module) -> None:
@ -284,6 +280,7 @@ def test_deepspeed_run_configure_optimizers(tmp_path):
def test_deepspeed_config(tmp_path, deepspeed_zero_config):
"""Test to ensure deepspeed works correctly when passed a DeepSpeed config object including optimizers/schedulers
and saves the model weights to load correctly."""
from deepspeed.runtime.zero.stage_1_and_2 import DeepSpeedZeroOptimizer
class TestCB(Callback):
def on_train_start(self, trainer, pl_module) -> None:
@ -397,6 +394,8 @@ def test_deepspeed_custom_activation_checkpointing_params():
def test_deepspeed_custom_activation_checkpointing_params_forwarded(tmp_path):
"""Ensure if we modify the activation checkpointing parameters, we pass these to deepspeed.checkpointing.configure
correctly."""
import deepspeed
ds = DeepSpeedStrategy(
partition_activations=True,
cpu_checkpointing=True,
@ -453,6 +452,8 @@ def test_deepspeed_assert_config_zero_offload_disabled(tmp_path, deepspeed_zero_
@RunIf(min_cuda_gpus=2, standalone=True, deepspeed=True)
def test_deepspeed_multigpu(tmp_path):
"""Test to ensure that DeepSpeed with multiple GPUs works and deepspeed distributed is initialized correctly."""
import deepspeed
model = BoringModel()
trainer = Trainer(
default_root_dir=tmp_path,
@ -978,6 +979,8 @@ def test_deepspeed_strategy_env_variables(mock_deepspeed_distributed, tmp_path,
def _assert_save_model_is_equal(model, tmp_path, trainer):
from deepspeed.utils.zero_to_fp32 import convert_zero_checkpoint_to_fp32_state_dict
checkpoint_path = os.path.join(tmp_path, "model.pt")
checkpoint_path = trainer.strategy.broadcast(checkpoint_path)
trainer.save_checkpoint(checkpoint_path)

View File

@ -217,6 +217,7 @@ def test_custom_mixed_precision():
assert strategy.mixed_precision_config == config
@pytest.mark.filterwarnings("ignore::FutureWarning")
@RunIf(min_cuda_gpus=2, skip_windows=True, standalone=True)
def test_strategy_sync_batchnorm(tmp_path):
"""Test to ensure that sync_batchnorm works when using FSDP and GPU, and all stages can be run."""
@ -233,6 +234,7 @@ def test_strategy_sync_batchnorm(tmp_path):
_run_multiple_stages(trainer, model, os.path.join(tmp_path, "last.ckpt"))
@pytest.mark.filterwarnings("ignore::FutureWarning")
@RunIf(min_cuda_gpus=1, skip_windows=True)
def test_modules_without_parameters(tmp_path):
"""Test that TorchMetrics get moved to the device despite not having any parameters."""
@ -263,6 +265,7 @@ def test_modules_without_parameters(tmp_path):
trainer.fit(model)
@pytest.mark.filterwarnings("ignore::FutureWarning")
@RunIf(min_cuda_gpus=2, skip_windows=True, standalone=True)
@pytest.mark.parametrize("precision", ["16-mixed", pytest.param("bf16-mixed", marks=RunIf(bf16_cuda=True))])
@pytest.mark.parametrize("state_dict_type", ["sharded", "full"])
@ -284,6 +287,7 @@ def custom_auto_wrap_policy(
return nonwrapped_numel >= 2
@pytest.mark.filterwarnings("ignore::FutureWarning")
@RunIf(min_cuda_gpus=2, skip_windows=True, standalone=True)
@pytest.mark.parametrize("wrap_min_params", [2, 1024, 100000000])
def test_strategy_full_state_dict(tmp_path, wrap_min_params):
@ -319,6 +323,7 @@ def test_strategy_full_state_dict(tmp_path, wrap_min_params):
assert all(_ex == _co for _ex, _co in zip(full_state_dict.keys(), correct_state_dict.keys()))
@pytest.mark.filterwarnings("ignore::FutureWarning")
@RunIf(min_cuda_gpus=2, skip_windows=True, standalone=True)
@pytest.mark.parametrize(
("model", "strategy", "strategy_cfg"),
@ -552,6 +557,7 @@ def test_strategy_load_optimizer_states_multiple(_, tmp_path):
strategy.load_checkpoint(tmp_path / "one-state.ckpt")
@pytest.mark.filterwarnings("ignore::FutureWarning")
@RunIf(min_cuda_gpus=2, skip_windows=True, standalone=True)
@pytest.mark.parametrize("wrap_min_params", [2, 1024, 100000000])
def test_strategy_save_optimizer_states(tmp_path, wrap_min_params):
@ -610,6 +616,7 @@ def test_strategy_save_optimizer_states(tmp_path, wrap_min_params):
trainer.strategy.barrier()
@pytest.mark.filterwarnings("ignore::FutureWarning")
@RunIf(min_cuda_gpus=2, skip_windows=True, standalone=True)
@pytest.mark.parametrize("wrap_min_params", [2, 1024, 100000000])
def test_strategy_load_optimizer_states(wrap_min_params, tmp_path):
@ -808,6 +815,7 @@ class TestFSDPCheckpointModel(BoringModel):
torch.testing.assert_close(p0, p1, atol=0, rtol=0, equal_nan=True)
@pytest.mark.filterwarnings("ignore::FutureWarning")
@RunIf(min_cuda_gpus=2, standalone=True)
def test_save_load_sharded_state_dict(tmp_path):
"""Test FSDP saving and loading with the sharded state dict format."""
@ -917,11 +925,20 @@ def test_module_init_context(precision, expected_dtype, tmp_path):
_run_setup_assertions(empty_init=True, expected_device=torch.device("meta"))
@pytest.mark.filterwarnings("ignore::FutureWarning")
@RunIf(min_cuda_gpus=2, standalone=True, min_torch="2.3.0")
def test_save_sharded_and_consolidate_and_load(tmp_path):
"""Test the consolidation of a FSDP-sharded checkpoint into a single file."""
model = BoringModel()
class CustomModel(BoringModel):
def configure_optimizers(self):
# Use Adam instead of SGD for this test because it has state
# In PyTorch >= 2.4, saving an optimizer with empty state would result in a `KeyError: 'state'`
# when loading the optimizer state-dict back.
# TODO: To resolve this, switch to the new `torch.distributed.checkpoint` APIs in FSDPStrategy
return torch.optim.Adam(self.parameters(), lr=0.1)
model = CustomModel()
trainer = Trainer(
default_root_dir=tmp_path,
accelerator="cuda",
@ -942,7 +959,7 @@ def test_save_sharded_and_consolidate_and_load(tmp_path):
torch.save(checkpoint, checkpoint_path_full)
trainer.strategy.barrier()
model = BoringModel()
model = CustomModel()
trainer = Trainer(
default_root_dir=tmp_path,
accelerator="cuda",

View File

@ -28,20 +28,20 @@ from lightning.pytorch.strategies import ModelParallelStrategy
from tests_pytorch.helpers.runif import RunIf
@mock.patch("lightning.pytorch.strategies.model_parallel._TORCH_GREATER_EQUAL_2_3", False)
def test_torch_greater_equal_2_3():
with pytest.raises(ImportError, match="ModelParallelStrategy requires PyTorch 2.3 or higher"):
@mock.patch("lightning.pytorch.strategies.model_parallel._TORCH_GREATER_EQUAL_2_4", False)
def test_torch_greater_equal_2_4():
with pytest.raises(ImportError, match="ModelParallelStrategy requires PyTorch 2.4 or higher"):
ModelParallelStrategy()
@RunIf(min_torch="2.3")
@RunIf(min_torch="2.4")
def test_device_mesh_access():
strategy = ModelParallelStrategy()
with pytest.raises(RuntimeError, match="Accessing the device mesh .* not allowed"):
_ = strategy.device_mesh
@RunIf(min_torch="2.3")
@RunIf(min_torch="2.4")
@pytest.mark.parametrize(
("num_nodes", "devices", "invalid_dp_size", "invalid_tp_size"),
[
@ -69,7 +69,7 @@ def test_validate_device_mesh_dimensions(num_nodes, devices, invalid_dp_size, in
strategy.setup_environment()
@RunIf(min_torch="2.3")
@RunIf(min_torch="2.4")
def test_fsdp_v1_modules_unsupported():
"""Test that the strategy won't allow setting up a module wrapped with the legacy FSDP API."""
from torch.distributed.fsdp import FullyShardedDataParallel
@ -85,11 +85,11 @@ def test_fsdp_v1_modules_unsupported():
strategy._lightning_module = model
strategy._accelerator = Mock()
with pytest.raises(TypeError, match="only supports the new FSDP2 APIs in PyTorch >= 2.3"):
with pytest.raises(TypeError, match="only supports the new FSDP2 APIs in PyTorch >= 2.4"):
strategy.setup(Mock())
@RunIf(min_torch="2.3")
@RunIf(min_torch="2.4")
def test_configure_model_required():
class Model1(LightningModule):
pass
@ -114,7 +114,7 @@ def test_configure_model_required():
strategy.setup(Mock())
@RunIf(min_torch="2.3")
@RunIf(min_torch="2.4")
def test_save_checkpoint_storage_options(tmp_path):
"""Test that the strategy does not accept storage options for saving checkpoints."""
strategy = ModelParallelStrategy()
@ -124,7 +124,7 @@ def test_save_checkpoint_storage_options(tmp_path):
strategy.save_checkpoint(checkpoint=Mock(), filepath=tmp_path, storage_options=Mock())
@RunIf(min_torch="2.3")
@RunIf(min_torch="2.4")
@mock.patch("lightning.pytorch.strategies.model_parallel.ModelParallelStrategy.broadcast", lambda _, x: x)
@mock.patch("lightning.fabric.plugins.io.torch_io._atomic_save")
@mock.patch("lightning.pytorch.strategies.model_parallel.shutil")
@ -174,7 +174,7 @@ def test_save_checkpoint_path_exists(shutil_mock, torch_save_mock, tmp_path):
assert path.is_dir()
@RunIf(min_torch="2.3")
@RunIf(min_torch="2.4")
@mock.patch("lightning.fabric.strategies.model_parallel._has_dtensor_modules", return_value=True)
def test_load_unknown_checkpoint_type(_, tmp_path):
"""Test that the strategy validates the contents at the checkpoint path."""
@ -187,7 +187,7 @@ def test_load_unknown_checkpoint_type(_, tmp_path):
strategy.load_checkpoint(checkpoint_path=path)
@RunIf(min_torch="2.3")
@RunIf(min_torch="2.4")
@mock.patch("lightning.pytorch.strategies.model_parallel._setup_device_mesh")
@mock.patch("torch.distributed.init_process_group")
def test_set_timeout(init_process_group_mock, _):
@ -207,7 +207,7 @@ def test_set_timeout(init_process_group_mock, _):
)
@RunIf(min_torch="2.3")
@RunIf(min_torch="2.4")
def test_meta_device_materialization():
"""Test that the `setup()` method materializes meta-device tensors in the LightningModule."""

View File

@ -111,7 +111,7 @@ class FSDP2TensorParallelModel(TemplateModel):
_parallelize_feed_forward_fsdp2_tp(self.model, device_mesh=self.device_mesh)
@RunIf(min_torch="2.3", standalone=True, min_cuda_gpus=4)
@RunIf(min_torch="2.4", standalone=True, min_cuda_gpus=4)
def test_setup_device_mesh():
from torch.distributed.device_mesh import DeviceMesh
@ -168,7 +168,7 @@ def test_setup_device_mesh():
trainer.fit(model)
@RunIf(min_torch="2.3", standalone=True, min_cuda_gpus=2)
@RunIf(min_torch="2.4", standalone=True, min_cuda_gpus=2)
def test_tensor_parallel():
from torch.distributed._tensor import DTensor
@ -209,7 +209,7 @@ def test_tensor_parallel():
trainer.fit(model)
@RunIf(min_torch="2.3", standalone=True, min_cuda_gpus=4)
@RunIf(min_torch="2.4", standalone=True, min_cuda_gpus=4)
def test_fsdp2_tensor_parallel():
from torch.distributed._tensor import DTensor
@ -266,7 +266,7 @@ def test_fsdp2_tensor_parallel():
trainer.fit(model)
@RunIf(min_torch="2.3", min_cuda_gpus=2, standalone=True)
@RunIf(min_torch="2.4", min_cuda_gpus=2, standalone=True)
def test_modules_without_parameters(tmp_path):
"""Test that TorchMetrics get moved to the device despite not having any parameters."""
@ -297,7 +297,7 @@ def test_modules_without_parameters(tmp_path):
trainer.fit(model)
@RunIf(min_torch="2.3", min_cuda_gpus=2, standalone=True)
@RunIf(min_torch="2.4", min_cuda_gpus=2, standalone=True)
@pytest.mark.parametrize(
("precision", "expected_dtype"),
[
@ -343,7 +343,7 @@ def test_module_init_context(precision, expected_dtype, tmp_path):
_run_setup_assertions(empty_init=True, expected_device=torch.device("meta"))
@RunIf(min_torch="2.3", min_cuda_gpus=2, skip_windows=True, standalone=True)
@RunIf(min_torch="2.4", min_cuda_gpus=2, skip_windows=True, standalone=True)
@pytest.mark.parametrize("save_distributed_checkpoint", [True, False])
def test_strategy_state_dict(tmp_path, save_distributed_checkpoint):
"""Test that the strategy returns the correct state dict of the LightningModule."""
@ -377,7 +377,7 @@ def test_strategy_state_dict(tmp_path, save_distributed_checkpoint):
assert list(state_dict.keys()) == list(correct_state_dict.keys())
@RunIf(min_torch="2.3", min_cuda_gpus=2, skip_windows=True, standalone=True)
@RunIf(min_torch="2.4", min_cuda_gpus=2, skip_windows=True, standalone=True)
def test_load_full_state_checkpoint_into_regular_model(tmp_path):
"""Test that a full-state checkpoint saved from a distributed model can be loaded back into a regular model."""

View File

@ -1067,7 +1067,7 @@ def test_bitsandbytes_precision_cuda_required(monkeypatch):
_AcceleratorConnector(accelerator="cpu", plugins=BitsandbytesPrecision(mode="int8"))
@RunIf(min_torch="2.3")
@RunIf(min_torch="2.4")
@pytest.mark.parametrize(
("precision", "raises"),
[("32-true", False), ("16-true", False), ("bf16-true", False), ("16-mixed", True), ("bf16-mixed", False)],

View File

@ -11,16 +11,18 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import sys
from contextlib import nullcontext
from unittest import mock
import pytest
import torch
from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_2
from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_2, _TORCH_GREATER_EQUAL_2_4
from lightning.pytorch import LightningModule, Trainer
from lightning.pytorch.demos.boring_classes import BoringModel
from lightning.pytorch.utilities.compile import from_compiled, to_uncompiled
from lightning_utilities.core import module_available
from lightning_utilities.core.imports import RequirementCache
from tests_pytorch.conftest import mock_cuda_count
from tests_pytorch.helpers.runif import RunIf
@ -67,10 +69,20 @@ def test_trainer_compiled_model(_, tmp_path, monkeypatch, mps_count_0):
assert trainer.model._compiler_ctx is None
# some strategies do not support it
if module_available("deepspeed"):
if RequirementCache("deepspeed"):
compiled_model = torch.compile(model)
mock_cuda_count(monkeypatch, 2)
trainer = Trainer(strategy="deepspeed", accelerator="cuda", **trainer_kwargs)
# TODO: Update deepspeed to avoid deprecation warning for `torch.cuda.amp.custom_fwd` on import
warn_context = (
pytest.warns(FutureWarning, match="torch.cuda.amp.*is deprecated")
if _TORCH_GREATER_EQUAL_2_4
else nullcontext()
)
with warn_context:
trainer = Trainer(strategy="deepspeed", accelerator="cuda", **trainer_kwargs)
with pytest.raises(RuntimeError, match="Using a compiled model is incompatible with the current strategy.*"):
trainer.fit(compiled_model)
@ -122,6 +134,7 @@ def test_compile_uncompile():
sys.platform == "win32" and _TORCH_GREATER_EQUAL_2_2, strict=False, reason="RuntimeError: Failed to import"
)
@RunIf(dynamo=True)
@mock.patch.dict(os.environ, {})
def test_trainer_compiled_model_that_logs(tmp_path):
class MyModel(BoringModel):
def training_step(self, batch, batch_idx):
@ -152,6 +165,7 @@ def test_trainer_compiled_model_that_logs(tmp_path):
sys.platform == "win32" and _TORCH_GREATER_EQUAL_2_2, strict=False, reason="RuntimeError: Failed to import"
)
@RunIf(dynamo=True)
@mock.patch.dict(os.environ, {})
def test_trainer_compiled_model_test(tmp_path):
model = BoringModel()
compiled_model = torch.compile(model)