Add testing for PyTorch 2.4 (Trainer) (#20010)
This commit is contained in:
parent
96b75df41a
commit
bf25167bbf
|
@ -55,6 +55,9 @@ jobs:
|
|||
"Lightning | latest":
|
||||
image: "pytorchlightning/pytorch_lightning:base-cuda-py3.11-torch2.3-cuda12.1.0"
|
||||
PACKAGE_NAME: "lightning"
|
||||
"Lightning | future":
|
||||
image: "pytorchlightning/pytorch_lightning:base-cuda-py3.11-torch2.4-cuda12.1.0"
|
||||
PACKAGE_NAME: "lightning"
|
||||
pool: lit-rtx-3090
|
||||
variables:
|
||||
DEVICES: $( python -c 'print("$(Agent.Name)".split("_")[-1])' )
|
||||
|
@ -76,9 +79,12 @@ jobs:
|
|||
echo "##vso[task.setvariable variable=TORCH_URL]https://download.pytorch.org/whl/cu${cuda_ver}/torch_stable.html"
|
||||
scope=$(python -c 'n = "$(PACKAGE_NAME)" ; print(dict(pytorch="pytorch_lightning").get(n, n))')
|
||||
echo "##vso[task.setvariable variable=COVERAGE_SOURCE]$scope"
|
||||
python_ver=$(python -c "import sys; print(f'{sys.version_info.major}{sys.version_info.minor}')")
|
||||
echo "##vso[task.setvariable variable=PYTHON_VERSION_MM]$python_ver"
|
||||
displayName: "set env. vars"
|
||||
- bash: |
|
||||
echo "##vso[task.setvariable variable=TORCH_URL]https://download.pytorch.org/whl/test/cu${CUDA_VERSION_MM}/torch_test.html"
|
||||
echo "##vso[task.setvariable variable=TORCH_URL]https://download.pytorch.org/whl/test/cu${CUDA_VERSION_MM}"
|
||||
echo "##vso[task.setvariable variable=TORCHVISION_URL]https://download.pytorch.org/whl/test/cu124/torchvision-0.19.0%2Bcu124-cp${PYTHON_VERSION_MM}-cp${PYTHON_VERSION_MM}-linux_x86_64.whl"
|
||||
condition: endsWith(variables['Agent.JobName'], 'future')
|
||||
displayName: "set env. vars 4 future"
|
||||
|
||||
|
@ -107,7 +113,7 @@ jobs:
|
|||
|
||||
- bash: |
|
||||
extra=$(python -c "print({'lightning': 'pytorch-'}.get('$(PACKAGE_NAME)', ''))")
|
||||
pip install -e ".[${extra}dev]" pytest-timeout -U --find-links="${TORCH_URL}"
|
||||
pip install -e ".[${extra}dev]" pytest-timeout -U --find-links="${TORCH_URL}" --find-links="${TORCHVISION_URL}"
|
||||
displayName: "Install package & dependencies"
|
||||
|
||||
- bash: pip uninstall -y lightning
|
||||
|
|
|
@ -142,11 +142,15 @@ subprojects:
|
|||
- "build-cuda (3.10, 2.2, 12.1.0)"
|
||||
- "build-cuda (3.11, 2.1, 12.1.0)"
|
||||
- "build-cuda (3.11, 2.2, 12.1.0)"
|
||||
- "build-cuda (3.11, 2.3, 12.1.0)"
|
||||
- "build-cuda (3.11, 2.4, 12.1.0)"
|
||||
#- "build-NGC"
|
||||
- "build-pl (3.10, 2.1, 12.1.0)"
|
||||
- "build-pl (3.10, 2.2, 12.1.0)"
|
||||
- "build-pl (3.11, 2.1, 12.1.0)"
|
||||
- "build-pl (3.11, 2.2, 12.1.0)"
|
||||
- "build-pl (3.11, 2.3, 12.1.0)"
|
||||
- "build-pl (3.11, 2.4, 12.1.0)"
|
||||
|
||||
# SECTION: lightning_fabric
|
||||
|
||||
|
|
|
@ -53,6 +53,9 @@ jobs:
|
|||
- { os: "macOS-14", pkg-name: "lightning", python-version: "3.10", pytorch-version: "2.3" }
|
||||
- { os: "ubuntu-20.04", pkg-name: "lightning", python-version: "3.10", pytorch-version: "2.3" }
|
||||
- { os: "windows-2022", pkg-name: "lightning", python-version: "3.10", pytorch-version: "2.3" }
|
||||
- { os: "macOS-14", pkg-name: "lightning", python-version: "3.11", pytorch-version: "2.4" }
|
||||
- { os: "ubuntu-20.04", pkg-name: "lightning", python-version: "3.11", pytorch-version: "2.4" }
|
||||
- { os: "windows-2022", pkg-name: "lightning", python-version: "3.11", pytorch-version: "2.4" }
|
||||
# only run PyTorch latest with Python latest, use PyTorch scope to limit dependency issues
|
||||
- { os: "macOS-12", pkg-name: "pytorch", python-version: "3.11", pytorch-version: "2.1" }
|
||||
- { os: "ubuntu-22.04", pkg-name: "pytorch", python-version: "3.11", pytorch-version: "2.1" }
|
||||
|
@ -82,7 +85,7 @@ jobs:
|
|||
PACKAGE_NAME: ${{ matrix.pkg-name }}
|
||||
TORCH_URL: "https://download.pytorch.org/whl/cpu/torch_stable.html"
|
||||
TORCH_URL_STABLE: "https://download.pytorch.org/whl/cpu/torch_stable.html"
|
||||
TORCH_URL_TEST: "https://download.pytorch.org/whl/test/cpu/torch_test.html"
|
||||
TORCH_URL_TEST: "https://download.pytorch.org/whl/test/cpu/torch"
|
||||
FREEZE_REQUIREMENTS: ${{ ! (github.ref == 'refs/heads/master' || startsWith(github.ref, 'refs/heads/release/')) }}
|
||||
PYPI_CACHE_DIR: "_pip-wheels"
|
||||
# TODO: Remove this - Enable running MPS tests on this platform
|
||||
|
@ -124,11 +127,13 @@ jobs:
|
|||
- name: Env. variables
|
||||
run: |
|
||||
# Switch PyTorch URL
|
||||
python -c "print('TORCH_URL=' + str('${{env.TORCH_URL_TEST}}' if '${{ matrix.pytorch-version }}' == '2.3' else '${{env.TORCH_URL_STABLE}}'))" >> $GITHUB_ENV
|
||||
python -c "print('TORCH_URL=' + str('${{env.TORCH_URL_TEST}}' if '${{ matrix.pytorch-version }}' == '2.4' else '${{env.TORCH_URL_STABLE}}'))" >> $GITHUB_ENV
|
||||
# Switch coverage scope
|
||||
python -c "print('COVERAGE_SCOPE=' + str('lightning' if '${{matrix.pkg-name}}' == 'lightning' else 'pytorch_lightning'))" >> $GITHUB_ENV
|
||||
# if you install mono-package set dependency only for this subpackage
|
||||
python -c "print('EXTRA_PREFIX=' + str('' if '${{matrix.pkg-name}}' != 'lightning' else 'pytorch-'))" >> $GITHUB_ENV
|
||||
# Avoid issue on Windows with PyTorch 2.4: "RuntimeError: use_libuv was requested but PyTorch was build without libuv support"
|
||||
python -c "print('USE_LIBUV=0' if '${{matrix.os}}' == 'windows-2022' and '${{matrix.pytorch-version}}' == '2.4' else '')" >> $GITHUB_ENV
|
||||
|
||||
- name: Install package & dependencies
|
||||
timeout-minutes: 20
|
||||
|
|
|
@ -47,6 +47,8 @@ jobs:
|
|||
- { python_version: "3.10", pytorch_version: "2.2", cuda_version: "12.1.0" }
|
||||
- { python_version: "3.11", pytorch_version: "2.1", cuda_version: "12.1.0" }
|
||||
- { python_version: "3.11", pytorch_version: "2.2", cuda_version: "12.1.0" }
|
||||
- { python_version: "3.11", pytorch_version: "2.3", cuda_version: "12.1.0" }
|
||||
- { python_version: "3.11", pytorch_version: "2.4", cuda_version: "12.1.0" }
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
with:
|
||||
|
@ -74,7 +76,7 @@ jobs:
|
|||
tags = [f"latest-py{py_ver}-torch{pt_ver}-cuda{cuda_ver}"]
|
||||
if ver:
|
||||
tags += [f"{ver}-py{py_ver}-torch{pt_ver}-cuda{cuda_ver}"]
|
||||
if py_ver == '3.10' and pt_ver == '2.1' and cuda_ver == '12.1.0':
|
||||
if py_ver == '3.11' and pt_ver == '2.3' and cuda_ver == '12.1.0':
|
||||
tags += ["latest"]
|
||||
|
||||
tags = [f"{repo}:{tag}" for tag in tags]
|
||||
|
@ -108,6 +110,7 @@ jobs:
|
|||
- { python_version: "3.11", pytorch_version: "2.1", cuda_version: "12.1.0" }
|
||||
- { python_version: "3.11", pytorch_version: "2.2", cuda_version: "12.1.0" }
|
||||
- { python_version: "3.11", pytorch_version: "2.3", cuda_version: "12.1.0" }
|
||||
- { python_version: "3.11", pytorch_version: "2.4", cuda_version: "12.1.0" }
|
||||
# - { python_version: "3.12", pytorch_version: "2.2", cuda_version: "12.1.0" } # todo: pending on `onnxruntime`
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
|
|
|
@ -20,7 +20,7 @@ FROM nvidia/cuda:${CUDA_VERSION}-runtime-ubuntu${UBUNTU_VERSION}
|
|||
|
||||
ARG PYTHON_VERSION=3.10
|
||||
ARG PYTORCH_VERSION=2.1
|
||||
ARG MAX_ALLOWED_NCCL=2.17.1
|
||||
ARG MAX_ALLOWED_NCCL=2.22.3
|
||||
|
||||
SHELL ["/bin/bash", "-c"]
|
||||
# https://techoverflow.net/2019/05/18/how-to-fix-configuring-tzdata-interactive-input-when-building-docker-images/
|
||||
|
@ -92,7 +92,8 @@ RUN \
|
|||
-r requirements/pytorch/test.txt \
|
||||
-r requirements/pytorch/strategies.txt \
|
||||
--find-links="https://download.pytorch.org/whl/cu${CUDA_VERSION_MM//'.'/''}/torch_stable.html" \
|
||||
--find-links="https://download.pytorch.org/whl/test/cu${CUDA_VERSION_MM//'.'/''}/torch_test.html"
|
||||
--find-links="https://download.pytorch.org/whl/test/cu${CUDA_VERSION_MM//'.'/''}/torch" \
|
||||
--find-links="https://download.pytorch.org/whl/test/cu${CUDA_VERSION_MM//'.'/''}/pytorch-triton"
|
||||
|
||||
RUN \
|
||||
# Show what we have
|
||||
|
|
|
@ -79,6 +79,18 @@ The table below indicates the coverage of tested versions in our CI. Versions ou
|
|||
- ``torch``
|
||||
- ``torchmetrics``
|
||||
- Python
|
||||
* - 2.4
|
||||
- 2.4
|
||||
- 2.4
|
||||
- ≥2.1, ≤2.4
|
||||
- ≥0.7.0
|
||||
- ≥3.9, ≤3.12
|
||||
* - 2.3
|
||||
- 2.3
|
||||
- 2.3
|
||||
- ≥2.0, ≤2.3
|
||||
- ≥0.7.0
|
||||
- ≥3.8, ≤3.11
|
||||
* - 2.2
|
||||
- 2.2
|
||||
- 2.2
|
||||
|
|
|
@ -2,7 +2,7 @@
|
|||
# in case you want to preserve/enforce restrictions on the latest compatible version, add "strict" as an in-line comment
|
||||
|
||||
numpy >=1.21.0, <1.27.0
|
||||
torch >=2.1.0, <2.4.0
|
||||
torch >=2.1.0, <2.5.0
|
||||
tqdm >=4.57.0, <4.67.0
|
||||
PyYAML >=5.4, <6.1.0
|
||||
fsspec[http] >=2022.5.0, <2024.4.0
|
||||
|
|
|
@ -2,7 +2,7 @@
|
|||
# in case you want to preserve/enforce restrictions on the latest compatible version, add "strict" as an in-line comment
|
||||
|
||||
requests <2.32.0
|
||||
torchvision >=0.16.0, <0.19.0
|
||||
torchvision >=0.16.0, <0.20.0
|
||||
ipython[all] <8.15.0
|
||||
torchmetrics >=0.10.0, <1.3.0
|
||||
lightning-utilities >=0.8.0, <0.12.0
|
||||
|
|
|
@ -33,6 +33,7 @@ log = logging.getLogger(__name__)
|
|||
def _load(
|
||||
path_or_url: Union[IO, _PATH],
|
||||
map_location: _MAP_LOCATION_TYPE = None,
|
||||
weights_only: bool = False,
|
||||
) -> Any:
|
||||
"""Loads a checkpoint.
|
||||
|
||||
|
@ -46,15 +47,21 @@ def _load(
|
|||
return torch.load(
|
||||
path_or_url,
|
||||
map_location=map_location, # type: ignore[arg-type] # upstream annotation is not correct
|
||||
weights_only=weights_only,
|
||||
)
|
||||
if str(path_or_url).startswith("http"):
|
||||
return torch.hub.load_state_dict_from_url(
|
||||
str(path_or_url),
|
||||
map_location=map_location, # type: ignore[arg-type]
|
||||
weights_only=weights_only,
|
||||
)
|
||||
fs = get_filesystem(path_or_url)
|
||||
with fs.open(path_or_url, "rb") as f:
|
||||
return torch.load(f, map_location=map_location) # type: ignore[arg-type]
|
||||
return torch.load(
|
||||
f,
|
||||
map_location=map_location, # type: ignore[arg-type]
|
||||
weights_only=weights_only,
|
||||
)
|
||||
|
||||
|
||||
def get_filesystem(path: _PATH, **kwargs: Any) -> AbstractFileSystem:
|
||||
|
|
|
@ -19,6 +19,7 @@ from typing_extensions import override
|
|||
|
||||
import lightning.pytorch as pl
|
||||
from lightning.fabric.plugins.precision.amp import _optimizer_handles_unscaling
|
||||
from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_4
|
||||
from lightning.fabric.utilities.types import Optimizable
|
||||
from lightning.pytorch.plugins.precision.precision import Precision
|
||||
from lightning.pytorch.utilities import GradClipAlgorithmType
|
||||
|
@ -39,7 +40,7 @@ class MixedPrecision(Precision):
|
|||
self,
|
||||
precision: Literal["16-mixed", "bf16-mixed"],
|
||||
device: str,
|
||||
scaler: Optional[torch.cuda.amp.GradScaler] = None,
|
||||
scaler: Optional["torch.cuda.amp.GradScaler"] = None,
|
||||
) -> None:
|
||||
if precision not in ("16-mixed", "bf16-mixed"):
|
||||
raise ValueError(
|
||||
|
@ -49,7 +50,7 @@ class MixedPrecision(Precision):
|
|||
|
||||
self.precision = precision
|
||||
if scaler is None and self.precision == "16-mixed":
|
||||
scaler = torch.cuda.amp.GradScaler()
|
||||
scaler = torch.amp.GradScaler(device=device) if _TORCH_GREATER_EQUAL_2_4 else torch.cuda.amp.GradScaler()
|
||||
if scaler is not None and self.precision == "bf16-mixed":
|
||||
raise MisconfigurationException(f"`precision='bf16-mixed'` does not use a scaler, found {scaler}.")
|
||||
self.device = device
|
||||
|
|
|
@ -28,6 +28,7 @@ from torch.utils.hooks import RemovableHandle
|
|||
from typing_extensions import override
|
||||
|
||||
from lightning.fabric.accelerators.cuda import is_cuda_available
|
||||
from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_4
|
||||
from lightning.pytorch.profilers.profiler import Profiler
|
||||
from lightning.pytorch.utilities.exceptions import MisconfigurationException
|
||||
from lightning.pytorch.utilities.rank_zero import WarningCache, rank_zero_warn
|
||||
|
@ -295,7 +296,7 @@ class PyTorchProfiler(Profiler):
|
|||
self._emit_nvtx = emit_nvtx
|
||||
self._export_to_chrome = export_to_chrome
|
||||
self._row_limit = row_limit
|
||||
self._sort_by_key = sort_by_key or f"{'cuda' if profiler_kwargs.get('use_cuda', False) else 'cpu'}_time_total"
|
||||
self._sort_by_key = sort_by_key or _default_sort_by_key(profiler_kwargs)
|
||||
self._record_module_names = record_module_names
|
||||
self._profiler_kwargs = profiler_kwargs
|
||||
self._table_kwargs = table_kwargs if table_kwargs is not None else {}
|
||||
|
@ -403,10 +404,16 @@ class PyTorchProfiler(Profiler):
|
|||
activities: List[ProfilerActivity] = []
|
||||
if not _KINETO_AVAILABLE:
|
||||
return activities
|
||||
if self._profiler_kwargs.get("use_cpu", True):
|
||||
if _TORCH_GREATER_EQUAL_2_4:
|
||||
activities.append(ProfilerActivity.CPU)
|
||||
if self._profiler_kwargs.get("use_cuda", is_cuda_available()):
|
||||
activities.append(ProfilerActivity.CUDA)
|
||||
if is_cuda_available():
|
||||
activities.append(ProfilerActivity.CUDA)
|
||||
else:
|
||||
# `use_cpu` and `use_cuda` are deprecated in PyTorch >= 2.4
|
||||
if self._profiler_kwargs.get("use_cpu", True):
|
||||
activities.append(ProfilerActivity.CPU)
|
||||
if self._profiler_kwargs.get("use_cuda", is_cuda_available()):
|
||||
activities.append(ProfilerActivity.CUDA)
|
||||
return activities
|
||||
|
||||
@override
|
||||
|
@ -565,3 +572,13 @@ class PyTorchProfiler(Profiler):
|
|||
self._recording_map = {}
|
||||
|
||||
super().teardown(stage=stage)
|
||||
|
||||
|
||||
def _default_sort_by_key(profiler_kwargs: dict) -> str:
|
||||
activities = profiler_kwargs.get("activities", [])
|
||||
is_cuda = (
|
||||
profiler_kwargs.get("use_cuda", False) # `use_cuda` is deprecated in PyTorch >= 2.4
|
||||
or (activities and ProfilerActivity.CUDA in activities)
|
||||
or (not activities and is_cuda_available())
|
||||
)
|
||||
return f"{'cuda' if is_cuda else 'cpu'}_time_total"
|
||||
|
|
|
@ -254,7 +254,7 @@ class _MultiProcessingLauncher(_Launcher):
|
|||
"""
|
||||
# NOTE: `get_extra_results` needs to be called before
|
||||
callback_metrics_bytes = extra["callback_metrics_bytes"]
|
||||
callback_metrics = torch.load(io.BytesIO(callback_metrics_bytes))
|
||||
callback_metrics = torch.load(io.BytesIO(callback_metrics_bytes), weights_only=True)
|
||||
trainer.callback_metrics.update(callback_metrics)
|
||||
|
||||
@override
|
||||
|
|
|
@ -38,7 +38,7 @@ from lightning.fabric.utilities.distributed import (
|
|||
_sync_ddp_if_available,
|
||||
)
|
||||
from lightning.fabric.utilities.distributed import group as _group
|
||||
from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_3
|
||||
from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_4
|
||||
from lightning.fabric.utilities.init import _materialize_distributed_module
|
||||
from lightning.fabric.utilities.load import _METADATA_FILENAME
|
||||
from lightning.fabric.utilities.optimizer import _optimizers_to_device
|
||||
|
@ -64,7 +64,7 @@ class ModelParallelStrategy(ParallelStrategy):
|
|||
Currently supports up to 2D parallelism. Specifically, it supports the combination of
|
||||
Fully Sharded Data-Parallel 2 (FSDP2) with Tensor Parallelism (DTensor). These PyTorch APIs are currently still
|
||||
experimental in PyTorch (see https://pytorch.org/docs/stable/distributed.tensor.parallel.html).
|
||||
Requires PyTorch 2.3 or newer.
|
||||
Requires PyTorch 2.4 or newer.
|
||||
|
||||
Arguments:
|
||||
data_parallel_size: The number of devices within a data-parallel group. Defaults to ``"auto"``, which
|
||||
|
@ -86,8 +86,8 @@ class ModelParallelStrategy(ParallelStrategy):
|
|||
timeout: Optional[timedelta] = default_pg_timeout,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
if not _TORCH_GREATER_EQUAL_2_3:
|
||||
raise ImportError(f"{type(self).__name__} requires PyTorch 2.3 or higher.")
|
||||
if not _TORCH_GREATER_EQUAL_2_4:
|
||||
raise ImportError(f"{type(self).__name__} requires PyTorch 2.4 or higher.")
|
||||
self._data_parallel_size = data_parallel_size
|
||||
self._tensor_parallel_size = tensor_parallel_size
|
||||
self._save_distributed_checkpoint = save_distributed_checkpoint
|
||||
|
@ -170,7 +170,7 @@ class ModelParallelStrategy(ParallelStrategy):
|
|||
if any(isinstance(mod, FullyShardedDataParallel) for mod in self.model.modules()):
|
||||
raise TypeError(
|
||||
"Found modules that are wrapped with `torch.distributed.fsdp.FullyShardedDataParallel`."
|
||||
f" The `{self.__class__.__name__}` only supports the new FSDP2 APIs in PyTorch >= 2.3."
|
||||
f" The `{self.__class__.__name__}` only supports the new FSDP2 APIs in PyTorch >= 2.4."
|
||||
)
|
||||
|
||||
_materialize_distributed_module(self.model, self.root_device)
|
||||
|
|
|
@ -156,6 +156,7 @@ def thread_police_duuu_daaa_duuu_daaa():
|
|||
elif (
|
||||
thread.name == "QueueFeederThread" # tensorboardX
|
||||
or thread.name == "QueueManagerThread" # torch.compile
|
||||
or "(_read_thread)" in thread.name # torch.compile
|
||||
):
|
||||
thread.join(timeout=20)
|
||||
elif isinstance(thread, TMonitor):
|
||||
|
|
|
@ -19,6 +19,7 @@ import pytest
|
|||
import torch
|
||||
from fsspec.implementations.local import LocalFileSystem
|
||||
from lightning.fabric.utilities.cloud_io import get_filesystem
|
||||
from lightning.fabric.utilities.imports import _IS_WINDOWS, _TORCH_GREATER_EQUAL_2_4
|
||||
from lightning.pytorch.core.module import LightningModule
|
||||
from lightning.pytorch.demos.boring_classes import BoringModel
|
||||
|
||||
|
@ -26,6 +27,7 @@ from tests_pytorch.helpers.advanced_models import BasicGAN, ParityModuleRNN
|
|||
from tests_pytorch.helpers.runif import RunIf
|
||||
|
||||
|
||||
@pytest.mark.skipif(_IS_WINDOWS and _TORCH_GREATER_EQUAL_2_4, reason="not close on Windows + PyTorch 2.4")
|
||||
@pytest.mark.parametrize("modelclass", [BoringModel, ParityModuleRNN, BasicGAN])
|
||||
def test_torchscript_input_output(modelclass):
|
||||
"""Test that scripted LightningModule forward works."""
|
||||
|
@ -45,6 +47,7 @@ def test_torchscript_input_output(modelclass):
|
|||
assert torch.allclose(script_output, model_output)
|
||||
|
||||
|
||||
@pytest.mark.skipif(_IS_WINDOWS and _TORCH_GREATER_EQUAL_2_4, reason="not close on Windows + PyTorch 2.4")
|
||||
@pytest.mark.parametrize("modelclass", [BoringModel, ParityModuleRNN, BasicGAN])
|
||||
def test_torchscript_example_input_output_trace(modelclass):
|
||||
"""Test that traced LightningModule forward works with example_input_array."""
|
||||
|
|
|
@ -15,6 +15,7 @@ from unittest.mock import Mock
|
|||
|
||||
import torch
|
||||
from lightning.fabric import seed_everything
|
||||
from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_4
|
||||
from lightning.pytorch import Trainer
|
||||
from lightning.pytorch.demos.boring_classes import BoringModel
|
||||
from lightning.pytorch.plugins.precision import MixedPrecision
|
||||
|
@ -28,7 +29,8 @@ class FusedOptimizerParityModel(BoringModel):
|
|||
self.fused = fused
|
||||
|
||||
def configure_optimizers(self):
|
||||
assert isinstance(self.trainer.precision_plugin.scaler, torch.cuda.amp.GradScaler)
|
||||
scaler_cls = torch.amp.GradScaler if _TORCH_GREATER_EQUAL_2_4 else torch.cuda.amp.GradScaler
|
||||
assert isinstance(self.trainer.precision_plugin.scaler, scaler_cls)
|
||||
return torch.optim.Adam(self.parameters(), lr=1.0, fused=self.fused)
|
||||
|
||||
|
||||
|
|
|
@ -21,6 +21,7 @@ from unittest.mock import patch
|
|||
import numpy as np
|
||||
import pytest
|
||||
import torch
|
||||
from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_4
|
||||
from lightning.pytorch import Callback, Trainer
|
||||
from lightning.pytorch.callbacks import EarlyStopping, StochasticWeightAveraging
|
||||
from lightning.pytorch.demos.boring_classes import BoringModel, ManualOptimBoringModel
|
||||
|
@ -430,7 +431,8 @@ def test_pytorch_profiler_trainer(fn, step_name, boring_model_cls, tmp_path):
|
|||
|
||||
def test_pytorch_profiler_nested(tmp_path):
|
||||
"""Ensure that the profiler handles nested context."""
|
||||
pytorch_profiler = PyTorchProfiler(use_cuda=False, dirpath=tmp_path, filename="profiler", schedule=None)
|
||||
kwargs = {} if _TORCH_GREATER_EQUAL_2_4 else {"use_cuda": False}
|
||||
pytorch_profiler = PyTorchProfiler(dirpath=tmp_path, filename="profiler", schedule=None, **kwargs)
|
||||
|
||||
with pytorch_profiler.profile("a"):
|
||||
a = torch.ones(42)
|
||||
|
@ -475,13 +477,14 @@ def test_pytorch_profiler_multiple_loggers(tmp_path):
|
|||
|
||||
def test_register_record_function(tmp_path):
|
||||
use_cuda = torch.cuda.is_available()
|
||||
kwargs = {} if _TORCH_GREATER_EQUAL_2_4 else {"use_cuda": torch.cuda.is_available()}
|
||||
pytorch_profiler = PyTorchProfiler(
|
||||
export_to_chrome=False,
|
||||
use_cuda=use_cuda,
|
||||
dirpath=tmp_path,
|
||||
filename="profiler",
|
||||
schedule=None,
|
||||
on_trace_ready=None,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
class TestModel(BoringModel):
|
||||
|
|
|
@ -28,7 +28,7 @@ from lightning.pytorch.callbacks import Callback, LearningRateMonitor, ModelChec
|
|||
from lightning.pytorch.demos.boring_classes import BoringModel, RandomDataset, RandomIterableDataset
|
||||
from lightning.pytorch.loggers import CSVLogger
|
||||
from lightning.pytorch.plugins import DeepSpeedPrecision
|
||||
from lightning.pytorch.strategies.deepspeed import _DEEPSPEED_AVAILABLE, DeepSpeedStrategy
|
||||
from lightning.pytorch.strategies.deepspeed import DeepSpeedStrategy
|
||||
from lightning.pytorch.utilities.exceptions import MisconfigurationException
|
||||
from lightning.pytorch.utilities.imports import _TORCHMETRICS_GREATER_EQUAL_0_11 as _TM_GE_0_11
|
||||
from torch import Tensor, nn
|
||||
|
@ -38,11 +38,6 @@ from torchmetrics import Accuracy
|
|||
from tests_pytorch.helpers.datamodules import ClassifDataModule
|
||||
from tests_pytorch.helpers.runif import RunIf
|
||||
|
||||
if _DEEPSPEED_AVAILABLE:
|
||||
import deepspeed
|
||||
from deepspeed.runtime.zero.stage_1_and_2 import DeepSpeedZeroOptimizer
|
||||
from deepspeed.utils.zero_to_fp32 import convert_zero_checkpoint_to_fp32_state_dict
|
||||
|
||||
|
||||
class ModelParallelBoringModel(BoringModel):
|
||||
def __init__(self):
|
||||
|
@ -245,6 +240,7 @@ def test_deepspeed_auto_batch_size_config_select(_, __, tmp_path, dataset_cls, v
|
|||
def test_deepspeed_run_configure_optimizers(tmp_path):
|
||||
"""Test end to end that deepspeed works with defaults (without ZeRO as that requires compilation), whilst using
|
||||
configure_optimizers for optimizers and schedulers."""
|
||||
from deepspeed.runtime.zero.stage_1_and_2 import DeepSpeedZeroOptimizer
|
||||
|
||||
class TestCB(Callback):
|
||||
def on_train_start(self, trainer, pl_module) -> None:
|
||||
|
@ -284,6 +280,7 @@ def test_deepspeed_run_configure_optimizers(tmp_path):
|
|||
def test_deepspeed_config(tmp_path, deepspeed_zero_config):
|
||||
"""Test to ensure deepspeed works correctly when passed a DeepSpeed config object including optimizers/schedulers
|
||||
and saves the model weights to load correctly."""
|
||||
from deepspeed.runtime.zero.stage_1_and_2 import DeepSpeedZeroOptimizer
|
||||
|
||||
class TestCB(Callback):
|
||||
def on_train_start(self, trainer, pl_module) -> None:
|
||||
|
@ -397,6 +394,8 @@ def test_deepspeed_custom_activation_checkpointing_params():
|
|||
def test_deepspeed_custom_activation_checkpointing_params_forwarded(tmp_path):
|
||||
"""Ensure if we modify the activation checkpointing parameters, we pass these to deepspeed.checkpointing.configure
|
||||
correctly."""
|
||||
import deepspeed
|
||||
|
||||
ds = DeepSpeedStrategy(
|
||||
partition_activations=True,
|
||||
cpu_checkpointing=True,
|
||||
|
@ -453,6 +452,8 @@ def test_deepspeed_assert_config_zero_offload_disabled(tmp_path, deepspeed_zero_
|
|||
@RunIf(min_cuda_gpus=2, standalone=True, deepspeed=True)
|
||||
def test_deepspeed_multigpu(tmp_path):
|
||||
"""Test to ensure that DeepSpeed with multiple GPUs works and deepspeed distributed is initialized correctly."""
|
||||
import deepspeed
|
||||
|
||||
model = BoringModel()
|
||||
trainer = Trainer(
|
||||
default_root_dir=tmp_path,
|
||||
|
@ -978,6 +979,8 @@ def test_deepspeed_strategy_env_variables(mock_deepspeed_distributed, tmp_path,
|
|||
|
||||
|
||||
def _assert_save_model_is_equal(model, tmp_path, trainer):
|
||||
from deepspeed.utils.zero_to_fp32 import convert_zero_checkpoint_to_fp32_state_dict
|
||||
|
||||
checkpoint_path = os.path.join(tmp_path, "model.pt")
|
||||
checkpoint_path = trainer.strategy.broadcast(checkpoint_path)
|
||||
trainer.save_checkpoint(checkpoint_path)
|
||||
|
|
|
@ -217,6 +217,7 @@ def test_custom_mixed_precision():
|
|||
assert strategy.mixed_precision_config == config
|
||||
|
||||
|
||||
@pytest.mark.filterwarnings("ignore::FutureWarning")
|
||||
@RunIf(min_cuda_gpus=2, skip_windows=True, standalone=True)
|
||||
def test_strategy_sync_batchnorm(tmp_path):
|
||||
"""Test to ensure that sync_batchnorm works when using FSDP and GPU, and all stages can be run."""
|
||||
|
@ -233,6 +234,7 @@ def test_strategy_sync_batchnorm(tmp_path):
|
|||
_run_multiple_stages(trainer, model, os.path.join(tmp_path, "last.ckpt"))
|
||||
|
||||
|
||||
@pytest.mark.filterwarnings("ignore::FutureWarning")
|
||||
@RunIf(min_cuda_gpus=1, skip_windows=True)
|
||||
def test_modules_without_parameters(tmp_path):
|
||||
"""Test that TorchMetrics get moved to the device despite not having any parameters."""
|
||||
|
@ -263,6 +265,7 @@ def test_modules_without_parameters(tmp_path):
|
|||
trainer.fit(model)
|
||||
|
||||
|
||||
@pytest.mark.filterwarnings("ignore::FutureWarning")
|
||||
@RunIf(min_cuda_gpus=2, skip_windows=True, standalone=True)
|
||||
@pytest.mark.parametrize("precision", ["16-mixed", pytest.param("bf16-mixed", marks=RunIf(bf16_cuda=True))])
|
||||
@pytest.mark.parametrize("state_dict_type", ["sharded", "full"])
|
||||
|
@ -284,6 +287,7 @@ def custom_auto_wrap_policy(
|
|||
return nonwrapped_numel >= 2
|
||||
|
||||
|
||||
@pytest.mark.filterwarnings("ignore::FutureWarning")
|
||||
@RunIf(min_cuda_gpus=2, skip_windows=True, standalone=True)
|
||||
@pytest.mark.parametrize("wrap_min_params", [2, 1024, 100000000])
|
||||
def test_strategy_full_state_dict(tmp_path, wrap_min_params):
|
||||
|
@ -319,6 +323,7 @@ def test_strategy_full_state_dict(tmp_path, wrap_min_params):
|
|||
assert all(_ex == _co for _ex, _co in zip(full_state_dict.keys(), correct_state_dict.keys()))
|
||||
|
||||
|
||||
@pytest.mark.filterwarnings("ignore::FutureWarning")
|
||||
@RunIf(min_cuda_gpus=2, skip_windows=True, standalone=True)
|
||||
@pytest.mark.parametrize(
|
||||
("model", "strategy", "strategy_cfg"),
|
||||
|
@ -552,6 +557,7 @@ def test_strategy_load_optimizer_states_multiple(_, tmp_path):
|
|||
strategy.load_checkpoint(tmp_path / "one-state.ckpt")
|
||||
|
||||
|
||||
@pytest.mark.filterwarnings("ignore::FutureWarning")
|
||||
@RunIf(min_cuda_gpus=2, skip_windows=True, standalone=True)
|
||||
@pytest.mark.parametrize("wrap_min_params", [2, 1024, 100000000])
|
||||
def test_strategy_save_optimizer_states(tmp_path, wrap_min_params):
|
||||
|
@ -610,6 +616,7 @@ def test_strategy_save_optimizer_states(tmp_path, wrap_min_params):
|
|||
trainer.strategy.barrier()
|
||||
|
||||
|
||||
@pytest.mark.filterwarnings("ignore::FutureWarning")
|
||||
@RunIf(min_cuda_gpus=2, skip_windows=True, standalone=True)
|
||||
@pytest.mark.parametrize("wrap_min_params", [2, 1024, 100000000])
|
||||
def test_strategy_load_optimizer_states(wrap_min_params, tmp_path):
|
||||
|
@ -808,6 +815,7 @@ class TestFSDPCheckpointModel(BoringModel):
|
|||
torch.testing.assert_close(p0, p1, atol=0, rtol=0, equal_nan=True)
|
||||
|
||||
|
||||
@pytest.mark.filterwarnings("ignore::FutureWarning")
|
||||
@RunIf(min_cuda_gpus=2, standalone=True)
|
||||
def test_save_load_sharded_state_dict(tmp_path):
|
||||
"""Test FSDP saving and loading with the sharded state dict format."""
|
||||
|
@ -917,11 +925,20 @@ def test_module_init_context(precision, expected_dtype, tmp_path):
|
|||
_run_setup_assertions(empty_init=True, expected_device=torch.device("meta"))
|
||||
|
||||
|
||||
@pytest.mark.filterwarnings("ignore::FutureWarning")
|
||||
@RunIf(min_cuda_gpus=2, standalone=True, min_torch="2.3.0")
|
||||
def test_save_sharded_and_consolidate_and_load(tmp_path):
|
||||
"""Test the consolidation of a FSDP-sharded checkpoint into a single file."""
|
||||
|
||||
model = BoringModel()
|
||||
class CustomModel(BoringModel):
|
||||
def configure_optimizers(self):
|
||||
# Use Adam instead of SGD for this test because it has state
|
||||
# In PyTorch >= 2.4, saving an optimizer with empty state would result in a `KeyError: 'state'`
|
||||
# when loading the optimizer state-dict back.
|
||||
# TODO: To resolve this, switch to the new `torch.distributed.checkpoint` APIs in FSDPStrategy
|
||||
return torch.optim.Adam(self.parameters(), lr=0.1)
|
||||
|
||||
model = CustomModel()
|
||||
trainer = Trainer(
|
||||
default_root_dir=tmp_path,
|
||||
accelerator="cuda",
|
||||
|
@ -942,7 +959,7 @@ def test_save_sharded_and_consolidate_and_load(tmp_path):
|
|||
torch.save(checkpoint, checkpoint_path_full)
|
||||
trainer.strategy.barrier()
|
||||
|
||||
model = BoringModel()
|
||||
model = CustomModel()
|
||||
trainer = Trainer(
|
||||
default_root_dir=tmp_path,
|
||||
accelerator="cuda",
|
||||
|
|
|
@ -28,20 +28,20 @@ from lightning.pytorch.strategies import ModelParallelStrategy
|
|||
from tests_pytorch.helpers.runif import RunIf
|
||||
|
||||
|
||||
@mock.patch("lightning.pytorch.strategies.model_parallel._TORCH_GREATER_EQUAL_2_3", False)
|
||||
def test_torch_greater_equal_2_3():
|
||||
with pytest.raises(ImportError, match="ModelParallelStrategy requires PyTorch 2.3 or higher"):
|
||||
@mock.patch("lightning.pytorch.strategies.model_parallel._TORCH_GREATER_EQUAL_2_4", False)
|
||||
def test_torch_greater_equal_2_4():
|
||||
with pytest.raises(ImportError, match="ModelParallelStrategy requires PyTorch 2.4 or higher"):
|
||||
ModelParallelStrategy()
|
||||
|
||||
|
||||
@RunIf(min_torch="2.3")
|
||||
@RunIf(min_torch="2.4")
|
||||
def test_device_mesh_access():
|
||||
strategy = ModelParallelStrategy()
|
||||
with pytest.raises(RuntimeError, match="Accessing the device mesh .* not allowed"):
|
||||
_ = strategy.device_mesh
|
||||
|
||||
|
||||
@RunIf(min_torch="2.3")
|
||||
@RunIf(min_torch="2.4")
|
||||
@pytest.mark.parametrize(
|
||||
("num_nodes", "devices", "invalid_dp_size", "invalid_tp_size"),
|
||||
[
|
||||
|
@ -69,7 +69,7 @@ def test_validate_device_mesh_dimensions(num_nodes, devices, invalid_dp_size, in
|
|||
strategy.setup_environment()
|
||||
|
||||
|
||||
@RunIf(min_torch="2.3")
|
||||
@RunIf(min_torch="2.4")
|
||||
def test_fsdp_v1_modules_unsupported():
|
||||
"""Test that the strategy won't allow setting up a module wrapped with the legacy FSDP API."""
|
||||
from torch.distributed.fsdp import FullyShardedDataParallel
|
||||
|
@ -85,11 +85,11 @@ def test_fsdp_v1_modules_unsupported():
|
|||
strategy._lightning_module = model
|
||||
strategy._accelerator = Mock()
|
||||
|
||||
with pytest.raises(TypeError, match="only supports the new FSDP2 APIs in PyTorch >= 2.3"):
|
||||
with pytest.raises(TypeError, match="only supports the new FSDP2 APIs in PyTorch >= 2.4"):
|
||||
strategy.setup(Mock())
|
||||
|
||||
|
||||
@RunIf(min_torch="2.3")
|
||||
@RunIf(min_torch="2.4")
|
||||
def test_configure_model_required():
|
||||
class Model1(LightningModule):
|
||||
pass
|
||||
|
@ -114,7 +114,7 @@ def test_configure_model_required():
|
|||
strategy.setup(Mock())
|
||||
|
||||
|
||||
@RunIf(min_torch="2.3")
|
||||
@RunIf(min_torch="2.4")
|
||||
def test_save_checkpoint_storage_options(tmp_path):
|
||||
"""Test that the strategy does not accept storage options for saving checkpoints."""
|
||||
strategy = ModelParallelStrategy()
|
||||
|
@ -124,7 +124,7 @@ def test_save_checkpoint_storage_options(tmp_path):
|
|||
strategy.save_checkpoint(checkpoint=Mock(), filepath=tmp_path, storage_options=Mock())
|
||||
|
||||
|
||||
@RunIf(min_torch="2.3")
|
||||
@RunIf(min_torch="2.4")
|
||||
@mock.patch("lightning.pytorch.strategies.model_parallel.ModelParallelStrategy.broadcast", lambda _, x: x)
|
||||
@mock.patch("lightning.fabric.plugins.io.torch_io._atomic_save")
|
||||
@mock.patch("lightning.pytorch.strategies.model_parallel.shutil")
|
||||
|
@ -174,7 +174,7 @@ def test_save_checkpoint_path_exists(shutil_mock, torch_save_mock, tmp_path):
|
|||
assert path.is_dir()
|
||||
|
||||
|
||||
@RunIf(min_torch="2.3")
|
||||
@RunIf(min_torch="2.4")
|
||||
@mock.patch("lightning.fabric.strategies.model_parallel._has_dtensor_modules", return_value=True)
|
||||
def test_load_unknown_checkpoint_type(_, tmp_path):
|
||||
"""Test that the strategy validates the contents at the checkpoint path."""
|
||||
|
@ -187,7 +187,7 @@ def test_load_unknown_checkpoint_type(_, tmp_path):
|
|||
strategy.load_checkpoint(checkpoint_path=path)
|
||||
|
||||
|
||||
@RunIf(min_torch="2.3")
|
||||
@RunIf(min_torch="2.4")
|
||||
@mock.patch("lightning.pytorch.strategies.model_parallel._setup_device_mesh")
|
||||
@mock.patch("torch.distributed.init_process_group")
|
||||
def test_set_timeout(init_process_group_mock, _):
|
||||
|
@ -207,7 +207,7 @@ def test_set_timeout(init_process_group_mock, _):
|
|||
)
|
||||
|
||||
|
||||
@RunIf(min_torch="2.3")
|
||||
@RunIf(min_torch="2.4")
|
||||
def test_meta_device_materialization():
|
||||
"""Test that the `setup()` method materializes meta-device tensors in the LightningModule."""
|
||||
|
||||
|
|
|
@ -111,7 +111,7 @@ class FSDP2TensorParallelModel(TemplateModel):
|
|||
_parallelize_feed_forward_fsdp2_tp(self.model, device_mesh=self.device_mesh)
|
||||
|
||||
|
||||
@RunIf(min_torch="2.3", standalone=True, min_cuda_gpus=4)
|
||||
@RunIf(min_torch="2.4", standalone=True, min_cuda_gpus=4)
|
||||
def test_setup_device_mesh():
|
||||
from torch.distributed.device_mesh import DeviceMesh
|
||||
|
||||
|
@ -168,7 +168,7 @@ def test_setup_device_mesh():
|
|||
trainer.fit(model)
|
||||
|
||||
|
||||
@RunIf(min_torch="2.3", standalone=True, min_cuda_gpus=2)
|
||||
@RunIf(min_torch="2.4", standalone=True, min_cuda_gpus=2)
|
||||
def test_tensor_parallel():
|
||||
from torch.distributed._tensor import DTensor
|
||||
|
||||
|
@ -209,7 +209,7 @@ def test_tensor_parallel():
|
|||
trainer.fit(model)
|
||||
|
||||
|
||||
@RunIf(min_torch="2.3", standalone=True, min_cuda_gpus=4)
|
||||
@RunIf(min_torch="2.4", standalone=True, min_cuda_gpus=4)
|
||||
def test_fsdp2_tensor_parallel():
|
||||
from torch.distributed._tensor import DTensor
|
||||
|
||||
|
@ -266,7 +266,7 @@ def test_fsdp2_tensor_parallel():
|
|||
trainer.fit(model)
|
||||
|
||||
|
||||
@RunIf(min_torch="2.3", min_cuda_gpus=2, standalone=True)
|
||||
@RunIf(min_torch="2.4", min_cuda_gpus=2, standalone=True)
|
||||
def test_modules_without_parameters(tmp_path):
|
||||
"""Test that TorchMetrics get moved to the device despite not having any parameters."""
|
||||
|
||||
|
@ -297,7 +297,7 @@ def test_modules_without_parameters(tmp_path):
|
|||
trainer.fit(model)
|
||||
|
||||
|
||||
@RunIf(min_torch="2.3", min_cuda_gpus=2, standalone=True)
|
||||
@RunIf(min_torch="2.4", min_cuda_gpus=2, standalone=True)
|
||||
@pytest.mark.parametrize(
|
||||
("precision", "expected_dtype"),
|
||||
[
|
||||
|
@ -343,7 +343,7 @@ def test_module_init_context(precision, expected_dtype, tmp_path):
|
|||
_run_setup_assertions(empty_init=True, expected_device=torch.device("meta"))
|
||||
|
||||
|
||||
@RunIf(min_torch="2.3", min_cuda_gpus=2, skip_windows=True, standalone=True)
|
||||
@RunIf(min_torch="2.4", min_cuda_gpus=2, skip_windows=True, standalone=True)
|
||||
@pytest.mark.parametrize("save_distributed_checkpoint", [True, False])
|
||||
def test_strategy_state_dict(tmp_path, save_distributed_checkpoint):
|
||||
"""Test that the strategy returns the correct state dict of the LightningModule."""
|
||||
|
@ -377,7 +377,7 @@ def test_strategy_state_dict(tmp_path, save_distributed_checkpoint):
|
|||
assert list(state_dict.keys()) == list(correct_state_dict.keys())
|
||||
|
||||
|
||||
@RunIf(min_torch="2.3", min_cuda_gpus=2, skip_windows=True, standalone=True)
|
||||
@RunIf(min_torch="2.4", min_cuda_gpus=2, skip_windows=True, standalone=True)
|
||||
def test_load_full_state_checkpoint_into_regular_model(tmp_path):
|
||||
"""Test that a full-state checkpoint saved from a distributed model can be loaded back into a regular model."""
|
||||
|
||||
|
|
|
@ -1067,7 +1067,7 @@ def test_bitsandbytes_precision_cuda_required(monkeypatch):
|
|||
_AcceleratorConnector(accelerator="cpu", plugins=BitsandbytesPrecision(mode="int8"))
|
||||
|
||||
|
||||
@RunIf(min_torch="2.3")
|
||||
@RunIf(min_torch="2.4")
|
||||
@pytest.mark.parametrize(
|
||||
("precision", "raises"),
|
||||
[("32-true", False), ("16-true", False), ("bf16-true", False), ("16-mixed", True), ("bf16-mixed", False)],
|
||||
|
|
|
@ -11,16 +11,18 @@
|
|||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import os
|
||||
import sys
|
||||
from contextlib import nullcontext
|
||||
from unittest import mock
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_2
|
||||
from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_2, _TORCH_GREATER_EQUAL_2_4
|
||||
from lightning.pytorch import LightningModule, Trainer
|
||||
from lightning.pytorch.demos.boring_classes import BoringModel
|
||||
from lightning.pytorch.utilities.compile import from_compiled, to_uncompiled
|
||||
from lightning_utilities.core import module_available
|
||||
from lightning_utilities.core.imports import RequirementCache
|
||||
|
||||
from tests_pytorch.conftest import mock_cuda_count
|
||||
from tests_pytorch.helpers.runif import RunIf
|
||||
|
@ -67,10 +69,20 @@ def test_trainer_compiled_model(_, tmp_path, monkeypatch, mps_count_0):
|
|||
assert trainer.model._compiler_ctx is None
|
||||
|
||||
# some strategies do not support it
|
||||
if module_available("deepspeed"):
|
||||
if RequirementCache("deepspeed"):
|
||||
compiled_model = torch.compile(model)
|
||||
mock_cuda_count(monkeypatch, 2)
|
||||
trainer = Trainer(strategy="deepspeed", accelerator="cuda", **trainer_kwargs)
|
||||
|
||||
# TODO: Update deepspeed to avoid deprecation warning for `torch.cuda.amp.custom_fwd` on import
|
||||
warn_context = (
|
||||
pytest.warns(FutureWarning, match="torch.cuda.amp.*is deprecated")
|
||||
if _TORCH_GREATER_EQUAL_2_4
|
||||
else nullcontext()
|
||||
)
|
||||
|
||||
with warn_context:
|
||||
trainer = Trainer(strategy="deepspeed", accelerator="cuda", **trainer_kwargs)
|
||||
|
||||
with pytest.raises(RuntimeError, match="Using a compiled model is incompatible with the current strategy.*"):
|
||||
trainer.fit(compiled_model)
|
||||
|
||||
|
@ -122,6 +134,7 @@ def test_compile_uncompile():
|
|||
sys.platform == "win32" and _TORCH_GREATER_EQUAL_2_2, strict=False, reason="RuntimeError: Failed to import"
|
||||
)
|
||||
@RunIf(dynamo=True)
|
||||
@mock.patch.dict(os.environ, {})
|
||||
def test_trainer_compiled_model_that_logs(tmp_path):
|
||||
class MyModel(BoringModel):
|
||||
def training_step(self, batch, batch_idx):
|
||||
|
@ -152,6 +165,7 @@ def test_trainer_compiled_model_that_logs(tmp_path):
|
|||
sys.platform == "win32" and _TORCH_GREATER_EQUAL_2_2, strict=False, reason="RuntimeError: Failed to import"
|
||||
)
|
||||
@RunIf(dynamo=True)
|
||||
@mock.patch.dict(os.environ, {})
|
||||
def test_trainer_compiled_model_test(tmp_path):
|
||||
model = BoringModel()
|
||||
compiled_model = torch.compile(model)
|
||||
|
|
Loading…
Reference in New Issue