Drop PyTorch 2.0 from the test matrix (#20009)

This commit is contained in:
awaelchli 2024-07-01 00:02:00 +02:00 committed by GitHub
parent 5636fe4a9c
commit 14493c0685
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
42 changed files with 118 additions and 295 deletions

View File

@ -19,29 +19,23 @@ subprojects:
- "!*.md"
- "!**/*.md"
checks:
- "pl-cpu (macOS-13, lightning, 3.8, 2.0, oldest)"
- "pl-cpu (macOS-14, lightning, 3.10, 2.0)"
- "pl-cpu (macOS-13, lightning, 3.8, 2.1, oldest)"
- "pl-cpu (macOS-14, lightning, 3.10, 2.1)"
- "pl-cpu (macOS-14, lightning, 3.10, 2.2)"
- "pl-cpu (macOS-14, lightning, 3.10, 2.3)"
- "pl-cpu (ubuntu-20.04, lightning, 3.8, 2.0, oldest)"
- "pl-cpu (ubuntu-20.04, lightning, 3.10, 2.0)"
- "pl-cpu (ubuntu-20.04, lightning, 3.8, 2.1, oldest)"
- "pl-cpu (ubuntu-20.04, lightning, 3.10, 2.1)"
- "pl-cpu (ubuntu-20.04, lightning, 3.10, 2.2)"
- "pl-cpu (ubuntu-20.04, lightning, 3.10, 2.3)"
- "pl-cpu (windows-2022, lightning, 3.8, 2.0, oldest)"
- "pl-cpu (windows-2022, lightning, 3.10, 2.0)"
- "pl-cpu (windows-2022, lightning, 3.8, 2.1, oldest)"
- "pl-cpu (windows-2022, lightning, 3.10, 2.1)"
- "pl-cpu (windows-2022, lightning, 3.10, 2.2)"
- "pl-cpu (windows-2022, lightning, 3.10, 2.3)"
- "pl-cpu (macOS-14, pytorch, 3.8, 2.0)"
- "pl-cpu (ubuntu-20.04, pytorch, 3.8, 2.0)"
- "pl-cpu (windows-2022, pytorch, 3.8, 2.0)"
- "pl-cpu (macOS-12, pytorch, 3.11, 2.0)"
- "pl-cpu (macOS-14, pytorch, 3.8, 2.1)"
- "pl-cpu (ubuntu-20.04, pytorch, 3.8, 2.1)"
- "pl-cpu (windows-2022, pytorch, 3.8, 2.1)"
- "pl-cpu (macOS-12, pytorch, 3.11, 2.1)"
- "pl-cpu (ubuntu-22.04, pytorch, 3.11, 2.0)"
- "pl-cpu (ubuntu-22.04, pytorch, 3.11, 2.1)"
- "pl-cpu (windows-2022, pytorch, 3.11, 2.0)"
- "pl-cpu (windows-2022, pytorch, 3.11, 2.1)"
- id: "pytorch_lightning: Azure GPU"
@ -144,13 +138,11 @@ subprojects:
- "!*.md"
- "!**/*.md"
checks:
- "build-cuda (3.10, 2.0, 11.8.0)"
- "build-cuda (3.10, 2.1, 12.1.0)"
- "build-cuda (3.10, 2.2, 12.1.0)"
- "build-cuda (3.11, 2.1, 12.1.0)"
- "build-cuda (3.11, 2.2, 12.1.0)"
#- "build-NGC"
- "build-pl (3.10, 2.0, 11.8.0)"
- "build-pl (3.10, 2.1, 12.1.0)"
- "build-pl (3.10, 2.2, 12.1.0)"
- "build-pl (3.11, 2.1, 12.1.0)"
@ -171,29 +163,23 @@ subprojects:
- "!*.md"
- "!**/*.md"
checks:
- "fabric-cpu (macOS-13, lightning, 3.8, 2.0, oldest)"
- "fabric-cpu (macOS-14, lightning, 3.10, 2.0)"
- "fabric-cpu (macOS-13, lightning, 3.8, 2.1, oldest)"
- "fabric-cpu (macOS-14, lightning, 3.11, 2.1)"
- "fabric-cpu (macOS-14, lightning, 3.11, 2.2)"
- "fabric-cpu (macOS-14, lightning, 3.10, 2.3)"
- "fabric-cpu (ubuntu-20.04, lightning, 3.8, 2.0, oldest)"
- "fabric-cpu (ubuntu-20.04, lightning, 3.10, 2.0)"
- "fabric-cpu (ubuntu-20.04, lightning, 3.8, 2.1, oldest)"
- "fabric-cpu (ubuntu-20.04, lightning, 3.11, 2.1)"
- "fabric-cpu (ubuntu-20.04, lightning, 3.11, 2.2)"
- "fabric-cpu (ubuntu-20.04, lightning, 3.11, 2.3)"
- "fabric-cpu (windows-2022, lightning, 3.8, 2.0, oldest)"
- "fabric-cpu (windows-2022, lightning, 3.10, 2.0)"
- "fabric-cpu (windows-2022, lightning, 3.8, 2.1, oldest)"
- "fabric-cpu (windows-2022, lightning, 3.11, 2.1)"
- "fabric-cpu (windows-2022, lightning, 3.11, 2.2)"
- "fabric-cpu (windows-2022, lightning, 3.11, 2.3)"
- "fabric-cpu (macOS-14, fabric, 3.8, 2.0)"
- "fabric-cpu (ubuntu-20.04, fabric, 3.8, 2.0)"
- "fabric-cpu (windows-2022, fabric, 3.8, 2.0)"
- "fabric-cpu (macOS-12, fabric, 3.11, 2.0)"
- "fabric-cpu (macOS-14, fabric, 3.8, 2.1)"
- "fabric-cpu (ubuntu-20.04, fabric, 3.8, 2.1)"
- "fabric-cpu (windows-2022, fabric, 3.8, 2.1)"
- "fabric-cpu (macOS-12, fabric, 3.11, 2.1)"
- "fabric-cpu (ubuntu-22.04, fabric, 3.11, 2.0)"
- "fabric-cpu (ubuntu-22.04, fabric, 3.11, 2.1)"
- "fabric-cpu (windows-2022, fabric, 3.11, 2.0)"
- "fabric-cpu (windows-2022, fabric, 3.11, 2.1)"
- id: "lightning_fabric: Azure GPU"

View File

@ -39,9 +39,6 @@ jobs:
fail-fast: false
matrix:
include:
- { os: "macOS-14", pkg-name: "lightning", python-version: "3.10", pytorch-version: "2.0" }
- { os: "ubuntu-20.04", pkg-name: "lightning", python-version: "3.10", pytorch-version: "2.0" }
- { os: "windows-2022", pkg-name: "lightning", python-version: "3.10", pytorch-version: "2.0" }
# only run PyTorch latest
- { os: "macOS-14", pkg-name: "lightning", python-version: "3.11", pytorch-version: "2.1" }
- { os: "ubuntu-20.04", pkg-name: "lightning", python-version: "3.11", pytorch-version: "2.1" }
@ -53,32 +50,29 @@ jobs:
- { os: "ubuntu-20.04", pkg-name: "lightning", python-version: "3.11", pytorch-version: "2.3" }
- { os: "windows-2022", pkg-name: "lightning", python-version: "3.11", pytorch-version: "2.3" }
# only run PyTorch latest with Python latest, use Fabric scope to limit dependency issues
- { os: "macOS-12", pkg-name: "fabric", python-version: "3.11", pytorch-version: "2.0" }
- { os: "ubuntu-22.04", pkg-name: "fabric", python-version: "3.11", pytorch-version: "2.0" }
- { os: "windows-2022", pkg-name: "fabric", python-version: "3.11", pytorch-version: "2.0" }
- { os: "macOS-12", pkg-name: "fabric", python-version: "3.11", pytorch-version: "2.1" }
- { os: "ubuntu-22.04", pkg-name: "fabric", python-version: "3.11", pytorch-version: "2.1" }
- { os: "windows-2022", pkg-name: "fabric", python-version: "3.11", pytorch-version: "2.1" }
# "oldest" versions tests, only on minimum Python
- { os: "macOS-13", pkg-name: "lightning", python-version: "3.8", pytorch-version: "2.0", requires: "oldest" }
- { os: "macOS-13", pkg-name: "lightning", python-version: "3.8", pytorch-version: "2.1", requires: "oldest" }
- {
os: "ubuntu-20.04",
pkg-name: "lightning",
python-version: "3.8",
pytorch-version: "2.0",
pytorch-version: "2.1",
requires: "oldest",
}
- {
os: "windows-2022",
pkg-name: "lightning",
python-version: "3.8",
pytorch-version: "2.0",
pytorch-version: "2.1",
requires: "oldest",
}
# "fabric" installs the standalone package
- { os: "macOS-14", pkg-name: "fabric", python-version: "3.8", pytorch-version: "2.0" }
- { os: "ubuntu-20.04", pkg-name: "fabric", python-version: "3.8", pytorch-version: "2.0" }
- { os: "windows-2022", pkg-name: "fabric", python-version: "3.8", pytorch-version: "2.0" }
- { os: "macOS-14", pkg-name: "fabric", python-version: "3.8", pytorch-version: "2.1" }
- { os: "ubuntu-20.04", pkg-name: "fabric", python-version: "3.8", pytorch-version: "2.1" }
- { os: "windows-2022", pkg-name: "fabric", python-version: "3.8", pytorch-version: "2.1" }
timeout-minutes: 25 # because of building grpcio on Mac
env:
PACKAGE_NAME: ${{ matrix.pkg-name }}

View File

@ -43,9 +43,6 @@ jobs:
fail-fast: false
matrix:
include:
- { os: "macOS-14", pkg-name: "lightning", python-version: "3.10", pytorch-version: "2.0" }
- { os: "ubuntu-20.04", pkg-name: "lightning", python-version: "3.10", pytorch-version: "2.0" }
- { os: "windows-2022", pkg-name: "lightning", python-version: "3.10", pytorch-version: "2.0" }
# only run PyTorch latest
- { os: "macOS-14", pkg-name: "lightning", python-version: "3.10", pytorch-version: "2.1" }
- { os: "ubuntu-20.04", pkg-name: "lightning", python-version: "3.10", pytorch-version: "2.1" }
@ -57,32 +54,29 @@ jobs:
- { os: "ubuntu-20.04", pkg-name: "lightning", python-version: "3.10", pytorch-version: "2.3" }
- { os: "windows-2022", pkg-name: "lightning", python-version: "3.10", pytorch-version: "2.3" }
# only run PyTorch latest with Python latest, use PyTorch scope to limit dependency issues
- { os: "macOS-12", pkg-name: "pytorch", python-version: "3.11", pytorch-version: "2.0" }
- { os: "ubuntu-22.04", pkg-name: "pytorch", python-version: "3.11", pytorch-version: "2.0" }
- { os: "windows-2022", pkg-name: "pytorch", python-version: "3.11", pytorch-version: "2.0" }
- { os: "macOS-12", pkg-name: "pytorch", python-version: "3.11", pytorch-version: "2.1" }
- { os: "ubuntu-22.04", pkg-name: "pytorch", python-version: "3.11", pytorch-version: "2.1" }
- { os: "windows-2022", pkg-name: "pytorch", python-version: "3.11", pytorch-version: "2.1" }
# "oldest" versions tests, only on minimum Python
- { os: "macOS-13", pkg-name: "lightning", python-version: "3.8", pytorch-version: "2.0", requires: "oldest" }
- { os: "macOS-13", pkg-name: "lightning", python-version: "3.8", pytorch-version: "2.1", requires: "oldest" }
- {
os: "ubuntu-20.04",
pkg-name: "lightning",
python-version: "3.8",
pytorch-version: "2.0",
pytorch-version: "2.1",
requires: "oldest",
}
- {
os: "windows-2022",
pkg-name: "lightning",
python-version: "3.8",
pytorch-version: "2.0",
pytorch-version: "2.1",
requires: "oldest",
}
# "pytorch" installs the standalone package
- { os: "macOS-14", pkg-name: "pytorch", python-version: "3.8", pytorch-version: "2.0" }
- { os: "ubuntu-20.04", pkg-name: "pytorch", python-version: "3.8", pytorch-version: "2.0" }
- { os: "windows-2022", pkg-name: "pytorch", python-version: "3.8", pytorch-version: "2.0" }
- { os: "macOS-14", pkg-name: "pytorch", python-version: "3.8", pytorch-version: "2.1" }
- { os: "ubuntu-20.04", pkg-name: "pytorch", python-version: "3.8", pytorch-version: "2.1" }
- { os: "windows-2022", pkg-name: "pytorch", python-version: "3.8", pytorch-version: "2.1" }
timeout-minutes: 50
env:
PACKAGE_NAME: ${{ matrix.pkg-name }}

View File

@ -43,7 +43,6 @@ jobs:
include:
# We only release one docker image per PyTorch version.
# Make sure the matrix here matches the one below.
- { python_version: "3.10", pytorch_version: "2.0", cuda_version: "11.8.0" }
- { python_version: "3.10", pytorch_version: "2.1", cuda_version: "12.1.0" }
- { python_version: "3.10", pytorch_version: "2.2", cuda_version: "12.1.0" }
- { python_version: "3.11", pytorch_version: "2.1", cuda_version: "12.1.0" }
@ -104,7 +103,6 @@ jobs:
include:
# These are the base images for PL release docker images.
# Make sure the matrix here matches the one above.
- { python_version: "3.10", pytorch_version: "2.0", cuda_version: "11.8.0" }
- { python_version: "3.10", pytorch_version: "2.1", cuda_version: "12.1.0" }
- { python_version: "3.10", pytorch_version: "2.2", cuda_version: "12.1.0" }
- { python_version: "3.11", pytorch_version: "2.1", cuda_version: "12.1.0" }

View File

@ -5,10 +5,6 @@ Speed up models by compiling them
Compiling your PyTorch model can result in significant speedups, especially on the latest generations of GPUs.
This guide shows you how to apply `torch.compile <https://pytorch.org/docs/2.2/generated/torch.compile.html>`_ correctly in your code.
.. note::
This requires PyTorch >= 2.0.
----

View File

@ -81,4 +81,4 @@ When training distributed models with :doc:`FSDP/TP <model_parallel/index>` or D
.. note::
Empty-init is experimental and the behavior may change in the future.
For distributed models on PyTorch 2.1+, it is required that all user-defined modules that manage parameters implement a ``reset_parameters()`` method (all PyTorch built-in modules have this too).
For distributed models, it is required that all user-defined modules that manage parameters implement a ``reset_parameters()`` method (all PyTorch built-in modules have this too).

View File

@ -5,10 +5,6 @@ Speed up models by compiling them
Compiling your LightningModule can result in significant speedups, especially on the latest generations of GPUs.
This guide shows you how to apply `torch.compile <https://pytorch.org/docs/2.2/generated/torch.compile.html>`_ correctly in your code.
.. note::
This requires PyTorch >= 2.0.
----

View File

@ -1,8 +1,8 @@
# NOTE: the upper bound for the package version is only set for CI stability, and it is dropped while installing this package
# in case you want to preserve/enforce restrictions on the latest compatible version, add "strict" as an in-line comment
numpy >=1.17.2, <1.27.0
torch >=2.0.0, <2.4.0
numpy >=1.21.0, <1.27.0
torch >=2.1.0, <2.4.0
fsspec[http] >=2022.5.0, <2024.4.0
packaging >=20.0, <=23.1
typing-extensions >=4.4.0, <4.10.0

View File

@ -1,6 +1,6 @@
# NOTE: the upper bound for the package version is only set for CI stability, and it is dropped while installing this package
# in case you want to preserve/enforce restrictions on the latest compatible version, add "strict" as an in-line comment
torchvision >=0.15.0, <0.19.0
torchvision >=0.16.0, <0.19.0
torchmetrics >=0.10.0, <1.3.0
lightning-utilities >=0.8.0, <0.12.0

View File

@ -1,8 +1,8 @@
# NOTE: the upper bound for the package version is only set for CI stability, and it is dropped while installing this package
# in case you want to preserve/enforce restrictions on the latest compatible version, add "strict" as an in-line comment
numpy >=1.17.2, <1.27.0
torch >=2.0.0, <2.4.0
numpy >=1.21.0, <1.27.0
torch >=2.1.0, <2.4.0
tqdm >=4.57.0, <4.67.0
PyYAML >=5.4, <6.1.0
fsspec[http] >=2022.5.0, <2024.4.0

View File

@ -2,7 +2,7 @@
# in case you want to preserve/enforce restrictions on the latest compatible version, add "strict" as an in-line comment
requests <2.32.0
torchvision >=0.15.0, <0.19.0
torchvision >=0.16.0, <0.19.0
ipython[all] <8.15.0
torchmetrics >=0.10.0, <1.3.0
lightning-utilities >=0.8.0, <0.12.0

View File

@ -8,8 +8,8 @@ pytest-random-order ==1.1.0
# needed in tests
cloudpickle >=1.3, <2.3.0
scikit-learn >0.22.1, <1.4.0
onnx >=0.14.0, <1.15.0
onnxruntime >=0.15.0, <1.17.0
onnx >=1.12.0, <1.15.0
onnxruntime >=1.12.0, <1.17.0
psutil <5.9.6 # for `DeviceStatsMonitor`
pandas >1.0, <2.2.0 # needed in benchmarks
fastapi # for `ServableModuleValidator` # not setting version as re-defined in App

View File

@ -27,7 +27,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
### Removed
-
- Removed support for PyTorch 2.1 ([#20009](https://github.com/Lightning-AI/lightning/pull/20009))
-

View File

@ -21,7 +21,7 @@ if not _root_logger.hasHandlers():
_logger.propagate = False
# In PyTorch 2.0+, setting this variable will force `torch.cuda.is_available()` and `torch.cuda.device_count()`
# Setting this variable will force `torch.cuda.is_available()` and `torch.cuda.device_count()`
# to use an NVML-based implementation that doesn't poison forks.
# https://github.com/pytorch/pytorch/issues/83973
os.environ["PYTORCH_NVML_BASED_CUDA_CHECK"] = "1"

View File

@ -225,8 +225,7 @@ class Fabric:
_reapply_compile: If ``True`` (default), and the model was ``torch.compile``d before, the
corresponding :class:`~torch._dynamo.OptimizedModule` wrapper will be removed and reapplied with the
same settings after the model was set up by the strategy (e.g., after the model was wrapped by DDP,
FSDP etc.). Only applies on PyTorch >= 2.1. Set it to ``False`` if compiling DDP/FSDP is causing
issues.
FSDP etc.). Set it to ``False`` if compiling DDP/FSDP is causing issues.
Returns:
The tuple containing wrapped module and the optimizers, in the same order they were passed in.
@ -292,8 +291,7 @@ class Fabric:
_reapply_compile: If ``True`` (default), and the model was ``torch.compile``d before, the
corresponding :class:`~torch._dynamo.OptimizedModule` wrapper will be removed and reapplied with the
same settings after the model was set up by the strategy (e.g., after the model was wrapped by DDP,
FSDP etc.). Only applies on PyTorch >= 2.1. Set it to ``False`` if compiling DDP/FSDP is causing
issues.
FSDP etc.). Set it to ``False`` if compiling DDP/FSDP is causing issues.
Returns:
The wrapped model.

View File

@ -63,11 +63,10 @@ from lightning.fabric.utilities.distributed import (
)
from lightning.fabric.utilities.distributed import group as _group
from lightning.fabric.utilities.imports import (
_TORCH_GREATER_EQUAL_2_1,
_TORCH_GREATER_EQUAL_2_2,
_TORCH_GREATER_EQUAL_2_3,
)
from lightning.fabric.utilities.init import _EmptyInit, _has_meta_device_parameters_or_buffers
from lightning.fabric.utilities.init import _has_meta_device_parameters_or_buffers
from lightning.fabric.utilities.load import _METADATA_FILENAME, _lazy_load, _materialize_tensors, _move_state_into
from lightning.fabric.utilities.rank_zero import rank_zero_deprecation, rank_zero_only, rank_zero_warn
from lightning.fabric.utilities.seed import reset_seed
@ -325,7 +324,7 @@ class FSDPStrategy(ParallelStrategy, _Sharded):
if self._fsdp_kwargs.get("use_orig_params"):
return super().setup_optimizer(optimizer)
if not _optimizer_has_flat_params(optimizer):
# We avoid this limitation in PyTorch >= 2.0 by setting `use_orig_params=True`
# We avoid this limitation by setting `use_orig_params=True`
raise ValueError(
"The optimizer does not seem to reference any FSDP parameters. HINT: Make sure to create the optimizer"
" after setting up the model."
@ -340,15 +339,12 @@ class FSDPStrategy(ParallelStrategy, _Sharded):
def module_init_context(self, empty_init: Optional[bool] = None) -> ContextManager:
precision_init_ctx = self.precision.module_init_context()
module_sharded_ctx = self.module_sharded_context()
empty_ctx = _EmptyInit(enabled=bool(empty_init))
stack = ExitStack()
if _TORCH_GREATER_EQUAL_2_1 and empty_init:
if empty_init:
# Materialization happens in `setup`. When modules get wrapped by FSDP, the sequence of operations is:
# 1) materialize module 2) call `reset_parameters()` 3) shard the module.
# These operations are applied to each submodule 'bottom up' in the module hierarchy.
stack.enter_context(torch.device("meta"))
else:
stack.enter_context(empty_ctx)
stack.enter_context(precision_init_ctx)
stack.enter_context(module_sharded_ctx)
return stack
@ -697,18 +693,13 @@ def _activation_checkpointing_kwargs(
classes = tuple(activation_checkpointing)
else:
classes = (activation_checkpointing,)
if _TORCH_GREATER_EQUAL_2_1:
rank_zero_deprecation(
f"`FSDPStrategy(activation_checkpointing={activation_checkpointing})` is deprecated, use "
f"`FSDPStrategy(activation_checkpointing_policy={set(classes)})` instead."
)
rank_zero_deprecation(
f"`FSDPStrategy(activation_checkpointing={activation_checkpointing})` is deprecated, use "
f"`FSDPStrategy(activation_checkpointing_policy={set(classes)})` instead."
)
return {"check_fn": lambda submodule: isinstance(submodule, classes)}
if isinstance(activation_checkpointing_policy, set):
if _TORCH_GREATER_EQUAL_2_1:
return _auto_wrap_policy_kwargs(activation_checkpointing_policy, {})
return {"check_fn": lambda submodule: isinstance(submodule, tuple(activation_checkpointing_policy))}
if not _TORCH_GREATER_EQUAL_2_1:
raise ValueError("`activation_checkpointing_policy` requires torch >= 2.1.0. HINT: `pip install -U torch`")
return _auto_wrap_policy_kwargs(activation_checkpointing_policy, {})
return {"auto_wrap_policy": activation_checkpointing_policy}
@ -716,15 +707,10 @@ def _auto_wrap_policy_kwargs(policy: Optional["_POLICY"], kwargs: Dict) -> Dict:
if policy is None:
return kwargs
if isinstance(policy, set):
if _TORCH_GREATER_EQUAL_2_1:
from torch.distributed.fsdp.wrap import ModuleWrapPolicy
from torch.distributed.fsdp.wrap import ModuleWrapPolicy
policy = ModuleWrapPolicy(policy)
else:
from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy
policy = ModuleWrapPolicy(policy)
# this is not transformer specific despite the name
policy = partial(transformer_auto_wrap_policy, transformer_layer_cls=policy)
kwargs["auto_wrap_policy"] = policy
return kwargs
@ -829,11 +815,8 @@ def _get_full_state_dict_context(
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp.api import FullOptimStateDictConfig
# In PyTorch < 2.1, offload to CPU in combination with `world_size=1` is not possible
offload_to_cpu = world_size > 1 or _TORCH_GREATER_EQUAL_2_1
state_dict_config = FullStateDictConfig(offload_to_cpu=offload_to_cpu, rank0_only=rank0_only)
optim_state_dict_config = FullOptimStateDictConfig(offload_to_cpu=offload_to_cpu, rank0_only=rank0_only)
state_dict_config = FullStateDictConfig(offload_to_cpu=True, rank0_only=rank0_only)
optim_state_dict_config = FullOptimStateDictConfig(offload_to_cpu=True, rank0_only=rank0_only)
state_dict_type_context = FSDP.state_dict_type(
module=module,
state_dict_type=StateDictType.FULL_STATE_DICT,

View File

@ -26,13 +26,10 @@ _IS_WINDOWS = platform.system() == "Windows"
# 2. The inspection mode via `python -i`: https://stackoverflow.com/a/6879085/1162383
_IS_INTERACTIVE = hasattr(sys, "ps1") or bool(sys.flags.interactive)
_TORCH_GREATER_EQUAL_2_1 = compare_version("torch", operator.ge, "2.1.0")
_TORCH_GREATER_EQUAL_2_2 = compare_version("torch", operator.ge, "2.2.0")
_TORCH_GREATER_EQUAL_2_3 = compare_version("torch", operator.ge, "2.3.0")
_TORCH_GREATER_EQUAL_2_4 = compare_version("torch", operator.ge, "2.4.0", use_base_version=True)
_TORCH_EQUAL_2_0 = compare_version("torch", operator.ge, "2.0.0") and not _TORCH_GREATER_EQUAL_2_1
_PYTHON_GREATER_EQUAL_3_8_0 = (sys.version_info.major, sys.version_info.minor) >= (3, 8)
_PYTHON_GREATER_EQUAL_3_10_0 = (sys.version_info.major, sys.version_info.minor) >= (3, 10)

View File

@ -20,7 +20,6 @@ from torch.optim import Optimizer
from torch.overrides import TorchFunctionMode
from typing_extensions import override
from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_1
from lightning.fabric.utilities.rank_zero import rank_zero_warn
from lightning.fabric.utilities.types import _DEVICE
@ -61,8 +60,6 @@ class _EmptyInit(TorchFunctionMode):
def _materialize(module: Module, device: _DEVICE) -> None:
"""Materialize a module."""
if not _TORCH_GREATER_EQUAL_2_1:
raise RuntimeError("recurse=False requires torch 2.1")
module.to_empty(device=device, recurse=False)
if not hasattr(module, "reset_parameters"):
raise TypeError(

View File

@ -24,7 +24,6 @@ from lightning.fabric.accelerators import XLAAccelerator
from lightning.fabric.accelerators.cuda import num_cuda_devices
from lightning.fabric.accelerators.mps import MPSAccelerator
from lightning.fabric.strategies.deepspeed import _DEEPSPEED_AVAILABLE
from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_1
def _runif_reasons(
@ -116,13 +115,9 @@ def _runif_reasons(
reasons.append("Deepspeed")
if dynamo:
if _TORCH_GREATER_EQUAL_2_1:
from torch._dynamo.eval_frame import is_dynamo_supported
from torch._dynamo.eval_frame import is_dynamo_supported
cond = not is_dynamo_supported()
else:
cond = sys.platform == "win32" or sys.version_info >= (3, 11)
if cond:
if not is_dynamo_supported():
reasons.append("torch.dynamo")
return reasons, kwargs

View File

@ -18,7 +18,6 @@ from typing import TYPE_CHECKING, Any, Callable, Deque, Dict, List, Optional, Ty
import torch
from typing_extensions import override
from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_1
from lightning.fabric.utilities.rank_zero import rank_zero_only, rank_zero_warn
if TYPE_CHECKING:
@ -292,8 +291,6 @@ def measure_flops(
FLOPs will be included in the result.
"""
if not _TORCH_GREATER_EQUAL_2_1:
raise ImportError("`measure_flops` requires PyTorch >= 2.1.")
from torch.utils.flop_counter import FlopCounterMode
flop_counter = FlopCounterMode(display=False)

View File

@ -27,7 +27,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
### Removed
-
- Removed support for PyTorch 2.1 ([#20009](https://github.com/Lightning-AI/lightning/pull/20009))
-

View File

@ -51,7 +51,6 @@ from lightning.fabric.loggers import Logger as FabricLogger
from lightning.fabric.utilities.apply_func import convert_to_tensors
from lightning.fabric.utilities.cloud_io import get_filesystem
from lightning.fabric.utilities.device_dtype_mixin import _DeviceDtypeModuleMixin
from lightning.fabric.utilities.imports import _IS_WINDOWS, _TORCH_GREATER_EQUAL_2_1
from lightning.fabric.utilities.types import _MAP_LOCATION_TYPE, _PATH
from lightning.fabric.wrappers import _FabricOptimizer
from lightning.pytorch.callbacks.callback import Callback
@ -67,7 +66,7 @@ from lightning.pytorch.utilities import GradClipAlgorithmType
from lightning.pytorch.utilities.exceptions import MisconfigurationException
from lightning.pytorch.utilities.imports import _TORCHMETRICS_GREATER_EQUAL_0_9_1
from lightning.pytorch.utilities.model_helpers import _restricted_classmethod
from lightning.pytorch.utilities.rank_zero import WarningCache, rank_zero_debug, rank_zero_warn
from lightning.pytorch.utilities.rank_zero import WarningCache, rank_zero_warn
from lightning.pytorch.utilities.signature_utils import is_param_in_hook_signature
from lightning.pytorch.utilities.types import (
_METRIC,
@ -140,7 +139,6 @@ class LightningModule(
self._current_fx_name: Optional[str] = None
self._param_requires_grad_state: Dict[str, bool] = {}
self._metric_attributes: Optional[Dict[int, str]] = None
self._register_sharded_tensor_state_dict_hooks_if_available()
self._compiler_ctx: Optional[Dict[str, Any]] = None
# attributes only used when using fabric
@ -1390,9 +1388,7 @@ class LightningModule(
"""
if not _ONNX_AVAILABLE:
raise ModuleNotFoundError(
f"`torch>=2.0` requires `onnx` to be installed to use `{type(self).__name__}.to_onnx()`"
)
raise ModuleNotFoundError(f"`{type(self).__name__}.to_onnx()` requires `onnx` to be installed.")
mode = self.training
@ -1599,24 +1595,6 @@ class LightningModule(
state["_trainer"] = None
return state
def _register_sharded_tensor_state_dict_hooks_if_available(self) -> None:
"""Adds ShardedTensor state dict hooks if ShardedTensors are supported.
These hooks ensure that ShardedTensors are included when saving, and are loaded the LightningModule correctly.
"""
if _TORCH_GREATER_EQUAL_2_1:
# ShardedTensor is deprecated in favor of DistributedTensor
return
if _IS_WINDOWS or not torch.distributed.is_available():
rank_zero_debug("Could not register sharded tensor state dict hooks")
return
from torch.distributed._shard.sharded_tensor import pre_load_state_dict_hook, state_dict_hook
self._register_state_dict_hook(state_dict_hook)
self._register_load_state_dict_pre_hook(pre_load_state_dict_hook, True)
@contextmanager
def _jit_is_scripting() -> Generator:

View File

@ -85,8 +85,7 @@ class PositionalEncoding(nn.Module):
if self.pe is None:
# 1) can't use buffer, see https://github.com/pytorch/pytorch/issues/68407
# 2) can't use parameter becauses pe gets sliced and DDP requires all params to participate in forward
# 3) can't make it a `requires_grad=False` parameter because FSDP in PyTorch < 2.1 needs all params to
# require grad
# TODO: Could make this a `nn.Parameter` with `requires_grad=False`
self.pe = self._init_pos_encoding(device=x.device)
x + self.pe[: x.size(0), :]

View File

@ -21,7 +21,6 @@ from torch import Tensor
import lightning.pytorch as pl
from lightning.fabric.utilities.distributed import _distributed_is_initialized
from lightning.fabric.utilities.imports import _TORCH_EQUAL_2_0
from lightning.fabric.utilities.warnings import PossibleUserWarning
from lightning.pytorch.accelerators.xla import XLAAccelerator
from lightning.pytorch.callbacks.timer import Timer
@ -171,9 +170,6 @@ def _no_grad_context(loop_run: Callable) -> Callable:
elif isinstance(self.trainer.strategy, FSDPStrategy):
# https://github.com/pytorch/pytorch/issues/95957
context_manager = torch.no_grad
elif _TORCH_EQUAL_2_0 and self.trainer.lightning_module._compiler_ctx is not None:
# avoid: `RuntimeError: Inference tensors do not track version counter` fixed in v2.1
context_manager = torch.no_grad
elif self.inference_mode:
context_manager = torch.inference_mode
else:

View File

@ -67,11 +67,8 @@ from lightning.fabric.utilities.distributed import (
_sync_ddp_if_available,
)
from lightning.fabric.utilities.distributed import group as _group
from lightning.fabric.utilities.imports import (
_TORCH_GREATER_EQUAL_2_1,
_TORCH_GREATER_EQUAL_2_2,
)
from lightning.fabric.utilities.init import _EmptyInit, _has_meta_device_parameters_or_buffers
from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_2
from lightning.fabric.utilities.init import _has_meta_device_parameters_or_buffers
from lightning.fabric.utilities.load import _lazy_load, _materialize_tensors
from lightning.fabric.utilities.optimizer import _optimizers_to_device
from lightning.fabric.utilities.seed import reset_seed
@ -368,8 +365,8 @@ class FSDPStrategy(ParallelStrategy):
invalid_params_error = False
try:
# In PyTorch < 2.0, or if `use_orig_params=False` the user needs to do access
# `self.trainer.model.parameters()` in configure_optimizers()
# If `use_orig_params=False` the user needs to do access `self.trainer.model.parameters()` in
# `configure_optimizers()`
super().setup_optimizers(trainer)
except ValueError as ex:
if "optimizer got an empty parameter list" not in str(ex):
@ -377,7 +374,7 @@ class FSDPStrategy(ParallelStrategy):
invalid_params_error = True
if invalid_params_error or any(not _optimizer_has_flat_params(optimizer) for optimizer in self.optimizers):
# We avoid this limitation in PyTorch >= 2.0 by setting `use_orig_params=True`
# We avoid this limitation by setting `use_orig_params=True`
raise ValueError(
"The optimizer does not seem to reference any FSDP parameters. HINT: Make sure to create the"
" optimizer after setting up the model by referencing `self.trainer.model.parameters()` in the"
@ -393,14 +390,10 @@ class FSDPStrategy(ParallelStrategy):
@contextmanager
@override
def tensor_init_context(self, empty_init: Optional[bool] = None) -> Generator[None, None, None]:
empty_init_context: Union[torch.device, _EmptyInit, nullcontext]
if _TORCH_GREATER_EQUAL_2_1 and empty_init:
# Materialization happens in `setup`. When modules get wrapped by FSDP, the sequence of operations is:
# 1) materialize module 2) call `reset_parameters()` 3) shard the module.
# These operations are applied to each submodule 'bottom up' in the module hierarchy.
empty_init_context = torch.device("meta")
else:
empty_init_context = _EmptyInit(enabled=bool(empty_init))
# Materialization happens in `setup`. When modules get wrapped by FSDP, the sequence of operations is:
# 1) materialize module 2) call `reset_parameters()` 3) shard the module.
# These operations are applied to each submodule 'bottom up' in the module hierarchy.
empty_init_context = torch.device("meta") if empty_init else nullcontext()
with empty_init_context, self.precision_plugin.tensor_init_context():
yield

View File

@ -24,7 +24,6 @@ from typing_extensions import TypedDict, override
from lightning.fabric.utilities import move_data_to_device
from lightning.fabric.utilities.apply_func import convert_tensors_to_scalars
from lightning.fabric.utilities.distributed import _distributed_is_initialized
from lightning.fabric.utilities.imports import _TORCH_EQUAL_2_0
from lightning.pytorch.utilities.data import extract_batch_size
from lightning.pytorch.utilities.exceptions import MisconfigurationException
from lightning.pytorch.utilities.imports import _TORCHMETRICS_GREATER_EQUAL_1_0_0
@ -112,7 +111,7 @@ class _Metadata:
on_step: bool = False
on_epoch: bool = True
# https://github.com/pytorch/pytorch/issues/96197
reduce_fx: Callable = "mean" if _TORCH_EQUAL_2_0 else torch.mean # type: ignore[assignment]
reduce_fx: Callable = torch.mean
enable_graph: bool = False
add_dataloader_idx: bool = True
dataloader_idx: Optional[int] = None
@ -362,7 +361,7 @@ class _ResultCollection(dict):
on_step: bool = False,
on_epoch: bool = True,
# https://github.com/pytorch/pytorch/issues/96197
reduce_fx: Callable = "mean" if _TORCH_EQUAL_2_0 else torch.mean, # type: ignore[assignment]
reduce_fx: Callable = torch.mean,
enable_graph: bool = False,
sync_dist: bool = False,
sync_dist_fn: Callable = _Sync.no_op,

View File

@ -17,7 +17,6 @@ import torch
from torch._dynamo import OptimizedModule
import lightning.pytorch as pl
from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_1
from lightning.pytorch.strategies import DDPStrategy, DeepSpeedStrategy, FSDPStrategy, SingleDeviceStrategy, Strategy
from lightning.pytorch.utilities.model_helpers import _check_mixed_imports
@ -56,11 +55,7 @@ def from_compiled(model: OptimizedModule) -> "pl.LightningModule":
}
orig_module.forward = model.dynamo_ctx(orig_module.forward) # type: ignore[method-assign]
if not _TORCH_GREATER_EQUAL_2_1: # https://github.com/pytorch/pytorch/issues/95630
orig_module.forward._torchdynamo_inline = orig_module.forward
orig_module.training_step = model.dynamo_ctx(orig_module.training_step) # type: ignore[method-assign]
if not _TORCH_GREATER_EQUAL_2_1: # https://github.com/pytorch/pytorch/issues/95630
orig_module.training_step._torchdynamo_inline = orig_module.training_step
orig_module.validation_step = model.dynamo_ctx(orig_module.validation_step) # type: ignore[method-assign]
orig_module.test_step = model.dynamo_ctx(orig_module.test_step) # type: ignore[method-assign]
orig_module.predict_step = model.dynamo_ctx(orig_module.predict_step) # type: ignore[method-assign]

View File

@ -101,7 +101,10 @@ def thread_police_duuu_daaa_duuu_daaa():
assert not thread.is_alive()
elif isinstance(thread, _ChildProcessObserver):
thread.join(timeout=10)
elif thread.name == "QueueFeederThread": # tensorboardX
elif (
thread.name == "QueueFeederThread" # tensorboardX
or thread.name == "QueueManagerThread" # torch.compile
):
thread.join(timeout=20)
elif (
sys.version_info >= (3, 9)

View File

@ -148,7 +148,7 @@ def test_bitsandbytes_layers(args, expected):
assert model.l.weight.dtype == expected
@RunIf(min_cuda_gpus=1, min_torch="2.1")
@RunIf(min_cuda_gpus=1)
@pytest.mark.skipif(not _BITSANDBYTES_AVAILABLE, reason="bitsandbytes unavailable")
@pytest.mark.parametrize(
("args", "expected"),
@ -232,7 +232,7 @@ def test_bitsandbytes_layers_meta_device(args, expected, tmp_path):
assert model.l.weight.dtype == expected
@RunIf(min_cuda_gpus=1, min_torch="2.1")
@RunIf(min_cuda_gpus=1)
@pytest.mark.skipif(not _BITSANDBYTES_AVAILABLE, reason="bitsandbytes unavailable")
def test_load_quantized_checkpoint(tmp_path):
"""Test that a checkpoint saved from a quantized model can be loaded back into a quantized model."""

View File

@ -75,7 +75,7 @@ def _run_ddp_save_load(fabric, tmp_path):
assert_params_equal(params_before, wrapped_model.parameters())
@RunIf(min_cuda_gpus=2, standalone=True, min_torch="2.1.0", dynamo=True)
@RunIf(min_cuda_gpus=2, standalone=True, dynamo=True)
@mock.patch("lightning.fabric.wrappers.torch.compile", Mock(wraps=torch.compile))
@mock.patch.dict(os.environ, {})
def test_reapply_compile():

View File

@ -16,7 +16,6 @@ from re import escape
from unittest import mock
from unittest.mock import ANY, MagicMock, Mock
import lightning.fabric
import pytest
import torch
import torch.nn as nn
@ -28,8 +27,9 @@ from lightning.fabric.strategies.fsdp import (
_get_full_state_dict_context,
_is_sharded_checkpoint,
)
from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_1, _TORCH_GREATER_EQUAL_2_2
from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_2
from torch.distributed.fsdp.fully_sharded_data_parallel import CPUOffload, FullyShardedDataParallel, MixedPrecision
from torch.distributed.fsdp.wrap import ModuleWrapPolicy
from torch.optim import Adam
@ -147,13 +147,6 @@ def test_no_backward_sync():
module.no_sync.assert_called_once()
def test_activation_checkpointing_support(monkeypatch):
"""Test that we error out if activation checkpointing requires a newer PyTorch version."""
monkeypatch.setattr(lightning.fabric.strategies.fsdp, "_TORCH_GREATER_EQUAL_2_1", False)
with pytest.raises(ValueError, match="activation_checkpointing_policy` requires torch >= 2.1.0"):
FSDPStrategy(activation_checkpointing_policy=Mock())
def test_activation_checkpointing():
"""Test that the FSDP strategy can apply activation checkpointing to the given layers."""
@ -170,28 +163,13 @@ def test_activation_checkpointing():
self.layer1 = Block2(2, 2)
self.layer2 = nn.Linear(3, 3)
if _TORCH_GREATER_EQUAL_2_1:
from torch.distributed.fsdp.wrap import ModuleWrapPolicy
strategy = FSDPStrategy(activation_checkpointing_policy={Block1})
assert set(strategy._activation_checkpointing_kwargs) == {"auto_wrap_policy"}
assert isinstance(strategy._activation_checkpointing_kwargs["auto_wrap_policy"], ModuleWrapPolicy)
strategy = FSDPStrategy(activation_checkpointing_policy={Block1})
assert set(strategy._activation_checkpointing_kwargs) == {"auto_wrap_policy"}
assert isinstance(strategy._activation_checkpointing_kwargs["auto_wrap_policy"], ModuleWrapPolicy)
strategy = FSDPStrategy(activation_checkpointing_policy=ModuleWrapPolicy({Block1, Block2}))
assert set(strategy._activation_checkpointing_kwargs) == {"auto_wrap_policy"}
assert isinstance(strategy._activation_checkpointing_kwargs["auto_wrap_policy"], ModuleWrapPolicy)
else:
strategy = FSDPStrategy(activation_checkpointing=Block1)
assert set(strategy._activation_checkpointing_kwargs) == {"check_fn"}
strategy = FSDPStrategy(activation_checkpointing=[Block1, Block2])
assert set(strategy._activation_checkpointing_kwargs) == {"check_fn"}
strategy = FSDPStrategy(activation_checkpointing_policy={Block1})
assert set(strategy._activation_checkpointing_kwargs) == {"check_fn"}
strategy = FSDPStrategy(activation_checkpointing_policy={Block1, Block2})
assert set(strategy._activation_checkpointing_kwargs) == {"check_fn"}
strategy = FSDPStrategy(activation_checkpointing_policy=ModuleWrapPolicy({Block1, Block2}))
assert set(strategy._activation_checkpointing_kwargs) == {"auto_wrap_policy"}
assert isinstance(strategy._activation_checkpointing_kwargs["auto_wrap_policy"], ModuleWrapPolicy)
strategy._parallel_devices = [torch.device("cuda", 0)]
with mock.patch("torch.distributed.fsdp.FullyShardedDataParallel", new=MagicMock), mock.patch(
@ -401,15 +379,13 @@ def test_set_timeout(init_process_group_mock):
)
@pytest.mark.parametrize("torch_ge_2_1", [True, False])
@mock.patch("torch.distributed.fsdp.fully_sharded_data_parallel.FullyShardedDataParallel.set_state_dict_type")
def test_get_full_state_dict_context_offload(set_type_mock, monkeypatch, torch_ge_2_1):
"""Test that the state dict context manager handles CPU offloading depending on the PyTorch version."""
monkeypatch.setattr("lightning.fabric.strategies.fsdp._TORCH_GREATER_EQUAL_2_1", torch_ge_2_1)
def test_get_full_state_dict_context_offload(set_type_mock, monkeypatch):
"""Test that the state dict context manager handles CPU offloading."""
with _get_full_state_dict_context(module=Mock(spec=FullyShardedDataParallel), world_size=1):
assert set_type_mock.call_args_list[0][0][2].offload_to_cpu is torch_ge_2_1 # model config
assert set_type_mock.call_args_list[0][0][3].offload_to_cpu is torch_ge_2_1 # optim config
assert set_type_mock.call_args_list[0][0][2].offload_to_cpu # model config
assert set_type_mock.call_args_list[0][0][3].offload_to_cpu # optim config
set_type_mock.reset_mock()

View File

@ -23,7 +23,6 @@ import torch.nn as nn
from lightning.fabric import Fabric
from lightning.fabric.plugins import FSDPPrecision
from lightning.fabric.strategies import FSDPStrategy
from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_1
from lightning.fabric.utilities.load import _load_distributed_checkpoint
from lightning.fabric.wrappers import _FabricOptimizer
from torch._dynamo import OptimizedModule
@ -400,7 +399,7 @@ def test_setup_with_orig_params_and_multiple_param_groups():
assert not isinstance(layer.weight, FlatParameter)
@RunIf(min_cuda_gpus=2, standalone=True, min_torch="2.1.0", dynamo=True, skip_windows=True)
@RunIf(min_cuda_gpus=2, standalone=True, dynamo=True, skip_windows=True)
@mock.patch("lightning.fabric.wrappers.torch.compile", Mock(wraps=torch.compile))
@mock.patch.dict(os.environ, {})
def test_reapply_compile():
@ -466,12 +465,8 @@ def test_module_init_context(precision, expected_dtype):
# Case 1: No empty init
_run_setup_assertions(empty_init=False, expected_device=torch.device("cpu"))
if _TORCH_GREATER_EQUAL_2_1:
# Case 2: Empty-init with PyTorch >= 2.1 supports meta device
_run_setup_assertions(empty_init=True, expected_device=torch.device("meta"))
else:
# Case 2: Empty-init with PyTorch < 2.1 only supports `torch.empty()`-init
_run_setup_assertions(empty_init=True, expected_device=torch.device("cpu"))
# Case 2: Empty-init with meta device
_run_setup_assertions(empty_init=True, expected_device=torch.device("meta"))
@RunIf(min_cuda_gpus=2, standalone=True)
@ -538,9 +533,6 @@ def test_rewrap_warnings():
assert not isinstance(model._forward_module, FullyShardedDataParallel)
assert isinstance(model._forward_module[2], FullyShardedDataParallel)
if not _TORCH_GREATER_EQUAL_2_1:
return
with fabric.init_module(empty_init=True):
model = torch.nn.Sequential(torch.nn.Linear(1, 1), torch.nn.ReLU(), wrap(torch.nn.Linear(1, 1)))
assert model[0].weight.is_meta

View File

@ -289,7 +289,7 @@ def test_setup_optimizers_not_supported(strategy_cls):
fabric.setup_optimizers(optimizer)
@RunIf(min_cuda_gpus=1, min_torch="2.1")
@RunIf(min_cuda_gpus=1)
def test_setup_optimizer_on_meta_device():
"""Test that the setup-methods validate that the optimizer doesn't have references to meta-device parameters."""
fabric = Fabric(strategy="fsdp", devices=1)
@ -867,8 +867,6 @@ def test_init_module_context(monkeypatch):
def test_init_tensor_context(monkeypatch):
"""Test that `.init_tensor()` warns if using PyTorch < 2.0."""
fabric = Fabric(accelerator="cpu")
strategy = SingleDeviceStrategy(device=torch.device("cuda"))
strategy.tensor_init_context = Mock(wraps=strategy.tensor_init_context)

View File

@ -19,7 +19,6 @@ import torch
from lightning.fabric.fabric import Fabric
from lightning.fabric.plugins import Precision
from lightning.fabric.utilities.device_dtype_mixin import _DeviceDtypeModuleMixin
from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_1
from lightning.fabric.wrappers import (
_FabricDataLoader,
_FabricModule,
@ -268,14 +267,13 @@ def test_fabric_module_state_dict_access():
assert torch.equal(fabric_module.layer.weight, weight)
assert torch.equal(fabric_module.layer.bias, bias)
if _TORCH_GREATER_EQUAL_2_1:
# Can use additional `assign` argument in PyTorch >= 2.1
with torch.device("meta"):
original_module = OriginalModule()
fabric_module = _FabricModule(wrapped_module, Mock(), original_module=original_module)
assert fabric_module.layer.weight.is_meta
fabric_module.load_state_dict({"layer.weight": weight, "layer.bias": bias}, assign=True)
assert not fabric_module.layer.weight.is_meta
# Can use additional `assign` argument
with torch.device("meta"):
original_module = OriginalModule()
fabric_module = _FabricModule(wrapped_module, Mock(), original_module=original_module)
assert fabric_module.layer.weight.is_meta
fabric_module.load_state_dict({"layer.weight": weight, "layer.bias": bias}, assign=True)
assert not fabric_module.layer.weight.is_meta
@pytest.mark.parametrize(

View File

@ -58,7 +58,6 @@ def test_empty_init_speed():
assert normal_init_time > 2 * empty_init_time
@RunIf(min_torch="2.1")
def test_materialize_meta_tensors():
class Submodule(torch.nn.Module):
def __init__(self):

View File

@ -13,11 +13,9 @@ from lightning.fabric.utilities.throughput import (
measure_flops,
)
from tests_fabric.helpers.runif import RunIf
from tests_fabric.test_fabric import BoringModel
@RunIf(min_torch="2.1")
def test_measure_flops():
with torch.device("meta"):
model = BoringModel()

View File

@ -8,10 +8,7 @@ from lightning.pytorch import Trainer
from lightning.pytorch.callbacks.throughput_monitor import ThroughputMonitor
from lightning.pytorch.demos.boring_classes import BoringModel
from tests_pytorch.helpers.runif import RunIf
@RunIf(min_torch="2.1")
def test_measure_flops():
with torch.device("meta"):
model = BoringModel()

View File

@ -153,7 +153,10 @@ def thread_police_duuu_daaa_duuu_daaa():
assert not thread.is_alive()
elif isinstance(thread, _ChildProcessObserver):
thread.join(timeout=10)
elif thread.name == "QueueFeederThread": # tensorboardX
elif (
thread.name == "QueueFeederThread" # tensorboardX
or thread.name == "QueueManagerThread" # torch.compile
):
thread.join(timeout=20)
elif isinstance(thread, TMonitor):
thread.exit()

View File

@ -35,14 +35,10 @@ def test_ddp_is_distributed():
_ = strategy.is_distributed
def test_fsdp_activation_checkpointing(monkeypatch):
def test_fsdp_activation_checkpointing():
with pytest.raises(ValueError, match="cannot set both `activation_checkpointing"):
FSDPStrategy(activation_checkpointing=torch.nn.Linear, activation_checkpointing_policy=lambda *_: True)
monkeypatch.setattr(lightning.fabric.strategies.fsdp, "_TORCH_GREATER_EQUAL_2_1", True)
with pytest.deprecated_call(match=r"use `FSDPStrategy\(activation_checkpointing_policy"):
FSDPStrategy(activation_checkpointing=torch.nn.Linear)
def test_double_precision_wrapper():
with pytest.deprecated_call(match=r"The `LightningDoublePrecisionModule` is deprecated and no longer needed"):

View File

@ -14,7 +14,7 @@ import torch
import torch.nn as nn
from lightning.fabric.plugins.environments import LightningEnvironment
from lightning.fabric.strategies.fsdp import _is_sharded_checkpoint
from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_1, _TORCH_GREATER_EQUAL_2_2
from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_2
from lightning.fabric.utilities.load import _load_distributed_checkpoint
from lightning.pytorch import Trainer
from lightning.pytorch.callbacks import ModelCheckpoint
@ -334,10 +334,9 @@ def test_strategy_full_state_dict(tmp_path, wrap_min_params):
TestFSDPModelAutoWrapped(),
FSDPStrategy,
{
"auto_wrap_policy": ModuleWrapPolicy({nn.Linear}) if _TORCH_GREATER_EQUAL_2_1 else None,
"auto_wrap_policy": ModuleWrapPolicy({nn.Linear}),
"use_orig_params": True,
},
marks=RunIf(min_torch="2.1.0"),
id="autowrap_use_orig_params",
),
],
@ -380,19 +379,12 @@ def test_invalid_parameters_in_optimizer(use_orig_params):
fast_dev_run=1,
)
error_context = (
nullcontext()
if _TORCH_GREATER_EQUAL_2_1 or use_orig_params is not False
else pytest.raises(ValueError, match="The optimizer does not seem to reference any FSDP parameters")
)
class EmptyParametersModel(BoringModel):
def configure_optimizers(self):
return torch.optim.Adam(self.parameters(), lr=1e-2)
model = EmptyParametersModel()
with error_context:
trainer.fit(model)
trainer.fit(model)
class NoFlatParametersModel(BoringModel):
def configure_optimizers(self):
@ -435,28 +427,13 @@ def test_activation_checkpointing():
self.layer1 = Block2(2, 2)
self.layer2 = nn.Linear(3, 3)
if _TORCH_GREATER_EQUAL_2_1:
from torch.distributed.fsdp.wrap import ModuleWrapPolicy
strategy = FSDPStrategy(activation_checkpointing_policy={Block1})
assert set(strategy._activation_checkpointing_kwargs) == {"auto_wrap_policy"}
assert isinstance(strategy._activation_checkpointing_kwargs["auto_wrap_policy"], ModuleWrapPolicy)
strategy = FSDPStrategy(activation_checkpointing_policy={Block1})
assert set(strategy._activation_checkpointing_kwargs) == {"auto_wrap_policy"}
assert isinstance(strategy._activation_checkpointing_kwargs["auto_wrap_policy"], ModuleWrapPolicy)
strategy = FSDPStrategy(activation_checkpointing_policy=ModuleWrapPolicy({Block1, Block2}))
assert set(strategy._activation_checkpointing_kwargs) == {"auto_wrap_policy"}
assert isinstance(strategy._activation_checkpointing_kwargs["auto_wrap_policy"], ModuleWrapPolicy)
else:
strategy = FSDPStrategy(activation_checkpointing=Block1)
assert set(strategy._activation_checkpointing_kwargs) == {"check_fn"}
strategy = FSDPStrategy(activation_checkpointing=[Block1, Block2])
assert set(strategy._activation_checkpointing_kwargs) == {"check_fn"}
strategy = FSDPStrategy(activation_checkpointing_policy={Block1})
assert set(strategy._activation_checkpointing_kwargs) == {"check_fn"}
strategy = FSDPStrategy(activation_checkpointing_policy={Block1, Block2})
assert set(strategy._activation_checkpointing_kwargs) == {"check_fn"}
strategy = FSDPStrategy(activation_checkpointing_policy=ModuleWrapPolicy({Block1, Block2}))
assert set(strategy._activation_checkpointing_kwargs) == {"auto_wrap_policy"}
assert isinstance(strategy._activation_checkpointing_kwargs["auto_wrap_policy"], ModuleWrapPolicy)
model = Model()
strategy._parallel_devices = [torch.device("cuda", 0)]
@ -608,7 +585,7 @@ def test_strategy_save_optimizer_states(tmp_path, wrap_min_params):
if trainer.global_rank != 0:
assert len(model_state_dict) == 0
if trainer.global_rank != 0 and _TORCH_GREATER_EQUAL_2_1:
if trainer.global_rank != 0:
assert len(optimizer_state_dict) == 0
# restore model to ddp
@ -679,7 +656,7 @@ def test_strategy_load_optimizer_states(wrap_min_params, tmp_path):
if trainer.global_rank != 0:
assert len(restored_model_state_dict) == 0
if trainer.global_rank != 0 and _TORCH_GREATER_EQUAL_2_1:
if trainer.global_rank != 0:
assert len(restored_optimizer_state_dict) == 0
if trainer.global_rank == 0:
@ -936,12 +913,8 @@ def test_module_init_context(precision, expected_dtype, tmp_path):
# Case 1: No empty init
_run_setup_assertions(empty_init=False, expected_device=torch.device("cpu"))
if _TORCH_GREATER_EQUAL_2_1:
# Case 2: Empty-init with PyTorch >= 2.1 supports meta device
_run_setup_assertions(empty_init=True, expected_device=torch.device("meta"))
else:
# Case 2: Empty-init with PyTorch < 2.1 only supports `torch.empty()`-init
_run_setup_assertions(empty_init=True, expected_device=torch.device("cpu"))
# Case 2: Empty-init with meta device
_run_setup_assertions(empty_init=True, expected_device=torch.device("meta"))
@RunIf(min_cuda_gpus=2, standalone=True, min_torch="2.3.0")

View File

@ -339,7 +339,7 @@ def test_module_init_context(precision, expected_dtype, tmp_path):
# Case 1: No empty init
_run_setup_assertions(empty_init=False, expected_device=torch.device("cpu"))
# Case 2: Empty-init with PyTorch >= 2.1 supports meta device
# Case 2: Empty-init with meta device
_run_setup_assertions(empty_init=True, expected_device=torch.device("meta"))

View File

@ -16,7 +16,6 @@ from unittest.mock import Mock
import pytest
import torch
from lightning.fabric.utilities.imports import _TORCH_EQUAL_2_0
from lightning.pytorch import Trainer
from lightning.pytorch.demos.boring_classes import BoringModel
from lightning.pytorch.loops import _Loop
@ -81,7 +80,5 @@ def test_no_grad_context():
f.run()
no_grad_mock.assert_called_once_with()
f.inference_mode = True
with mock.patch("torch.inference_mode") as inference_mode_mock:
with mock.patch("torch.inference_mode"):
f.run()
if not _TORCH_EQUAL_2_0:
inference_mode_mock.assert_called_once_with()