Drop PyTorch 2.0 from the test matrix (#20009)
This commit is contained in:
parent
5636fe4a9c
commit
14493c0685
|
@ -19,29 +19,23 @@ subprojects:
|
|||
- "!*.md"
|
||||
- "!**/*.md"
|
||||
checks:
|
||||
- "pl-cpu (macOS-13, lightning, 3.8, 2.0, oldest)"
|
||||
- "pl-cpu (macOS-14, lightning, 3.10, 2.0)"
|
||||
- "pl-cpu (macOS-13, lightning, 3.8, 2.1, oldest)"
|
||||
- "pl-cpu (macOS-14, lightning, 3.10, 2.1)"
|
||||
- "pl-cpu (macOS-14, lightning, 3.10, 2.2)"
|
||||
- "pl-cpu (macOS-14, lightning, 3.10, 2.3)"
|
||||
- "pl-cpu (ubuntu-20.04, lightning, 3.8, 2.0, oldest)"
|
||||
- "pl-cpu (ubuntu-20.04, lightning, 3.10, 2.0)"
|
||||
- "pl-cpu (ubuntu-20.04, lightning, 3.8, 2.1, oldest)"
|
||||
- "pl-cpu (ubuntu-20.04, lightning, 3.10, 2.1)"
|
||||
- "pl-cpu (ubuntu-20.04, lightning, 3.10, 2.2)"
|
||||
- "pl-cpu (ubuntu-20.04, lightning, 3.10, 2.3)"
|
||||
- "pl-cpu (windows-2022, lightning, 3.8, 2.0, oldest)"
|
||||
- "pl-cpu (windows-2022, lightning, 3.10, 2.0)"
|
||||
- "pl-cpu (windows-2022, lightning, 3.8, 2.1, oldest)"
|
||||
- "pl-cpu (windows-2022, lightning, 3.10, 2.1)"
|
||||
- "pl-cpu (windows-2022, lightning, 3.10, 2.2)"
|
||||
- "pl-cpu (windows-2022, lightning, 3.10, 2.3)"
|
||||
- "pl-cpu (macOS-14, pytorch, 3.8, 2.0)"
|
||||
- "pl-cpu (ubuntu-20.04, pytorch, 3.8, 2.0)"
|
||||
- "pl-cpu (windows-2022, pytorch, 3.8, 2.0)"
|
||||
- "pl-cpu (macOS-12, pytorch, 3.11, 2.0)"
|
||||
- "pl-cpu (macOS-14, pytorch, 3.8, 2.1)"
|
||||
- "pl-cpu (ubuntu-20.04, pytorch, 3.8, 2.1)"
|
||||
- "pl-cpu (windows-2022, pytorch, 3.8, 2.1)"
|
||||
- "pl-cpu (macOS-12, pytorch, 3.11, 2.1)"
|
||||
- "pl-cpu (ubuntu-22.04, pytorch, 3.11, 2.0)"
|
||||
- "pl-cpu (ubuntu-22.04, pytorch, 3.11, 2.1)"
|
||||
- "pl-cpu (windows-2022, pytorch, 3.11, 2.0)"
|
||||
- "pl-cpu (windows-2022, pytorch, 3.11, 2.1)"
|
||||
|
||||
- id: "pytorch_lightning: Azure GPU"
|
||||
|
@ -144,13 +138,11 @@ subprojects:
|
|||
- "!*.md"
|
||||
- "!**/*.md"
|
||||
checks:
|
||||
- "build-cuda (3.10, 2.0, 11.8.0)"
|
||||
- "build-cuda (3.10, 2.1, 12.1.0)"
|
||||
- "build-cuda (3.10, 2.2, 12.1.0)"
|
||||
- "build-cuda (3.11, 2.1, 12.1.0)"
|
||||
- "build-cuda (3.11, 2.2, 12.1.0)"
|
||||
#- "build-NGC"
|
||||
- "build-pl (3.10, 2.0, 11.8.0)"
|
||||
- "build-pl (3.10, 2.1, 12.1.0)"
|
||||
- "build-pl (3.10, 2.2, 12.1.0)"
|
||||
- "build-pl (3.11, 2.1, 12.1.0)"
|
||||
|
@ -171,29 +163,23 @@ subprojects:
|
|||
- "!*.md"
|
||||
- "!**/*.md"
|
||||
checks:
|
||||
- "fabric-cpu (macOS-13, lightning, 3.8, 2.0, oldest)"
|
||||
- "fabric-cpu (macOS-14, lightning, 3.10, 2.0)"
|
||||
- "fabric-cpu (macOS-13, lightning, 3.8, 2.1, oldest)"
|
||||
- "fabric-cpu (macOS-14, lightning, 3.11, 2.1)"
|
||||
- "fabric-cpu (macOS-14, lightning, 3.11, 2.2)"
|
||||
- "fabric-cpu (macOS-14, lightning, 3.10, 2.3)"
|
||||
- "fabric-cpu (ubuntu-20.04, lightning, 3.8, 2.0, oldest)"
|
||||
- "fabric-cpu (ubuntu-20.04, lightning, 3.10, 2.0)"
|
||||
- "fabric-cpu (ubuntu-20.04, lightning, 3.8, 2.1, oldest)"
|
||||
- "fabric-cpu (ubuntu-20.04, lightning, 3.11, 2.1)"
|
||||
- "fabric-cpu (ubuntu-20.04, lightning, 3.11, 2.2)"
|
||||
- "fabric-cpu (ubuntu-20.04, lightning, 3.11, 2.3)"
|
||||
- "fabric-cpu (windows-2022, lightning, 3.8, 2.0, oldest)"
|
||||
- "fabric-cpu (windows-2022, lightning, 3.10, 2.0)"
|
||||
- "fabric-cpu (windows-2022, lightning, 3.8, 2.1, oldest)"
|
||||
- "fabric-cpu (windows-2022, lightning, 3.11, 2.1)"
|
||||
- "fabric-cpu (windows-2022, lightning, 3.11, 2.2)"
|
||||
- "fabric-cpu (windows-2022, lightning, 3.11, 2.3)"
|
||||
- "fabric-cpu (macOS-14, fabric, 3.8, 2.0)"
|
||||
- "fabric-cpu (ubuntu-20.04, fabric, 3.8, 2.0)"
|
||||
- "fabric-cpu (windows-2022, fabric, 3.8, 2.0)"
|
||||
- "fabric-cpu (macOS-12, fabric, 3.11, 2.0)"
|
||||
- "fabric-cpu (macOS-14, fabric, 3.8, 2.1)"
|
||||
- "fabric-cpu (ubuntu-20.04, fabric, 3.8, 2.1)"
|
||||
- "fabric-cpu (windows-2022, fabric, 3.8, 2.1)"
|
||||
- "fabric-cpu (macOS-12, fabric, 3.11, 2.1)"
|
||||
- "fabric-cpu (ubuntu-22.04, fabric, 3.11, 2.0)"
|
||||
- "fabric-cpu (ubuntu-22.04, fabric, 3.11, 2.1)"
|
||||
- "fabric-cpu (windows-2022, fabric, 3.11, 2.0)"
|
||||
- "fabric-cpu (windows-2022, fabric, 3.11, 2.1)"
|
||||
|
||||
- id: "lightning_fabric: Azure GPU"
|
||||
|
|
|
@ -39,9 +39,6 @@ jobs:
|
|||
fail-fast: false
|
||||
matrix:
|
||||
include:
|
||||
- { os: "macOS-14", pkg-name: "lightning", python-version: "3.10", pytorch-version: "2.0" }
|
||||
- { os: "ubuntu-20.04", pkg-name: "lightning", python-version: "3.10", pytorch-version: "2.0" }
|
||||
- { os: "windows-2022", pkg-name: "lightning", python-version: "3.10", pytorch-version: "2.0" }
|
||||
# only run PyTorch latest
|
||||
- { os: "macOS-14", pkg-name: "lightning", python-version: "3.11", pytorch-version: "2.1" }
|
||||
- { os: "ubuntu-20.04", pkg-name: "lightning", python-version: "3.11", pytorch-version: "2.1" }
|
||||
|
@ -53,32 +50,29 @@ jobs:
|
|||
- { os: "ubuntu-20.04", pkg-name: "lightning", python-version: "3.11", pytorch-version: "2.3" }
|
||||
- { os: "windows-2022", pkg-name: "lightning", python-version: "3.11", pytorch-version: "2.3" }
|
||||
# only run PyTorch latest with Python latest, use Fabric scope to limit dependency issues
|
||||
- { os: "macOS-12", pkg-name: "fabric", python-version: "3.11", pytorch-version: "2.0" }
|
||||
- { os: "ubuntu-22.04", pkg-name: "fabric", python-version: "3.11", pytorch-version: "2.0" }
|
||||
- { os: "windows-2022", pkg-name: "fabric", python-version: "3.11", pytorch-version: "2.0" }
|
||||
- { os: "macOS-12", pkg-name: "fabric", python-version: "3.11", pytorch-version: "2.1" }
|
||||
- { os: "ubuntu-22.04", pkg-name: "fabric", python-version: "3.11", pytorch-version: "2.1" }
|
||||
- { os: "windows-2022", pkg-name: "fabric", python-version: "3.11", pytorch-version: "2.1" }
|
||||
# "oldest" versions tests, only on minimum Python
|
||||
- { os: "macOS-13", pkg-name: "lightning", python-version: "3.8", pytorch-version: "2.0", requires: "oldest" }
|
||||
- { os: "macOS-13", pkg-name: "lightning", python-version: "3.8", pytorch-version: "2.1", requires: "oldest" }
|
||||
- {
|
||||
os: "ubuntu-20.04",
|
||||
pkg-name: "lightning",
|
||||
python-version: "3.8",
|
||||
pytorch-version: "2.0",
|
||||
pytorch-version: "2.1",
|
||||
requires: "oldest",
|
||||
}
|
||||
- {
|
||||
os: "windows-2022",
|
||||
pkg-name: "lightning",
|
||||
python-version: "3.8",
|
||||
pytorch-version: "2.0",
|
||||
pytorch-version: "2.1",
|
||||
requires: "oldest",
|
||||
}
|
||||
# "fabric" installs the standalone package
|
||||
- { os: "macOS-14", pkg-name: "fabric", python-version: "3.8", pytorch-version: "2.0" }
|
||||
- { os: "ubuntu-20.04", pkg-name: "fabric", python-version: "3.8", pytorch-version: "2.0" }
|
||||
- { os: "windows-2022", pkg-name: "fabric", python-version: "3.8", pytorch-version: "2.0" }
|
||||
- { os: "macOS-14", pkg-name: "fabric", python-version: "3.8", pytorch-version: "2.1" }
|
||||
- { os: "ubuntu-20.04", pkg-name: "fabric", python-version: "3.8", pytorch-version: "2.1" }
|
||||
- { os: "windows-2022", pkg-name: "fabric", python-version: "3.8", pytorch-version: "2.1" }
|
||||
timeout-minutes: 25 # because of building grpcio on Mac
|
||||
env:
|
||||
PACKAGE_NAME: ${{ matrix.pkg-name }}
|
||||
|
|
|
@ -43,9 +43,6 @@ jobs:
|
|||
fail-fast: false
|
||||
matrix:
|
||||
include:
|
||||
- { os: "macOS-14", pkg-name: "lightning", python-version: "3.10", pytorch-version: "2.0" }
|
||||
- { os: "ubuntu-20.04", pkg-name: "lightning", python-version: "3.10", pytorch-version: "2.0" }
|
||||
- { os: "windows-2022", pkg-name: "lightning", python-version: "3.10", pytorch-version: "2.0" }
|
||||
# only run PyTorch latest
|
||||
- { os: "macOS-14", pkg-name: "lightning", python-version: "3.10", pytorch-version: "2.1" }
|
||||
- { os: "ubuntu-20.04", pkg-name: "lightning", python-version: "3.10", pytorch-version: "2.1" }
|
||||
|
@ -57,32 +54,29 @@ jobs:
|
|||
- { os: "ubuntu-20.04", pkg-name: "lightning", python-version: "3.10", pytorch-version: "2.3" }
|
||||
- { os: "windows-2022", pkg-name: "lightning", python-version: "3.10", pytorch-version: "2.3" }
|
||||
# only run PyTorch latest with Python latest, use PyTorch scope to limit dependency issues
|
||||
- { os: "macOS-12", pkg-name: "pytorch", python-version: "3.11", pytorch-version: "2.0" }
|
||||
- { os: "ubuntu-22.04", pkg-name: "pytorch", python-version: "3.11", pytorch-version: "2.0" }
|
||||
- { os: "windows-2022", pkg-name: "pytorch", python-version: "3.11", pytorch-version: "2.0" }
|
||||
- { os: "macOS-12", pkg-name: "pytorch", python-version: "3.11", pytorch-version: "2.1" }
|
||||
- { os: "ubuntu-22.04", pkg-name: "pytorch", python-version: "3.11", pytorch-version: "2.1" }
|
||||
- { os: "windows-2022", pkg-name: "pytorch", python-version: "3.11", pytorch-version: "2.1" }
|
||||
# "oldest" versions tests, only on minimum Python
|
||||
- { os: "macOS-13", pkg-name: "lightning", python-version: "3.8", pytorch-version: "2.0", requires: "oldest" }
|
||||
- { os: "macOS-13", pkg-name: "lightning", python-version: "3.8", pytorch-version: "2.1", requires: "oldest" }
|
||||
- {
|
||||
os: "ubuntu-20.04",
|
||||
pkg-name: "lightning",
|
||||
python-version: "3.8",
|
||||
pytorch-version: "2.0",
|
||||
pytorch-version: "2.1",
|
||||
requires: "oldest",
|
||||
}
|
||||
- {
|
||||
os: "windows-2022",
|
||||
pkg-name: "lightning",
|
||||
python-version: "3.8",
|
||||
pytorch-version: "2.0",
|
||||
pytorch-version: "2.1",
|
||||
requires: "oldest",
|
||||
}
|
||||
# "pytorch" installs the standalone package
|
||||
- { os: "macOS-14", pkg-name: "pytorch", python-version: "3.8", pytorch-version: "2.0" }
|
||||
- { os: "ubuntu-20.04", pkg-name: "pytorch", python-version: "3.8", pytorch-version: "2.0" }
|
||||
- { os: "windows-2022", pkg-name: "pytorch", python-version: "3.8", pytorch-version: "2.0" }
|
||||
- { os: "macOS-14", pkg-name: "pytorch", python-version: "3.8", pytorch-version: "2.1" }
|
||||
- { os: "ubuntu-20.04", pkg-name: "pytorch", python-version: "3.8", pytorch-version: "2.1" }
|
||||
- { os: "windows-2022", pkg-name: "pytorch", python-version: "3.8", pytorch-version: "2.1" }
|
||||
timeout-minutes: 50
|
||||
env:
|
||||
PACKAGE_NAME: ${{ matrix.pkg-name }}
|
||||
|
|
|
@ -43,7 +43,6 @@ jobs:
|
|||
include:
|
||||
# We only release one docker image per PyTorch version.
|
||||
# Make sure the matrix here matches the one below.
|
||||
- { python_version: "3.10", pytorch_version: "2.0", cuda_version: "11.8.0" }
|
||||
- { python_version: "3.10", pytorch_version: "2.1", cuda_version: "12.1.0" }
|
||||
- { python_version: "3.10", pytorch_version: "2.2", cuda_version: "12.1.0" }
|
||||
- { python_version: "3.11", pytorch_version: "2.1", cuda_version: "12.1.0" }
|
||||
|
@ -104,7 +103,6 @@ jobs:
|
|||
include:
|
||||
# These are the base images for PL release docker images.
|
||||
# Make sure the matrix here matches the one above.
|
||||
- { python_version: "3.10", pytorch_version: "2.0", cuda_version: "11.8.0" }
|
||||
- { python_version: "3.10", pytorch_version: "2.1", cuda_version: "12.1.0" }
|
||||
- { python_version: "3.10", pytorch_version: "2.2", cuda_version: "12.1.0" }
|
||||
- { python_version: "3.11", pytorch_version: "2.1", cuda_version: "12.1.0" }
|
||||
|
|
|
@ -5,10 +5,6 @@ Speed up models by compiling them
|
|||
Compiling your PyTorch model can result in significant speedups, especially on the latest generations of GPUs.
|
||||
This guide shows you how to apply `torch.compile <https://pytorch.org/docs/2.2/generated/torch.compile.html>`_ correctly in your code.
|
||||
|
||||
.. note::
|
||||
|
||||
This requires PyTorch >= 2.0.
|
||||
|
||||
|
||||
----
|
||||
|
||||
|
|
|
@ -81,4 +81,4 @@ When training distributed models with :doc:`FSDP/TP <model_parallel/index>` or D
|
|||
|
||||
.. note::
|
||||
Empty-init is experimental and the behavior may change in the future.
|
||||
For distributed models on PyTorch 2.1+, it is required that all user-defined modules that manage parameters implement a ``reset_parameters()`` method (all PyTorch built-in modules have this too).
|
||||
For distributed models, it is required that all user-defined modules that manage parameters implement a ``reset_parameters()`` method (all PyTorch built-in modules have this too).
|
||||
|
|
|
@ -5,10 +5,6 @@ Speed up models by compiling them
|
|||
Compiling your LightningModule can result in significant speedups, especially on the latest generations of GPUs.
|
||||
This guide shows you how to apply `torch.compile <https://pytorch.org/docs/2.2/generated/torch.compile.html>`_ correctly in your code.
|
||||
|
||||
.. note::
|
||||
|
||||
This requires PyTorch >= 2.0.
|
||||
|
||||
|
||||
----
|
||||
|
||||
|
|
|
@ -1,8 +1,8 @@
|
|||
# NOTE: the upper bound for the package version is only set for CI stability, and it is dropped while installing this package
|
||||
# in case you want to preserve/enforce restrictions on the latest compatible version, add "strict" as an in-line comment
|
||||
|
||||
numpy >=1.17.2, <1.27.0
|
||||
torch >=2.0.0, <2.4.0
|
||||
numpy >=1.21.0, <1.27.0
|
||||
torch >=2.1.0, <2.4.0
|
||||
fsspec[http] >=2022.5.0, <2024.4.0
|
||||
packaging >=20.0, <=23.1
|
||||
typing-extensions >=4.4.0, <4.10.0
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
# NOTE: the upper bound for the package version is only set for CI stability, and it is dropped while installing this package
|
||||
# in case you want to preserve/enforce restrictions on the latest compatible version, add "strict" as an in-line comment
|
||||
|
||||
torchvision >=0.15.0, <0.19.0
|
||||
torchvision >=0.16.0, <0.19.0
|
||||
torchmetrics >=0.10.0, <1.3.0
|
||||
lightning-utilities >=0.8.0, <0.12.0
|
||||
|
|
|
@ -1,8 +1,8 @@
|
|||
# NOTE: the upper bound for the package version is only set for CI stability, and it is dropped while installing this package
|
||||
# in case you want to preserve/enforce restrictions on the latest compatible version, add "strict" as an in-line comment
|
||||
|
||||
numpy >=1.17.2, <1.27.0
|
||||
torch >=2.0.0, <2.4.0
|
||||
numpy >=1.21.0, <1.27.0
|
||||
torch >=2.1.0, <2.4.0
|
||||
tqdm >=4.57.0, <4.67.0
|
||||
PyYAML >=5.4, <6.1.0
|
||||
fsspec[http] >=2022.5.0, <2024.4.0
|
||||
|
|
|
@ -2,7 +2,7 @@
|
|||
# in case you want to preserve/enforce restrictions on the latest compatible version, add "strict" as an in-line comment
|
||||
|
||||
requests <2.32.0
|
||||
torchvision >=0.15.0, <0.19.0
|
||||
torchvision >=0.16.0, <0.19.0
|
||||
ipython[all] <8.15.0
|
||||
torchmetrics >=0.10.0, <1.3.0
|
||||
lightning-utilities >=0.8.0, <0.12.0
|
||||
|
|
|
@ -8,8 +8,8 @@ pytest-random-order ==1.1.0
|
|||
# needed in tests
|
||||
cloudpickle >=1.3, <2.3.0
|
||||
scikit-learn >0.22.1, <1.4.0
|
||||
onnx >=0.14.0, <1.15.0
|
||||
onnxruntime >=0.15.0, <1.17.0
|
||||
onnx >=1.12.0, <1.15.0
|
||||
onnxruntime >=1.12.0, <1.17.0
|
||||
psutil <5.9.6 # for `DeviceStatsMonitor`
|
||||
pandas >1.0, <2.2.0 # needed in benchmarks
|
||||
fastapi # for `ServableModuleValidator` # not setting version as re-defined in App
|
||||
|
|
|
@ -27,7 +27,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
|
|||
|
||||
### Removed
|
||||
|
||||
-
|
||||
- Removed support for PyTorch 2.1 ([#20009](https://github.com/Lightning-AI/lightning/pull/20009))
|
||||
|
||||
|
||||
-
|
||||
|
||||
|
|
|
@ -21,7 +21,7 @@ if not _root_logger.hasHandlers():
|
|||
_logger.propagate = False
|
||||
|
||||
|
||||
# In PyTorch 2.0+, setting this variable will force `torch.cuda.is_available()` and `torch.cuda.device_count()`
|
||||
# Setting this variable will force `torch.cuda.is_available()` and `torch.cuda.device_count()`
|
||||
# to use an NVML-based implementation that doesn't poison forks.
|
||||
# https://github.com/pytorch/pytorch/issues/83973
|
||||
os.environ["PYTORCH_NVML_BASED_CUDA_CHECK"] = "1"
|
||||
|
|
|
@ -225,8 +225,7 @@ class Fabric:
|
|||
_reapply_compile: If ``True`` (default), and the model was ``torch.compile``d before, the
|
||||
corresponding :class:`~torch._dynamo.OptimizedModule` wrapper will be removed and reapplied with the
|
||||
same settings after the model was set up by the strategy (e.g., after the model was wrapped by DDP,
|
||||
FSDP etc.). Only applies on PyTorch >= 2.1. Set it to ``False`` if compiling DDP/FSDP is causing
|
||||
issues.
|
||||
FSDP etc.). Set it to ``False`` if compiling DDP/FSDP is causing issues.
|
||||
|
||||
Returns:
|
||||
The tuple containing wrapped module and the optimizers, in the same order they were passed in.
|
||||
|
@ -292,8 +291,7 @@ class Fabric:
|
|||
_reapply_compile: If ``True`` (default), and the model was ``torch.compile``d before, the
|
||||
corresponding :class:`~torch._dynamo.OptimizedModule` wrapper will be removed and reapplied with the
|
||||
same settings after the model was set up by the strategy (e.g., after the model was wrapped by DDP,
|
||||
FSDP etc.). Only applies on PyTorch >= 2.1. Set it to ``False`` if compiling DDP/FSDP is causing
|
||||
issues.
|
||||
FSDP etc.). Set it to ``False`` if compiling DDP/FSDP is causing issues.
|
||||
Returns:
|
||||
The wrapped model.
|
||||
|
||||
|
|
|
@ -63,11 +63,10 @@ from lightning.fabric.utilities.distributed import (
|
|||
)
|
||||
from lightning.fabric.utilities.distributed import group as _group
|
||||
from lightning.fabric.utilities.imports import (
|
||||
_TORCH_GREATER_EQUAL_2_1,
|
||||
_TORCH_GREATER_EQUAL_2_2,
|
||||
_TORCH_GREATER_EQUAL_2_3,
|
||||
)
|
||||
from lightning.fabric.utilities.init import _EmptyInit, _has_meta_device_parameters_or_buffers
|
||||
from lightning.fabric.utilities.init import _has_meta_device_parameters_or_buffers
|
||||
from lightning.fabric.utilities.load import _METADATA_FILENAME, _lazy_load, _materialize_tensors, _move_state_into
|
||||
from lightning.fabric.utilities.rank_zero import rank_zero_deprecation, rank_zero_only, rank_zero_warn
|
||||
from lightning.fabric.utilities.seed import reset_seed
|
||||
|
@ -325,7 +324,7 @@ class FSDPStrategy(ParallelStrategy, _Sharded):
|
|||
if self._fsdp_kwargs.get("use_orig_params"):
|
||||
return super().setup_optimizer(optimizer)
|
||||
if not _optimizer_has_flat_params(optimizer):
|
||||
# We avoid this limitation in PyTorch >= 2.0 by setting `use_orig_params=True`
|
||||
# We avoid this limitation by setting `use_orig_params=True`
|
||||
raise ValueError(
|
||||
"The optimizer does not seem to reference any FSDP parameters. HINT: Make sure to create the optimizer"
|
||||
" after setting up the model."
|
||||
|
@ -340,15 +339,12 @@ class FSDPStrategy(ParallelStrategy, _Sharded):
|
|||
def module_init_context(self, empty_init: Optional[bool] = None) -> ContextManager:
|
||||
precision_init_ctx = self.precision.module_init_context()
|
||||
module_sharded_ctx = self.module_sharded_context()
|
||||
empty_ctx = _EmptyInit(enabled=bool(empty_init))
|
||||
stack = ExitStack()
|
||||
if _TORCH_GREATER_EQUAL_2_1 and empty_init:
|
||||
if empty_init:
|
||||
# Materialization happens in `setup`. When modules get wrapped by FSDP, the sequence of operations is:
|
||||
# 1) materialize module 2) call `reset_parameters()` 3) shard the module.
|
||||
# These operations are applied to each submodule 'bottom up' in the module hierarchy.
|
||||
stack.enter_context(torch.device("meta"))
|
||||
else:
|
||||
stack.enter_context(empty_ctx)
|
||||
stack.enter_context(precision_init_ctx)
|
||||
stack.enter_context(module_sharded_ctx)
|
||||
return stack
|
||||
|
@ -697,18 +693,13 @@ def _activation_checkpointing_kwargs(
|
|||
classes = tuple(activation_checkpointing)
|
||||
else:
|
||||
classes = (activation_checkpointing,)
|
||||
if _TORCH_GREATER_EQUAL_2_1:
|
||||
rank_zero_deprecation(
|
||||
f"`FSDPStrategy(activation_checkpointing={activation_checkpointing})` is deprecated, use "
|
||||
f"`FSDPStrategy(activation_checkpointing_policy={set(classes)})` instead."
|
||||
)
|
||||
rank_zero_deprecation(
|
||||
f"`FSDPStrategy(activation_checkpointing={activation_checkpointing})` is deprecated, use "
|
||||
f"`FSDPStrategy(activation_checkpointing_policy={set(classes)})` instead."
|
||||
)
|
||||
return {"check_fn": lambda submodule: isinstance(submodule, classes)}
|
||||
if isinstance(activation_checkpointing_policy, set):
|
||||
if _TORCH_GREATER_EQUAL_2_1:
|
||||
return _auto_wrap_policy_kwargs(activation_checkpointing_policy, {})
|
||||
return {"check_fn": lambda submodule: isinstance(submodule, tuple(activation_checkpointing_policy))}
|
||||
if not _TORCH_GREATER_EQUAL_2_1:
|
||||
raise ValueError("`activation_checkpointing_policy` requires torch >= 2.1.0. HINT: `pip install -U torch`")
|
||||
return _auto_wrap_policy_kwargs(activation_checkpointing_policy, {})
|
||||
return {"auto_wrap_policy": activation_checkpointing_policy}
|
||||
|
||||
|
||||
|
@ -716,15 +707,10 @@ def _auto_wrap_policy_kwargs(policy: Optional["_POLICY"], kwargs: Dict) -> Dict:
|
|||
if policy is None:
|
||||
return kwargs
|
||||
if isinstance(policy, set):
|
||||
if _TORCH_GREATER_EQUAL_2_1:
|
||||
from torch.distributed.fsdp.wrap import ModuleWrapPolicy
|
||||
from torch.distributed.fsdp.wrap import ModuleWrapPolicy
|
||||
|
||||
policy = ModuleWrapPolicy(policy)
|
||||
else:
|
||||
from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy
|
||||
policy = ModuleWrapPolicy(policy)
|
||||
|
||||
# this is not transformer specific despite the name
|
||||
policy = partial(transformer_auto_wrap_policy, transformer_layer_cls=policy)
|
||||
kwargs["auto_wrap_policy"] = policy
|
||||
return kwargs
|
||||
|
||||
|
@ -829,11 +815,8 @@ def _get_full_state_dict_context(
|
|||
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
|
||||
from torch.distributed.fsdp.api import FullOptimStateDictConfig
|
||||
|
||||
# In PyTorch < 2.1, offload to CPU in combination with `world_size=1` is not possible
|
||||
offload_to_cpu = world_size > 1 or _TORCH_GREATER_EQUAL_2_1
|
||||
state_dict_config = FullStateDictConfig(offload_to_cpu=offload_to_cpu, rank0_only=rank0_only)
|
||||
|
||||
optim_state_dict_config = FullOptimStateDictConfig(offload_to_cpu=offload_to_cpu, rank0_only=rank0_only)
|
||||
state_dict_config = FullStateDictConfig(offload_to_cpu=True, rank0_only=rank0_only)
|
||||
optim_state_dict_config = FullOptimStateDictConfig(offload_to_cpu=True, rank0_only=rank0_only)
|
||||
state_dict_type_context = FSDP.state_dict_type(
|
||||
module=module,
|
||||
state_dict_type=StateDictType.FULL_STATE_DICT,
|
||||
|
|
|
@ -26,13 +26,10 @@ _IS_WINDOWS = platform.system() == "Windows"
|
|||
# 2. The inspection mode via `python -i`: https://stackoverflow.com/a/6879085/1162383
|
||||
_IS_INTERACTIVE = hasattr(sys, "ps1") or bool(sys.flags.interactive)
|
||||
|
||||
_TORCH_GREATER_EQUAL_2_1 = compare_version("torch", operator.ge, "2.1.0")
|
||||
_TORCH_GREATER_EQUAL_2_2 = compare_version("torch", operator.ge, "2.2.0")
|
||||
_TORCH_GREATER_EQUAL_2_3 = compare_version("torch", operator.ge, "2.3.0")
|
||||
_TORCH_GREATER_EQUAL_2_4 = compare_version("torch", operator.ge, "2.4.0", use_base_version=True)
|
||||
|
||||
_TORCH_EQUAL_2_0 = compare_version("torch", operator.ge, "2.0.0") and not _TORCH_GREATER_EQUAL_2_1
|
||||
|
||||
_PYTHON_GREATER_EQUAL_3_8_0 = (sys.version_info.major, sys.version_info.minor) >= (3, 8)
|
||||
_PYTHON_GREATER_EQUAL_3_10_0 = (sys.version_info.major, sys.version_info.minor) >= (3, 10)
|
||||
|
||||
|
|
|
@ -20,7 +20,6 @@ from torch.optim import Optimizer
|
|||
from torch.overrides import TorchFunctionMode
|
||||
from typing_extensions import override
|
||||
|
||||
from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_1
|
||||
from lightning.fabric.utilities.rank_zero import rank_zero_warn
|
||||
from lightning.fabric.utilities.types import _DEVICE
|
||||
|
||||
|
@ -61,8 +60,6 @@ class _EmptyInit(TorchFunctionMode):
|
|||
|
||||
def _materialize(module: Module, device: _DEVICE) -> None:
|
||||
"""Materialize a module."""
|
||||
if not _TORCH_GREATER_EQUAL_2_1:
|
||||
raise RuntimeError("recurse=False requires torch 2.1")
|
||||
module.to_empty(device=device, recurse=False)
|
||||
if not hasattr(module, "reset_parameters"):
|
||||
raise TypeError(
|
||||
|
|
|
@ -24,7 +24,6 @@ from lightning.fabric.accelerators import XLAAccelerator
|
|||
from lightning.fabric.accelerators.cuda import num_cuda_devices
|
||||
from lightning.fabric.accelerators.mps import MPSAccelerator
|
||||
from lightning.fabric.strategies.deepspeed import _DEEPSPEED_AVAILABLE
|
||||
from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_1
|
||||
|
||||
|
||||
def _runif_reasons(
|
||||
|
@ -116,13 +115,9 @@ def _runif_reasons(
|
|||
reasons.append("Deepspeed")
|
||||
|
||||
if dynamo:
|
||||
if _TORCH_GREATER_EQUAL_2_1:
|
||||
from torch._dynamo.eval_frame import is_dynamo_supported
|
||||
from torch._dynamo.eval_frame import is_dynamo_supported
|
||||
|
||||
cond = not is_dynamo_supported()
|
||||
else:
|
||||
cond = sys.platform == "win32" or sys.version_info >= (3, 11)
|
||||
if cond:
|
||||
if not is_dynamo_supported():
|
||||
reasons.append("torch.dynamo")
|
||||
|
||||
return reasons, kwargs
|
||||
|
|
|
@ -18,7 +18,6 @@ from typing import TYPE_CHECKING, Any, Callable, Deque, Dict, List, Optional, Ty
|
|||
import torch
|
||||
from typing_extensions import override
|
||||
|
||||
from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_1
|
||||
from lightning.fabric.utilities.rank_zero import rank_zero_only, rank_zero_warn
|
||||
|
||||
if TYPE_CHECKING:
|
||||
|
@ -292,8 +291,6 @@ def measure_flops(
|
|||
FLOPs will be included in the result.
|
||||
|
||||
"""
|
||||
if not _TORCH_GREATER_EQUAL_2_1:
|
||||
raise ImportError("`measure_flops` requires PyTorch >= 2.1.")
|
||||
from torch.utils.flop_counter import FlopCounterMode
|
||||
|
||||
flop_counter = FlopCounterMode(display=False)
|
||||
|
|
|
@ -27,7 +27,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
|
|||
|
||||
### Removed
|
||||
|
||||
-
|
||||
- Removed support for PyTorch 2.1 ([#20009](https://github.com/Lightning-AI/lightning/pull/20009))
|
||||
|
||||
|
||||
-
|
||||
|
||||
|
|
|
@ -51,7 +51,6 @@ from lightning.fabric.loggers import Logger as FabricLogger
|
|||
from lightning.fabric.utilities.apply_func import convert_to_tensors
|
||||
from lightning.fabric.utilities.cloud_io import get_filesystem
|
||||
from lightning.fabric.utilities.device_dtype_mixin import _DeviceDtypeModuleMixin
|
||||
from lightning.fabric.utilities.imports import _IS_WINDOWS, _TORCH_GREATER_EQUAL_2_1
|
||||
from lightning.fabric.utilities.types import _MAP_LOCATION_TYPE, _PATH
|
||||
from lightning.fabric.wrappers import _FabricOptimizer
|
||||
from lightning.pytorch.callbacks.callback import Callback
|
||||
|
@ -67,7 +66,7 @@ from lightning.pytorch.utilities import GradClipAlgorithmType
|
|||
from lightning.pytorch.utilities.exceptions import MisconfigurationException
|
||||
from lightning.pytorch.utilities.imports import _TORCHMETRICS_GREATER_EQUAL_0_9_1
|
||||
from lightning.pytorch.utilities.model_helpers import _restricted_classmethod
|
||||
from lightning.pytorch.utilities.rank_zero import WarningCache, rank_zero_debug, rank_zero_warn
|
||||
from lightning.pytorch.utilities.rank_zero import WarningCache, rank_zero_warn
|
||||
from lightning.pytorch.utilities.signature_utils import is_param_in_hook_signature
|
||||
from lightning.pytorch.utilities.types import (
|
||||
_METRIC,
|
||||
|
@ -140,7 +139,6 @@ class LightningModule(
|
|||
self._current_fx_name: Optional[str] = None
|
||||
self._param_requires_grad_state: Dict[str, bool] = {}
|
||||
self._metric_attributes: Optional[Dict[int, str]] = None
|
||||
self._register_sharded_tensor_state_dict_hooks_if_available()
|
||||
self._compiler_ctx: Optional[Dict[str, Any]] = None
|
||||
|
||||
# attributes only used when using fabric
|
||||
|
@ -1390,9 +1388,7 @@ class LightningModule(
|
|||
|
||||
"""
|
||||
if not _ONNX_AVAILABLE:
|
||||
raise ModuleNotFoundError(
|
||||
f"`torch>=2.0` requires `onnx` to be installed to use `{type(self).__name__}.to_onnx()`"
|
||||
)
|
||||
raise ModuleNotFoundError(f"`{type(self).__name__}.to_onnx()` requires `onnx` to be installed.")
|
||||
|
||||
mode = self.training
|
||||
|
||||
|
@ -1599,24 +1595,6 @@ class LightningModule(
|
|||
state["_trainer"] = None
|
||||
return state
|
||||
|
||||
def _register_sharded_tensor_state_dict_hooks_if_available(self) -> None:
|
||||
"""Adds ShardedTensor state dict hooks if ShardedTensors are supported.
|
||||
|
||||
These hooks ensure that ShardedTensors are included when saving, and are loaded the LightningModule correctly.
|
||||
|
||||
"""
|
||||
if _TORCH_GREATER_EQUAL_2_1:
|
||||
# ShardedTensor is deprecated in favor of DistributedTensor
|
||||
return
|
||||
if _IS_WINDOWS or not torch.distributed.is_available():
|
||||
rank_zero_debug("Could not register sharded tensor state dict hooks")
|
||||
return
|
||||
|
||||
from torch.distributed._shard.sharded_tensor import pre_load_state_dict_hook, state_dict_hook
|
||||
|
||||
self._register_state_dict_hook(state_dict_hook)
|
||||
self._register_load_state_dict_pre_hook(pre_load_state_dict_hook, True)
|
||||
|
||||
|
||||
@contextmanager
|
||||
def _jit_is_scripting() -> Generator:
|
||||
|
|
|
@ -85,8 +85,7 @@ class PositionalEncoding(nn.Module):
|
|||
if self.pe is None:
|
||||
# 1) can't use buffer, see https://github.com/pytorch/pytorch/issues/68407
|
||||
# 2) can't use parameter becauses pe gets sliced and DDP requires all params to participate in forward
|
||||
# 3) can't make it a `requires_grad=False` parameter because FSDP in PyTorch < 2.1 needs all params to
|
||||
# require grad
|
||||
# TODO: Could make this a `nn.Parameter` with `requires_grad=False`
|
||||
self.pe = self._init_pos_encoding(device=x.device)
|
||||
|
||||
x + self.pe[: x.size(0), :]
|
||||
|
|
|
@ -21,7 +21,6 @@ from torch import Tensor
|
|||
|
||||
import lightning.pytorch as pl
|
||||
from lightning.fabric.utilities.distributed import _distributed_is_initialized
|
||||
from lightning.fabric.utilities.imports import _TORCH_EQUAL_2_0
|
||||
from lightning.fabric.utilities.warnings import PossibleUserWarning
|
||||
from lightning.pytorch.accelerators.xla import XLAAccelerator
|
||||
from lightning.pytorch.callbacks.timer import Timer
|
||||
|
@ -171,9 +170,6 @@ def _no_grad_context(loop_run: Callable) -> Callable:
|
|||
elif isinstance(self.trainer.strategy, FSDPStrategy):
|
||||
# https://github.com/pytorch/pytorch/issues/95957
|
||||
context_manager = torch.no_grad
|
||||
elif _TORCH_EQUAL_2_0 and self.trainer.lightning_module._compiler_ctx is not None:
|
||||
# avoid: `RuntimeError: Inference tensors do not track version counter` fixed in v2.1
|
||||
context_manager = torch.no_grad
|
||||
elif self.inference_mode:
|
||||
context_manager = torch.inference_mode
|
||||
else:
|
||||
|
|
|
@ -67,11 +67,8 @@ from lightning.fabric.utilities.distributed import (
|
|||
_sync_ddp_if_available,
|
||||
)
|
||||
from lightning.fabric.utilities.distributed import group as _group
|
||||
from lightning.fabric.utilities.imports import (
|
||||
_TORCH_GREATER_EQUAL_2_1,
|
||||
_TORCH_GREATER_EQUAL_2_2,
|
||||
)
|
||||
from lightning.fabric.utilities.init import _EmptyInit, _has_meta_device_parameters_or_buffers
|
||||
from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_2
|
||||
from lightning.fabric.utilities.init import _has_meta_device_parameters_or_buffers
|
||||
from lightning.fabric.utilities.load import _lazy_load, _materialize_tensors
|
||||
from lightning.fabric.utilities.optimizer import _optimizers_to_device
|
||||
from lightning.fabric.utilities.seed import reset_seed
|
||||
|
@ -368,8 +365,8 @@ class FSDPStrategy(ParallelStrategy):
|
|||
|
||||
invalid_params_error = False
|
||||
try:
|
||||
# In PyTorch < 2.0, or if `use_orig_params=False` the user needs to do access
|
||||
# `self.trainer.model.parameters()` in configure_optimizers()
|
||||
# If `use_orig_params=False` the user needs to do access `self.trainer.model.parameters()` in
|
||||
# `configure_optimizers()`
|
||||
super().setup_optimizers(trainer)
|
||||
except ValueError as ex:
|
||||
if "optimizer got an empty parameter list" not in str(ex):
|
||||
|
@ -377,7 +374,7 @@ class FSDPStrategy(ParallelStrategy):
|
|||
invalid_params_error = True
|
||||
|
||||
if invalid_params_error or any(not _optimizer_has_flat_params(optimizer) for optimizer in self.optimizers):
|
||||
# We avoid this limitation in PyTorch >= 2.0 by setting `use_orig_params=True`
|
||||
# We avoid this limitation by setting `use_orig_params=True`
|
||||
raise ValueError(
|
||||
"The optimizer does not seem to reference any FSDP parameters. HINT: Make sure to create the"
|
||||
" optimizer after setting up the model by referencing `self.trainer.model.parameters()` in the"
|
||||
|
@ -393,14 +390,10 @@ class FSDPStrategy(ParallelStrategy):
|
|||
@contextmanager
|
||||
@override
|
||||
def tensor_init_context(self, empty_init: Optional[bool] = None) -> Generator[None, None, None]:
|
||||
empty_init_context: Union[torch.device, _EmptyInit, nullcontext]
|
||||
if _TORCH_GREATER_EQUAL_2_1 and empty_init:
|
||||
# Materialization happens in `setup`. When modules get wrapped by FSDP, the sequence of operations is:
|
||||
# 1) materialize module 2) call `reset_parameters()` 3) shard the module.
|
||||
# These operations are applied to each submodule 'bottom up' in the module hierarchy.
|
||||
empty_init_context = torch.device("meta")
|
||||
else:
|
||||
empty_init_context = _EmptyInit(enabled=bool(empty_init))
|
||||
# Materialization happens in `setup`. When modules get wrapped by FSDP, the sequence of operations is:
|
||||
# 1) materialize module 2) call `reset_parameters()` 3) shard the module.
|
||||
# These operations are applied to each submodule 'bottom up' in the module hierarchy.
|
||||
empty_init_context = torch.device("meta") if empty_init else nullcontext()
|
||||
with empty_init_context, self.precision_plugin.tensor_init_context():
|
||||
yield
|
||||
|
||||
|
|
|
@ -24,7 +24,6 @@ from typing_extensions import TypedDict, override
|
|||
from lightning.fabric.utilities import move_data_to_device
|
||||
from lightning.fabric.utilities.apply_func import convert_tensors_to_scalars
|
||||
from lightning.fabric.utilities.distributed import _distributed_is_initialized
|
||||
from lightning.fabric.utilities.imports import _TORCH_EQUAL_2_0
|
||||
from lightning.pytorch.utilities.data import extract_batch_size
|
||||
from lightning.pytorch.utilities.exceptions import MisconfigurationException
|
||||
from lightning.pytorch.utilities.imports import _TORCHMETRICS_GREATER_EQUAL_1_0_0
|
||||
|
@ -112,7 +111,7 @@ class _Metadata:
|
|||
on_step: bool = False
|
||||
on_epoch: bool = True
|
||||
# https://github.com/pytorch/pytorch/issues/96197
|
||||
reduce_fx: Callable = "mean" if _TORCH_EQUAL_2_0 else torch.mean # type: ignore[assignment]
|
||||
reduce_fx: Callable = torch.mean
|
||||
enable_graph: bool = False
|
||||
add_dataloader_idx: bool = True
|
||||
dataloader_idx: Optional[int] = None
|
||||
|
@ -362,7 +361,7 @@ class _ResultCollection(dict):
|
|||
on_step: bool = False,
|
||||
on_epoch: bool = True,
|
||||
# https://github.com/pytorch/pytorch/issues/96197
|
||||
reduce_fx: Callable = "mean" if _TORCH_EQUAL_2_0 else torch.mean, # type: ignore[assignment]
|
||||
reduce_fx: Callable = torch.mean,
|
||||
enable_graph: bool = False,
|
||||
sync_dist: bool = False,
|
||||
sync_dist_fn: Callable = _Sync.no_op,
|
||||
|
|
|
@ -17,7 +17,6 @@ import torch
|
|||
from torch._dynamo import OptimizedModule
|
||||
|
||||
import lightning.pytorch as pl
|
||||
from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_1
|
||||
from lightning.pytorch.strategies import DDPStrategy, DeepSpeedStrategy, FSDPStrategy, SingleDeviceStrategy, Strategy
|
||||
from lightning.pytorch.utilities.model_helpers import _check_mixed_imports
|
||||
|
||||
|
@ -56,11 +55,7 @@ def from_compiled(model: OptimizedModule) -> "pl.LightningModule":
|
|||
}
|
||||
|
||||
orig_module.forward = model.dynamo_ctx(orig_module.forward) # type: ignore[method-assign]
|
||||
if not _TORCH_GREATER_EQUAL_2_1: # https://github.com/pytorch/pytorch/issues/95630
|
||||
orig_module.forward._torchdynamo_inline = orig_module.forward
|
||||
orig_module.training_step = model.dynamo_ctx(orig_module.training_step) # type: ignore[method-assign]
|
||||
if not _TORCH_GREATER_EQUAL_2_1: # https://github.com/pytorch/pytorch/issues/95630
|
||||
orig_module.training_step._torchdynamo_inline = orig_module.training_step
|
||||
orig_module.validation_step = model.dynamo_ctx(orig_module.validation_step) # type: ignore[method-assign]
|
||||
orig_module.test_step = model.dynamo_ctx(orig_module.test_step) # type: ignore[method-assign]
|
||||
orig_module.predict_step = model.dynamo_ctx(orig_module.predict_step) # type: ignore[method-assign]
|
||||
|
|
|
@ -101,7 +101,10 @@ def thread_police_duuu_daaa_duuu_daaa():
|
|||
assert not thread.is_alive()
|
||||
elif isinstance(thread, _ChildProcessObserver):
|
||||
thread.join(timeout=10)
|
||||
elif thread.name == "QueueFeederThread": # tensorboardX
|
||||
elif (
|
||||
thread.name == "QueueFeederThread" # tensorboardX
|
||||
or thread.name == "QueueManagerThread" # torch.compile
|
||||
):
|
||||
thread.join(timeout=20)
|
||||
elif (
|
||||
sys.version_info >= (3, 9)
|
||||
|
|
|
@ -148,7 +148,7 @@ def test_bitsandbytes_layers(args, expected):
|
|||
assert model.l.weight.dtype == expected
|
||||
|
||||
|
||||
@RunIf(min_cuda_gpus=1, min_torch="2.1")
|
||||
@RunIf(min_cuda_gpus=1)
|
||||
@pytest.mark.skipif(not _BITSANDBYTES_AVAILABLE, reason="bitsandbytes unavailable")
|
||||
@pytest.mark.parametrize(
|
||||
("args", "expected"),
|
||||
|
@ -232,7 +232,7 @@ def test_bitsandbytes_layers_meta_device(args, expected, tmp_path):
|
|||
assert model.l.weight.dtype == expected
|
||||
|
||||
|
||||
@RunIf(min_cuda_gpus=1, min_torch="2.1")
|
||||
@RunIf(min_cuda_gpus=1)
|
||||
@pytest.mark.skipif(not _BITSANDBYTES_AVAILABLE, reason="bitsandbytes unavailable")
|
||||
def test_load_quantized_checkpoint(tmp_path):
|
||||
"""Test that a checkpoint saved from a quantized model can be loaded back into a quantized model."""
|
||||
|
|
|
@ -75,7 +75,7 @@ def _run_ddp_save_load(fabric, tmp_path):
|
|||
assert_params_equal(params_before, wrapped_model.parameters())
|
||||
|
||||
|
||||
@RunIf(min_cuda_gpus=2, standalone=True, min_torch="2.1.0", dynamo=True)
|
||||
@RunIf(min_cuda_gpus=2, standalone=True, dynamo=True)
|
||||
@mock.patch("lightning.fabric.wrappers.torch.compile", Mock(wraps=torch.compile))
|
||||
@mock.patch.dict(os.environ, {})
|
||||
def test_reapply_compile():
|
||||
|
|
|
@ -16,7 +16,6 @@ from re import escape
|
|||
from unittest import mock
|
||||
from unittest.mock import ANY, MagicMock, Mock
|
||||
|
||||
import lightning.fabric
|
||||
import pytest
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
@ -28,8 +27,9 @@ from lightning.fabric.strategies.fsdp import (
|
|||
_get_full_state_dict_context,
|
||||
_is_sharded_checkpoint,
|
||||
)
|
||||
from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_1, _TORCH_GREATER_EQUAL_2_2
|
||||
from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_2
|
||||
from torch.distributed.fsdp.fully_sharded_data_parallel import CPUOffload, FullyShardedDataParallel, MixedPrecision
|
||||
from torch.distributed.fsdp.wrap import ModuleWrapPolicy
|
||||
from torch.optim import Adam
|
||||
|
||||
|
||||
|
@ -147,13 +147,6 @@ def test_no_backward_sync():
|
|||
module.no_sync.assert_called_once()
|
||||
|
||||
|
||||
def test_activation_checkpointing_support(monkeypatch):
|
||||
"""Test that we error out if activation checkpointing requires a newer PyTorch version."""
|
||||
monkeypatch.setattr(lightning.fabric.strategies.fsdp, "_TORCH_GREATER_EQUAL_2_1", False)
|
||||
with pytest.raises(ValueError, match="activation_checkpointing_policy` requires torch >= 2.1.0"):
|
||||
FSDPStrategy(activation_checkpointing_policy=Mock())
|
||||
|
||||
|
||||
def test_activation_checkpointing():
|
||||
"""Test that the FSDP strategy can apply activation checkpointing to the given layers."""
|
||||
|
||||
|
@ -170,28 +163,13 @@ def test_activation_checkpointing():
|
|||
self.layer1 = Block2(2, 2)
|
||||
self.layer2 = nn.Linear(3, 3)
|
||||
|
||||
if _TORCH_GREATER_EQUAL_2_1:
|
||||
from torch.distributed.fsdp.wrap import ModuleWrapPolicy
|
||||
strategy = FSDPStrategy(activation_checkpointing_policy={Block1})
|
||||
assert set(strategy._activation_checkpointing_kwargs) == {"auto_wrap_policy"}
|
||||
assert isinstance(strategy._activation_checkpointing_kwargs["auto_wrap_policy"], ModuleWrapPolicy)
|
||||
|
||||
strategy = FSDPStrategy(activation_checkpointing_policy={Block1})
|
||||
assert set(strategy._activation_checkpointing_kwargs) == {"auto_wrap_policy"}
|
||||
assert isinstance(strategy._activation_checkpointing_kwargs["auto_wrap_policy"], ModuleWrapPolicy)
|
||||
|
||||
strategy = FSDPStrategy(activation_checkpointing_policy=ModuleWrapPolicy({Block1, Block2}))
|
||||
assert set(strategy._activation_checkpointing_kwargs) == {"auto_wrap_policy"}
|
||||
assert isinstance(strategy._activation_checkpointing_kwargs["auto_wrap_policy"], ModuleWrapPolicy)
|
||||
else:
|
||||
strategy = FSDPStrategy(activation_checkpointing=Block1)
|
||||
assert set(strategy._activation_checkpointing_kwargs) == {"check_fn"}
|
||||
|
||||
strategy = FSDPStrategy(activation_checkpointing=[Block1, Block2])
|
||||
assert set(strategy._activation_checkpointing_kwargs) == {"check_fn"}
|
||||
|
||||
strategy = FSDPStrategy(activation_checkpointing_policy={Block1})
|
||||
assert set(strategy._activation_checkpointing_kwargs) == {"check_fn"}
|
||||
|
||||
strategy = FSDPStrategy(activation_checkpointing_policy={Block1, Block2})
|
||||
assert set(strategy._activation_checkpointing_kwargs) == {"check_fn"}
|
||||
strategy = FSDPStrategy(activation_checkpointing_policy=ModuleWrapPolicy({Block1, Block2}))
|
||||
assert set(strategy._activation_checkpointing_kwargs) == {"auto_wrap_policy"}
|
||||
assert isinstance(strategy._activation_checkpointing_kwargs["auto_wrap_policy"], ModuleWrapPolicy)
|
||||
|
||||
strategy._parallel_devices = [torch.device("cuda", 0)]
|
||||
with mock.patch("torch.distributed.fsdp.FullyShardedDataParallel", new=MagicMock), mock.patch(
|
||||
|
@ -401,15 +379,13 @@ def test_set_timeout(init_process_group_mock):
|
|||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("torch_ge_2_1", [True, False])
|
||||
@mock.patch("torch.distributed.fsdp.fully_sharded_data_parallel.FullyShardedDataParallel.set_state_dict_type")
|
||||
def test_get_full_state_dict_context_offload(set_type_mock, monkeypatch, torch_ge_2_1):
|
||||
"""Test that the state dict context manager handles CPU offloading depending on the PyTorch version."""
|
||||
monkeypatch.setattr("lightning.fabric.strategies.fsdp._TORCH_GREATER_EQUAL_2_1", torch_ge_2_1)
|
||||
def test_get_full_state_dict_context_offload(set_type_mock, monkeypatch):
|
||||
"""Test that the state dict context manager handles CPU offloading."""
|
||||
|
||||
with _get_full_state_dict_context(module=Mock(spec=FullyShardedDataParallel), world_size=1):
|
||||
assert set_type_mock.call_args_list[0][0][2].offload_to_cpu is torch_ge_2_1 # model config
|
||||
assert set_type_mock.call_args_list[0][0][3].offload_to_cpu is torch_ge_2_1 # optim config
|
||||
assert set_type_mock.call_args_list[0][0][2].offload_to_cpu # model config
|
||||
assert set_type_mock.call_args_list[0][0][3].offload_to_cpu # optim config
|
||||
|
||||
set_type_mock.reset_mock()
|
||||
|
||||
|
|
|
@ -23,7 +23,6 @@ import torch.nn as nn
|
|||
from lightning.fabric import Fabric
|
||||
from lightning.fabric.plugins import FSDPPrecision
|
||||
from lightning.fabric.strategies import FSDPStrategy
|
||||
from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_1
|
||||
from lightning.fabric.utilities.load import _load_distributed_checkpoint
|
||||
from lightning.fabric.wrappers import _FabricOptimizer
|
||||
from torch._dynamo import OptimizedModule
|
||||
|
@ -400,7 +399,7 @@ def test_setup_with_orig_params_and_multiple_param_groups():
|
|||
assert not isinstance(layer.weight, FlatParameter)
|
||||
|
||||
|
||||
@RunIf(min_cuda_gpus=2, standalone=True, min_torch="2.1.0", dynamo=True, skip_windows=True)
|
||||
@RunIf(min_cuda_gpus=2, standalone=True, dynamo=True, skip_windows=True)
|
||||
@mock.patch("lightning.fabric.wrappers.torch.compile", Mock(wraps=torch.compile))
|
||||
@mock.patch.dict(os.environ, {})
|
||||
def test_reapply_compile():
|
||||
|
@ -466,12 +465,8 @@ def test_module_init_context(precision, expected_dtype):
|
|||
# Case 1: No empty init
|
||||
_run_setup_assertions(empty_init=False, expected_device=torch.device("cpu"))
|
||||
|
||||
if _TORCH_GREATER_EQUAL_2_1:
|
||||
# Case 2: Empty-init with PyTorch >= 2.1 supports meta device
|
||||
_run_setup_assertions(empty_init=True, expected_device=torch.device("meta"))
|
||||
else:
|
||||
# Case 2: Empty-init with PyTorch < 2.1 only supports `torch.empty()`-init
|
||||
_run_setup_assertions(empty_init=True, expected_device=torch.device("cpu"))
|
||||
# Case 2: Empty-init with meta device
|
||||
_run_setup_assertions(empty_init=True, expected_device=torch.device("meta"))
|
||||
|
||||
|
||||
@RunIf(min_cuda_gpus=2, standalone=True)
|
||||
|
@ -538,9 +533,6 @@ def test_rewrap_warnings():
|
|||
assert not isinstance(model._forward_module, FullyShardedDataParallel)
|
||||
assert isinstance(model._forward_module[2], FullyShardedDataParallel)
|
||||
|
||||
if not _TORCH_GREATER_EQUAL_2_1:
|
||||
return
|
||||
|
||||
with fabric.init_module(empty_init=True):
|
||||
model = torch.nn.Sequential(torch.nn.Linear(1, 1), torch.nn.ReLU(), wrap(torch.nn.Linear(1, 1)))
|
||||
assert model[0].weight.is_meta
|
||||
|
|
|
@ -289,7 +289,7 @@ def test_setup_optimizers_not_supported(strategy_cls):
|
|||
fabric.setup_optimizers(optimizer)
|
||||
|
||||
|
||||
@RunIf(min_cuda_gpus=1, min_torch="2.1")
|
||||
@RunIf(min_cuda_gpus=1)
|
||||
def test_setup_optimizer_on_meta_device():
|
||||
"""Test that the setup-methods validate that the optimizer doesn't have references to meta-device parameters."""
|
||||
fabric = Fabric(strategy="fsdp", devices=1)
|
||||
|
@ -867,8 +867,6 @@ def test_init_module_context(monkeypatch):
|
|||
|
||||
|
||||
def test_init_tensor_context(monkeypatch):
|
||||
"""Test that `.init_tensor()` warns if using PyTorch < 2.0."""
|
||||
|
||||
fabric = Fabric(accelerator="cpu")
|
||||
strategy = SingleDeviceStrategy(device=torch.device("cuda"))
|
||||
strategy.tensor_init_context = Mock(wraps=strategy.tensor_init_context)
|
||||
|
|
|
@ -19,7 +19,6 @@ import torch
|
|||
from lightning.fabric.fabric import Fabric
|
||||
from lightning.fabric.plugins import Precision
|
||||
from lightning.fabric.utilities.device_dtype_mixin import _DeviceDtypeModuleMixin
|
||||
from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_1
|
||||
from lightning.fabric.wrappers import (
|
||||
_FabricDataLoader,
|
||||
_FabricModule,
|
||||
|
@ -268,14 +267,13 @@ def test_fabric_module_state_dict_access():
|
|||
assert torch.equal(fabric_module.layer.weight, weight)
|
||||
assert torch.equal(fabric_module.layer.bias, bias)
|
||||
|
||||
if _TORCH_GREATER_EQUAL_2_1:
|
||||
# Can use additional `assign` argument in PyTorch >= 2.1
|
||||
with torch.device("meta"):
|
||||
original_module = OriginalModule()
|
||||
fabric_module = _FabricModule(wrapped_module, Mock(), original_module=original_module)
|
||||
assert fabric_module.layer.weight.is_meta
|
||||
fabric_module.load_state_dict({"layer.weight": weight, "layer.bias": bias}, assign=True)
|
||||
assert not fabric_module.layer.weight.is_meta
|
||||
# Can use additional `assign` argument
|
||||
with torch.device("meta"):
|
||||
original_module = OriginalModule()
|
||||
fabric_module = _FabricModule(wrapped_module, Mock(), original_module=original_module)
|
||||
assert fabric_module.layer.weight.is_meta
|
||||
fabric_module.load_state_dict({"layer.weight": weight, "layer.bias": bias}, assign=True)
|
||||
assert not fabric_module.layer.weight.is_meta
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
|
|
|
@ -58,7 +58,6 @@ def test_empty_init_speed():
|
|||
assert normal_init_time > 2 * empty_init_time
|
||||
|
||||
|
||||
@RunIf(min_torch="2.1")
|
||||
def test_materialize_meta_tensors():
|
||||
class Submodule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
|
|
|
@ -13,11 +13,9 @@ from lightning.fabric.utilities.throughput import (
|
|||
measure_flops,
|
||||
)
|
||||
|
||||
from tests_fabric.helpers.runif import RunIf
|
||||
from tests_fabric.test_fabric import BoringModel
|
||||
|
||||
|
||||
@RunIf(min_torch="2.1")
|
||||
def test_measure_flops():
|
||||
with torch.device("meta"):
|
||||
model = BoringModel()
|
||||
|
|
|
@ -8,10 +8,7 @@ from lightning.pytorch import Trainer
|
|||
from lightning.pytorch.callbacks.throughput_monitor import ThroughputMonitor
|
||||
from lightning.pytorch.demos.boring_classes import BoringModel
|
||||
|
||||
from tests_pytorch.helpers.runif import RunIf
|
||||
|
||||
|
||||
@RunIf(min_torch="2.1")
|
||||
def test_measure_flops():
|
||||
with torch.device("meta"):
|
||||
model = BoringModel()
|
||||
|
|
|
@ -153,7 +153,10 @@ def thread_police_duuu_daaa_duuu_daaa():
|
|||
assert not thread.is_alive()
|
||||
elif isinstance(thread, _ChildProcessObserver):
|
||||
thread.join(timeout=10)
|
||||
elif thread.name == "QueueFeederThread": # tensorboardX
|
||||
elif (
|
||||
thread.name == "QueueFeederThread" # tensorboardX
|
||||
or thread.name == "QueueManagerThread" # torch.compile
|
||||
):
|
||||
thread.join(timeout=20)
|
||||
elif isinstance(thread, TMonitor):
|
||||
thread.exit()
|
||||
|
|
|
@ -35,14 +35,10 @@ def test_ddp_is_distributed():
|
|||
_ = strategy.is_distributed
|
||||
|
||||
|
||||
def test_fsdp_activation_checkpointing(monkeypatch):
|
||||
def test_fsdp_activation_checkpointing():
|
||||
with pytest.raises(ValueError, match="cannot set both `activation_checkpointing"):
|
||||
FSDPStrategy(activation_checkpointing=torch.nn.Linear, activation_checkpointing_policy=lambda *_: True)
|
||||
|
||||
monkeypatch.setattr(lightning.fabric.strategies.fsdp, "_TORCH_GREATER_EQUAL_2_1", True)
|
||||
with pytest.deprecated_call(match=r"use `FSDPStrategy\(activation_checkpointing_policy"):
|
||||
FSDPStrategy(activation_checkpointing=torch.nn.Linear)
|
||||
|
||||
|
||||
def test_double_precision_wrapper():
|
||||
with pytest.deprecated_call(match=r"The `LightningDoublePrecisionModule` is deprecated and no longer needed"):
|
||||
|
|
|
@ -14,7 +14,7 @@ import torch
|
|||
import torch.nn as nn
|
||||
from lightning.fabric.plugins.environments import LightningEnvironment
|
||||
from lightning.fabric.strategies.fsdp import _is_sharded_checkpoint
|
||||
from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_1, _TORCH_GREATER_EQUAL_2_2
|
||||
from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_2
|
||||
from lightning.fabric.utilities.load import _load_distributed_checkpoint
|
||||
from lightning.pytorch import Trainer
|
||||
from lightning.pytorch.callbacks import ModelCheckpoint
|
||||
|
@ -334,10 +334,9 @@ def test_strategy_full_state_dict(tmp_path, wrap_min_params):
|
|||
TestFSDPModelAutoWrapped(),
|
||||
FSDPStrategy,
|
||||
{
|
||||
"auto_wrap_policy": ModuleWrapPolicy({nn.Linear}) if _TORCH_GREATER_EQUAL_2_1 else None,
|
||||
"auto_wrap_policy": ModuleWrapPolicy({nn.Linear}),
|
||||
"use_orig_params": True,
|
||||
},
|
||||
marks=RunIf(min_torch="2.1.0"),
|
||||
id="autowrap_use_orig_params",
|
||||
),
|
||||
],
|
||||
|
@ -380,19 +379,12 @@ def test_invalid_parameters_in_optimizer(use_orig_params):
|
|||
fast_dev_run=1,
|
||||
)
|
||||
|
||||
error_context = (
|
||||
nullcontext()
|
||||
if _TORCH_GREATER_EQUAL_2_1 or use_orig_params is not False
|
||||
else pytest.raises(ValueError, match="The optimizer does not seem to reference any FSDP parameters")
|
||||
)
|
||||
|
||||
class EmptyParametersModel(BoringModel):
|
||||
def configure_optimizers(self):
|
||||
return torch.optim.Adam(self.parameters(), lr=1e-2)
|
||||
|
||||
model = EmptyParametersModel()
|
||||
with error_context:
|
||||
trainer.fit(model)
|
||||
trainer.fit(model)
|
||||
|
||||
class NoFlatParametersModel(BoringModel):
|
||||
def configure_optimizers(self):
|
||||
|
@ -435,28 +427,13 @@ def test_activation_checkpointing():
|
|||
self.layer1 = Block2(2, 2)
|
||||
self.layer2 = nn.Linear(3, 3)
|
||||
|
||||
if _TORCH_GREATER_EQUAL_2_1:
|
||||
from torch.distributed.fsdp.wrap import ModuleWrapPolicy
|
||||
strategy = FSDPStrategy(activation_checkpointing_policy={Block1})
|
||||
assert set(strategy._activation_checkpointing_kwargs) == {"auto_wrap_policy"}
|
||||
assert isinstance(strategy._activation_checkpointing_kwargs["auto_wrap_policy"], ModuleWrapPolicy)
|
||||
|
||||
strategy = FSDPStrategy(activation_checkpointing_policy={Block1})
|
||||
assert set(strategy._activation_checkpointing_kwargs) == {"auto_wrap_policy"}
|
||||
assert isinstance(strategy._activation_checkpointing_kwargs["auto_wrap_policy"], ModuleWrapPolicy)
|
||||
|
||||
strategy = FSDPStrategy(activation_checkpointing_policy=ModuleWrapPolicy({Block1, Block2}))
|
||||
assert set(strategy._activation_checkpointing_kwargs) == {"auto_wrap_policy"}
|
||||
assert isinstance(strategy._activation_checkpointing_kwargs["auto_wrap_policy"], ModuleWrapPolicy)
|
||||
else:
|
||||
strategy = FSDPStrategy(activation_checkpointing=Block1)
|
||||
assert set(strategy._activation_checkpointing_kwargs) == {"check_fn"}
|
||||
|
||||
strategy = FSDPStrategy(activation_checkpointing=[Block1, Block2])
|
||||
assert set(strategy._activation_checkpointing_kwargs) == {"check_fn"}
|
||||
|
||||
strategy = FSDPStrategy(activation_checkpointing_policy={Block1})
|
||||
assert set(strategy._activation_checkpointing_kwargs) == {"check_fn"}
|
||||
|
||||
strategy = FSDPStrategy(activation_checkpointing_policy={Block1, Block2})
|
||||
assert set(strategy._activation_checkpointing_kwargs) == {"check_fn"}
|
||||
strategy = FSDPStrategy(activation_checkpointing_policy=ModuleWrapPolicy({Block1, Block2}))
|
||||
assert set(strategy._activation_checkpointing_kwargs) == {"auto_wrap_policy"}
|
||||
assert isinstance(strategy._activation_checkpointing_kwargs["auto_wrap_policy"], ModuleWrapPolicy)
|
||||
|
||||
model = Model()
|
||||
strategy._parallel_devices = [torch.device("cuda", 0)]
|
||||
|
@ -608,7 +585,7 @@ def test_strategy_save_optimizer_states(tmp_path, wrap_min_params):
|
|||
if trainer.global_rank != 0:
|
||||
assert len(model_state_dict) == 0
|
||||
|
||||
if trainer.global_rank != 0 and _TORCH_GREATER_EQUAL_2_1:
|
||||
if trainer.global_rank != 0:
|
||||
assert len(optimizer_state_dict) == 0
|
||||
|
||||
# restore model to ddp
|
||||
|
@ -679,7 +656,7 @@ def test_strategy_load_optimizer_states(wrap_min_params, tmp_path):
|
|||
if trainer.global_rank != 0:
|
||||
assert len(restored_model_state_dict) == 0
|
||||
|
||||
if trainer.global_rank != 0 and _TORCH_GREATER_EQUAL_2_1:
|
||||
if trainer.global_rank != 0:
|
||||
assert len(restored_optimizer_state_dict) == 0
|
||||
|
||||
if trainer.global_rank == 0:
|
||||
|
@ -936,12 +913,8 @@ def test_module_init_context(precision, expected_dtype, tmp_path):
|
|||
# Case 1: No empty init
|
||||
_run_setup_assertions(empty_init=False, expected_device=torch.device("cpu"))
|
||||
|
||||
if _TORCH_GREATER_EQUAL_2_1:
|
||||
# Case 2: Empty-init with PyTorch >= 2.1 supports meta device
|
||||
_run_setup_assertions(empty_init=True, expected_device=torch.device("meta"))
|
||||
else:
|
||||
# Case 2: Empty-init with PyTorch < 2.1 only supports `torch.empty()`-init
|
||||
_run_setup_assertions(empty_init=True, expected_device=torch.device("cpu"))
|
||||
# Case 2: Empty-init with meta device
|
||||
_run_setup_assertions(empty_init=True, expected_device=torch.device("meta"))
|
||||
|
||||
|
||||
@RunIf(min_cuda_gpus=2, standalone=True, min_torch="2.3.0")
|
||||
|
|
|
@ -339,7 +339,7 @@ def test_module_init_context(precision, expected_dtype, tmp_path):
|
|||
# Case 1: No empty init
|
||||
_run_setup_assertions(empty_init=False, expected_device=torch.device("cpu"))
|
||||
|
||||
# Case 2: Empty-init with PyTorch >= 2.1 supports meta device
|
||||
# Case 2: Empty-init with meta device
|
||||
_run_setup_assertions(empty_init=True, expected_device=torch.device("meta"))
|
||||
|
||||
|
||||
|
|
|
@ -16,7 +16,6 @@ from unittest.mock import Mock
|
|||
|
||||
import pytest
|
||||
import torch
|
||||
from lightning.fabric.utilities.imports import _TORCH_EQUAL_2_0
|
||||
from lightning.pytorch import Trainer
|
||||
from lightning.pytorch.demos.boring_classes import BoringModel
|
||||
from lightning.pytorch.loops import _Loop
|
||||
|
@ -81,7 +80,5 @@ def test_no_grad_context():
|
|||
f.run()
|
||||
no_grad_mock.assert_called_once_with()
|
||||
f.inference_mode = True
|
||||
with mock.patch("torch.inference_mode") as inference_mode_mock:
|
||||
with mock.patch("torch.inference_mode"):
|
||||
f.run()
|
||||
if not _TORCH_EQUAL_2_0:
|
||||
inference_mode_mock.assert_called_once_with()
|
||||
|
|
Loading…
Reference in New Issue