diff --git a/.github/checkgroup.yml b/.github/checkgroup.yml index 22b2dee5dd..79b65664d2 100644 --- a/.github/checkgroup.yml +++ b/.github/checkgroup.yml @@ -19,29 +19,23 @@ subprojects: - "!*.md" - "!**/*.md" checks: - - "pl-cpu (macOS-13, lightning, 3.8, 2.0, oldest)" - - "pl-cpu (macOS-14, lightning, 3.10, 2.0)" + - "pl-cpu (macOS-13, lightning, 3.8, 2.1, oldest)" - "pl-cpu (macOS-14, lightning, 3.10, 2.1)" - "pl-cpu (macOS-14, lightning, 3.10, 2.2)" - "pl-cpu (macOS-14, lightning, 3.10, 2.3)" - - "pl-cpu (ubuntu-20.04, lightning, 3.8, 2.0, oldest)" - - "pl-cpu (ubuntu-20.04, lightning, 3.10, 2.0)" + - "pl-cpu (ubuntu-20.04, lightning, 3.8, 2.1, oldest)" - "pl-cpu (ubuntu-20.04, lightning, 3.10, 2.1)" - "pl-cpu (ubuntu-20.04, lightning, 3.10, 2.2)" - "pl-cpu (ubuntu-20.04, lightning, 3.10, 2.3)" - - "pl-cpu (windows-2022, lightning, 3.8, 2.0, oldest)" - - "pl-cpu (windows-2022, lightning, 3.10, 2.0)" + - "pl-cpu (windows-2022, lightning, 3.8, 2.1, oldest)" - "pl-cpu (windows-2022, lightning, 3.10, 2.1)" - "pl-cpu (windows-2022, lightning, 3.10, 2.2)" - "pl-cpu (windows-2022, lightning, 3.10, 2.3)" - - "pl-cpu (macOS-14, pytorch, 3.8, 2.0)" - - "pl-cpu (ubuntu-20.04, pytorch, 3.8, 2.0)" - - "pl-cpu (windows-2022, pytorch, 3.8, 2.0)" - - "pl-cpu (macOS-12, pytorch, 3.11, 2.0)" + - "pl-cpu (macOS-14, pytorch, 3.8, 2.1)" + - "pl-cpu (ubuntu-20.04, pytorch, 3.8, 2.1)" + - "pl-cpu (windows-2022, pytorch, 3.8, 2.1)" - "pl-cpu (macOS-12, pytorch, 3.11, 2.1)" - - "pl-cpu (ubuntu-22.04, pytorch, 3.11, 2.0)" - "pl-cpu (ubuntu-22.04, pytorch, 3.11, 2.1)" - - "pl-cpu (windows-2022, pytorch, 3.11, 2.0)" - "pl-cpu (windows-2022, pytorch, 3.11, 2.1)" - id: "pytorch_lightning: Azure GPU" @@ -144,13 +138,11 @@ subprojects: - "!*.md" - "!**/*.md" checks: - - "build-cuda (3.10, 2.0, 11.8.0)" - "build-cuda (3.10, 2.1, 12.1.0)" - "build-cuda (3.10, 2.2, 12.1.0)" - "build-cuda (3.11, 2.1, 12.1.0)" - "build-cuda (3.11, 2.2, 12.1.0)" #- "build-NGC" - - "build-pl (3.10, 2.0, 11.8.0)" - "build-pl (3.10, 2.1, 12.1.0)" - "build-pl (3.10, 2.2, 12.1.0)" - "build-pl (3.11, 2.1, 12.1.0)" @@ -171,29 +163,23 @@ subprojects: - "!*.md" - "!**/*.md" checks: - - "fabric-cpu (macOS-13, lightning, 3.8, 2.0, oldest)" - - "fabric-cpu (macOS-14, lightning, 3.10, 2.0)" + - "fabric-cpu (macOS-13, lightning, 3.8, 2.1, oldest)" - "fabric-cpu (macOS-14, lightning, 3.11, 2.1)" - "fabric-cpu (macOS-14, lightning, 3.11, 2.2)" - "fabric-cpu (macOS-14, lightning, 3.10, 2.3)" - - "fabric-cpu (ubuntu-20.04, lightning, 3.8, 2.0, oldest)" - - "fabric-cpu (ubuntu-20.04, lightning, 3.10, 2.0)" + - "fabric-cpu (ubuntu-20.04, lightning, 3.8, 2.1, oldest)" - "fabric-cpu (ubuntu-20.04, lightning, 3.11, 2.1)" - "fabric-cpu (ubuntu-20.04, lightning, 3.11, 2.2)" - "fabric-cpu (ubuntu-20.04, lightning, 3.11, 2.3)" - - "fabric-cpu (windows-2022, lightning, 3.8, 2.0, oldest)" - - "fabric-cpu (windows-2022, lightning, 3.10, 2.0)" + - "fabric-cpu (windows-2022, lightning, 3.8, 2.1, oldest)" - "fabric-cpu (windows-2022, lightning, 3.11, 2.1)" - "fabric-cpu (windows-2022, lightning, 3.11, 2.2)" - "fabric-cpu (windows-2022, lightning, 3.11, 2.3)" - - "fabric-cpu (macOS-14, fabric, 3.8, 2.0)" - - "fabric-cpu (ubuntu-20.04, fabric, 3.8, 2.0)" - - "fabric-cpu (windows-2022, fabric, 3.8, 2.0)" - - "fabric-cpu (macOS-12, fabric, 3.11, 2.0)" + - "fabric-cpu (macOS-14, fabric, 3.8, 2.1)" + - "fabric-cpu (ubuntu-20.04, fabric, 3.8, 2.1)" + - "fabric-cpu (windows-2022, fabric, 3.8, 2.1)" - "fabric-cpu (macOS-12, fabric, 3.11, 2.1)" - - "fabric-cpu (ubuntu-22.04, fabric, 3.11, 2.0)" - "fabric-cpu (ubuntu-22.04, fabric, 3.11, 2.1)" - - "fabric-cpu (windows-2022, fabric, 3.11, 2.0)" - "fabric-cpu (windows-2022, fabric, 3.11, 2.1)" - id: "lightning_fabric: Azure GPU" diff --git a/.github/workflows/ci-tests-fabric.yml b/.github/workflows/ci-tests-fabric.yml index 009e03f38c..2c0d8d16b8 100644 --- a/.github/workflows/ci-tests-fabric.yml +++ b/.github/workflows/ci-tests-fabric.yml @@ -39,9 +39,6 @@ jobs: fail-fast: false matrix: include: - - { os: "macOS-14", pkg-name: "lightning", python-version: "3.10", pytorch-version: "2.0" } - - { os: "ubuntu-20.04", pkg-name: "lightning", python-version: "3.10", pytorch-version: "2.0" } - - { os: "windows-2022", pkg-name: "lightning", python-version: "3.10", pytorch-version: "2.0" } # only run PyTorch latest - { os: "macOS-14", pkg-name: "lightning", python-version: "3.11", pytorch-version: "2.1" } - { os: "ubuntu-20.04", pkg-name: "lightning", python-version: "3.11", pytorch-version: "2.1" } @@ -53,32 +50,29 @@ jobs: - { os: "ubuntu-20.04", pkg-name: "lightning", python-version: "3.11", pytorch-version: "2.3" } - { os: "windows-2022", pkg-name: "lightning", python-version: "3.11", pytorch-version: "2.3" } # only run PyTorch latest with Python latest, use Fabric scope to limit dependency issues - - { os: "macOS-12", pkg-name: "fabric", python-version: "3.11", pytorch-version: "2.0" } - - { os: "ubuntu-22.04", pkg-name: "fabric", python-version: "3.11", pytorch-version: "2.0" } - - { os: "windows-2022", pkg-name: "fabric", python-version: "3.11", pytorch-version: "2.0" } - { os: "macOS-12", pkg-name: "fabric", python-version: "3.11", pytorch-version: "2.1" } - { os: "ubuntu-22.04", pkg-name: "fabric", python-version: "3.11", pytorch-version: "2.1" } - { os: "windows-2022", pkg-name: "fabric", python-version: "3.11", pytorch-version: "2.1" } # "oldest" versions tests, only on minimum Python - - { os: "macOS-13", pkg-name: "lightning", python-version: "3.8", pytorch-version: "2.0", requires: "oldest" } + - { os: "macOS-13", pkg-name: "lightning", python-version: "3.8", pytorch-version: "2.1", requires: "oldest" } - { os: "ubuntu-20.04", pkg-name: "lightning", python-version: "3.8", - pytorch-version: "2.0", + pytorch-version: "2.1", requires: "oldest", } - { os: "windows-2022", pkg-name: "lightning", python-version: "3.8", - pytorch-version: "2.0", + pytorch-version: "2.1", requires: "oldest", } # "fabric" installs the standalone package - - { os: "macOS-14", pkg-name: "fabric", python-version: "3.8", pytorch-version: "2.0" } - - { os: "ubuntu-20.04", pkg-name: "fabric", python-version: "3.8", pytorch-version: "2.0" } - - { os: "windows-2022", pkg-name: "fabric", python-version: "3.8", pytorch-version: "2.0" } + - { os: "macOS-14", pkg-name: "fabric", python-version: "3.8", pytorch-version: "2.1" } + - { os: "ubuntu-20.04", pkg-name: "fabric", python-version: "3.8", pytorch-version: "2.1" } + - { os: "windows-2022", pkg-name: "fabric", python-version: "3.8", pytorch-version: "2.1" } timeout-minutes: 25 # because of building grpcio on Mac env: PACKAGE_NAME: ${{ matrix.pkg-name }} diff --git a/.github/workflows/ci-tests-pytorch.yml b/.github/workflows/ci-tests-pytorch.yml index 967369976e..b75b6e73d9 100644 --- a/.github/workflows/ci-tests-pytorch.yml +++ b/.github/workflows/ci-tests-pytorch.yml @@ -43,9 +43,6 @@ jobs: fail-fast: false matrix: include: - - { os: "macOS-14", pkg-name: "lightning", python-version: "3.10", pytorch-version: "2.0" } - - { os: "ubuntu-20.04", pkg-name: "lightning", python-version: "3.10", pytorch-version: "2.0" } - - { os: "windows-2022", pkg-name: "lightning", python-version: "3.10", pytorch-version: "2.0" } # only run PyTorch latest - { os: "macOS-14", pkg-name: "lightning", python-version: "3.10", pytorch-version: "2.1" } - { os: "ubuntu-20.04", pkg-name: "lightning", python-version: "3.10", pytorch-version: "2.1" } @@ -57,32 +54,29 @@ jobs: - { os: "ubuntu-20.04", pkg-name: "lightning", python-version: "3.10", pytorch-version: "2.3" } - { os: "windows-2022", pkg-name: "lightning", python-version: "3.10", pytorch-version: "2.3" } # only run PyTorch latest with Python latest, use PyTorch scope to limit dependency issues - - { os: "macOS-12", pkg-name: "pytorch", python-version: "3.11", pytorch-version: "2.0" } - - { os: "ubuntu-22.04", pkg-name: "pytorch", python-version: "3.11", pytorch-version: "2.0" } - - { os: "windows-2022", pkg-name: "pytorch", python-version: "3.11", pytorch-version: "2.0" } - { os: "macOS-12", pkg-name: "pytorch", python-version: "3.11", pytorch-version: "2.1" } - { os: "ubuntu-22.04", pkg-name: "pytorch", python-version: "3.11", pytorch-version: "2.1" } - { os: "windows-2022", pkg-name: "pytorch", python-version: "3.11", pytorch-version: "2.1" } # "oldest" versions tests, only on minimum Python - - { os: "macOS-13", pkg-name: "lightning", python-version: "3.8", pytorch-version: "2.0", requires: "oldest" } + - { os: "macOS-13", pkg-name: "lightning", python-version: "3.8", pytorch-version: "2.1", requires: "oldest" } - { os: "ubuntu-20.04", pkg-name: "lightning", python-version: "3.8", - pytorch-version: "2.0", + pytorch-version: "2.1", requires: "oldest", } - { os: "windows-2022", pkg-name: "lightning", python-version: "3.8", - pytorch-version: "2.0", + pytorch-version: "2.1", requires: "oldest", } # "pytorch" installs the standalone package - - { os: "macOS-14", pkg-name: "pytorch", python-version: "3.8", pytorch-version: "2.0" } - - { os: "ubuntu-20.04", pkg-name: "pytorch", python-version: "3.8", pytorch-version: "2.0" } - - { os: "windows-2022", pkg-name: "pytorch", python-version: "3.8", pytorch-version: "2.0" } + - { os: "macOS-14", pkg-name: "pytorch", python-version: "3.8", pytorch-version: "2.1" } + - { os: "ubuntu-20.04", pkg-name: "pytorch", python-version: "3.8", pytorch-version: "2.1" } + - { os: "windows-2022", pkg-name: "pytorch", python-version: "3.8", pytorch-version: "2.1" } timeout-minutes: 50 env: PACKAGE_NAME: ${{ matrix.pkg-name }} diff --git a/.github/workflows/docker-build.yml b/.github/workflows/docker-build.yml index d917ebc407..0891205421 100644 --- a/.github/workflows/docker-build.yml +++ b/.github/workflows/docker-build.yml @@ -43,7 +43,6 @@ jobs: include: # We only release one docker image per PyTorch version. # Make sure the matrix here matches the one below. - - { python_version: "3.10", pytorch_version: "2.0", cuda_version: "11.8.0" } - { python_version: "3.10", pytorch_version: "2.1", cuda_version: "12.1.0" } - { python_version: "3.10", pytorch_version: "2.2", cuda_version: "12.1.0" } - { python_version: "3.11", pytorch_version: "2.1", cuda_version: "12.1.0" } @@ -104,7 +103,6 @@ jobs: include: # These are the base images for PL release docker images. # Make sure the matrix here matches the one above. - - { python_version: "3.10", pytorch_version: "2.0", cuda_version: "11.8.0" } - { python_version: "3.10", pytorch_version: "2.1", cuda_version: "12.1.0" } - { python_version: "3.10", pytorch_version: "2.2", cuda_version: "12.1.0" } - { python_version: "3.11", pytorch_version: "2.1", cuda_version: "12.1.0" } diff --git a/docs/source-fabric/advanced/compile.rst b/docs/source-fabric/advanced/compile.rst index a8e1cc2db2..ed46a1f822 100644 --- a/docs/source-fabric/advanced/compile.rst +++ b/docs/source-fabric/advanced/compile.rst @@ -5,10 +5,6 @@ Speed up models by compiling them Compiling your PyTorch model can result in significant speedups, especially on the latest generations of GPUs. This guide shows you how to apply `torch.compile `_ correctly in your code. -.. note:: - - This requires PyTorch >= 2.0. - ---- diff --git a/docs/source-fabric/advanced/model_init.rst b/docs/source-fabric/advanced/model_init.rst index 4b31df036f..f5f76e8aa0 100644 --- a/docs/source-fabric/advanced/model_init.rst +++ b/docs/source-fabric/advanced/model_init.rst @@ -81,4 +81,4 @@ When training distributed models with :doc:`FSDP/TP ` or D .. note:: Empty-init is experimental and the behavior may change in the future. - For distributed models on PyTorch 2.1+, it is required that all user-defined modules that manage parameters implement a ``reset_parameters()`` method (all PyTorch built-in modules have this too). + For distributed models, it is required that all user-defined modules that manage parameters implement a ``reset_parameters()`` method (all PyTorch built-in modules have this too). diff --git a/docs/source-pytorch/advanced/compile.rst b/docs/source-pytorch/advanced/compile.rst index 73d5f4fbc2..484559e111 100644 --- a/docs/source-pytorch/advanced/compile.rst +++ b/docs/source-pytorch/advanced/compile.rst @@ -5,10 +5,6 @@ Speed up models by compiling them Compiling your LightningModule can result in significant speedups, especially on the latest generations of GPUs. This guide shows you how to apply `torch.compile `_ correctly in your code. -.. note:: - - This requires PyTorch >= 2.0. - ---- diff --git a/requirements/fabric/base.txt b/requirements/fabric/base.txt index aac884d9c6..7ca4556821 100644 --- a/requirements/fabric/base.txt +++ b/requirements/fabric/base.txt @@ -1,8 +1,8 @@ # NOTE: the upper bound for the package version is only set for CI stability, and it is dropped while installing this package # in case you want to preserve/enforce restrictions on the latest compatible version, add "strict" as an in-line comment -numpy >=1.17.2, <1.27.0 -torch >=2.0.0, <2.4.0 +numpy >=1.21.0, <1.27.0 +torch >=2.1.0, <2.4.0 fsspec[http] >=2022.5.0, <2024.4.0 packaging >=20.0, <=23.1 typing-extensions >=4.4.0, <4.10.0 diff --git a/requirements/fabric/examples.txt b/requirements/fabric/examples.txt index 0e2feb97ec..49ffde9d0f 100644 --- a/requirements/fabric/examples.txt +++ b/requirements/fabric/examples.txt @@ -1,6 +1,6 @@ # NOTE: the upper bound for the package version is only set for CI stability, and it is dropped while installing this package # in case you want to preserve/enforce restrictions on the latest compatible version, add "strict" as an in-line comment -torchvision >=0.15.0, <0.19.0 +torchvision >=0.16.0, <0.19.0 torchmetrics >=0.10.0, <1.3.0 lightning-utilities >=0.8.0, <0.12.0 diff --git a/requirements/pytorch/base.txt b/requirements/pytorch/base.txt index 6372357b6d..cd71466551 100644 --- a/requirements/pytorch/base.txt +++ b/requirements/pytorch/base.txt @@ -1,8 +1,8 @@ # NOTE: the upper bound for the package version is only set for CI stability, and it is dropped while installing this package # in case you want to preserve/enforce restrictions on the latest compatible version, add "strict" as an in-line comment -numpy >=1.17.2, <1.27.0 -torch >=2.0.0, <2.4.0 +numpy >=1.21.0, <1.27.0 +torch >=2.1.0, <2.4.0 tqdm >=4.57.0, <4.67.0 PyYAML >=5.4, <6.1.0 fsspec[http] >=2022.5.0, <2024.4.0 diff --git a/requirements/pytorch/examples.txt b/requirements/pytorch/examples.txt index 55b85025bd..e4b1bc31e9 100644 --- a/requirements/pytorch/examples.txt +++ b/requirements/pytorch/examples.txt @@ -2,7 +2,7 @@ # in case you want to preserve/enforce restrictions on the latest compatible version, add "strict" as an in-line comment requests <2.32.0 -torchvision >=0.15.0, <0.19.0 +torchvision >=0.16.0, <0.19.0 ipython[all] <8.15.0 torchmetrics >=0.10.0, <1.3.0 lightning-utilities >=0.8.0, <0.12.0 diff --git a/requirements/pytorch/test.txt b/requirements/pytorch/test.txt index 94a06630df..472c4157df 100644 --- a/requirements/pytorch/test.txt +++ b/requirements/pytorch/test.txt @@ -8,8 +8,8 @@ pytest-random-order ==1.1.0 # needed in tests cloudpickle >=1.3, <2.3.0 scikit-learn >0.22.1, <1.4.0 -onnx >=0.14.0, <1.15.0 -onnxruntime >=0.15.0, <1.17.0 +onnx >=1.12.0, <1.15.0 +onnxruntime >=1.12.0, <1.17.0 psutil <5.9.6 # for `DeviceStatsMonitor` pandas >1.0, <2.2.0 # needed in benchmarks fastapi # for `ServableModuleValidator` # not setting version as re-defined in App diff --git a/src/lightning/fabric/CHANGELOG.md b/src/lightning/fabric/CHANGELOG.md index 6155644ed6..9857598307 100644 --- a/src/lightning/fabric/CHANGELOG.md +++ b/src/lightning/fabric/CHANGELOG.md @@ -27,7 +27,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Removed -- +- Removed support for PyTorch 2.1 ([#20009](https://github.com/Lightning-AI/lightning/pull/20009)) + - diff --git a/src/lightning/fabric/__init__.py b/src/lightning/fabric/__init__.py index 75752d8b94..26f01aad64 100644 --- a/src/lightning/fabric/__init__.py +++ b/src/lightning/fabric/__init__.py @@ -21,7 +21,7 @@ if not _root_logger.hasHandlers(): _logger.propagate = False -# In PyTorch 2.0+, setting this variable will force `torch.cuda.is_available()` and `torch.cuda.device_count()` +# Setting this variable will force `torch.cuda.is_available()` and `torch.cuda.device_count()` # to use an NVML-based implementation that doesn't poison forks. # https://github.com/pytorch/pytorch/issues/83973 os.environ["PYTORCH_NVML_BASED_CUDA_CHECK"] = "1" diff --git a/src/lightning/fabric/fabric.py b/src/lightning/fabric/fabric.py index 71d8f623dc..b9032fe7a9 100644 --- a/src/lightning/fabric/fabric.py +++ b/src/lightning/fabric/fabric.py @@ -225,8 +225,7 @@ class Fabric: _reapply_compile: If ``True`` (default), and the model was ``torch.compile``d before, the corresponding :class:`~torch._dynamo.OptimizedModule` wrapper will be removed and reapplied with the same settings after the model was set up by the strategy (e.g., after the model was wrapped by DDP, - FSDP etc.). Only applies on PyTorch >= 2.1. Set it to ``False`` if compiling DDP/FSDP is causing - issues. + FSDP etc.). Set it to ``False`` if compiling DDP/FSDP is causing issues. Returns: The tuple containing wrapped module and the optimizers, in the same order they were passed in. @@ -292,8 +291,7 @@ class Fabric: _reapply_compile: If ``True`` (default), and the model was ``torch.compile``d before, the corresponding :class:`~torch._dynamo.OptimizedModule` wrapper will be removed and reapplied with the same settings after the model was set up by the strategy (e.g., after the model was wrapped by DDP, - FSDP etc.). Only applies on PyTorch >= 2.1. Set it to ``False`` if compiling DDP/FSDP is causing - issues. + FSDP etc.). Set it to ``False`` if compiling DDP/FSDP is causing issues. Returns: The wrapped model. diff --git a/src/lightning/fabric/strategies/fsdp.py b/src/lightning/fabric/strategies/fsdp.py index 9a711b8449..b8a3a26847 100644 --- a/src/lightning/fabric/strategies/fsdp.py +++ b/src/lightning/fabric/strategies/fsdp.py @@ -63,11 +63,10 @@ from lightning.fabric.utilities.distributed import ( ) from lightning.fabric.utilities.distributed import group as _group from lightning.fabric.utilities.imports import ( - _TORCH_GREATER_EQUAL_2_1, _TORCH_GREATER_EQUAL_2_2, _TORCH_GREATER_EQUAL_2_3, ) -from lightning.fabric.utilities.init import _EmptyInit, _has_meta_device_parameters_or_buffers +from lightning.fabric.utilities.init import _has_meta_device_parameters_or_buffers from lightning.fabric.utilities.load import _METADATA_FILENAME, _lazy_load, _materialize_tensors, _move_state_into from lightning.fabric.utilities.rank_zero import rank_zero_deprecation, rank_zero_only, rank_zero_warn from lightning.fabric.utilities.seed import reset_seed @@ -325,7 +324,7 @@ class FSDPStrategy(ParallelStrategy, _Sharded): if self._fsdp_kwargs.get("use_orig_params"): return super().setup_optimizer(optimizer) if not _optimizer_has_flat_params(optimizer): - # We avoid this limitation in PyTorch >= 2.0 by setting `use_orig_params=True` + # We avoid this limitation by setting `use_orig_params=True` raise ValueError( "The optimizer does not seem to reference any FSDP parameters. HINT: Make sure to create the optimizer" " after setting up the model." @@ -340,15 +339,12 @@ class FSDPStrategy(ParallelStrategy, _Sharded): def module_init_context(self, empty_init: Optional[bool] = None) -> ContextManager: precision_init_ctx = self.precision.module_init_context() module_sharded_ctx = self.module_sharded_context() - empty_ctx = _EmptyInit(enabled=bool(empty_init)) stack = ExitStack() - if _TORCH_GREATER_EQUAL_2_1 and empty_init: + if empty_init: # Materialization happens in `setup`. When modules get wrapped by FSDP, the sequence of operations is: # 1) materialize module 2) call `reset_parameters()` 3) shard the module. # These operations are applied to each submodule 'bottom up' in the module hierarchy. stack.enter_context(torch.device("meta")) - else: - stack.enter_context(empty_ctx) stack.enter_context(precision_init_ctx) stack.enter_context(module_sharded_ctx) return stack @@ -697,18 +693,13 @@ def _activation_checkpointing_kwargs( classes = tuple(activation_checkpointing) else: classes = (activation_checkpointing,) - if _TORCH_GREATER_EQUAL_2_1: - rank_zero_deprecation( - f"`FSDPStrategy(activation_checkpointing={activation_checkpointing})` is deprecated, use " - f"`FSDPStrategy(activation_checkpointing_policy={set(classes)})` instead." - ) + rank_zero_deprecation( + f"`FSDPStrategy(activation_checkpointing={activation_checkpointing})` is deprecated, use " + f"`FSDPStrategy(activation_checkpointing_policy={set(classes)})` instead." + ) return {"check_fn": lambda submodule: isinstance(submodule, classes)} if isinstance(activation_checkpointing_policy, set): - if _TORCH_GREATER_EQUAL_2_1: - return _auto_wrap_policy_kwargs(activation_checkpointing_policy, {}) - return {"check_fn": lambda submodule: isinstance(submodule, tuple(activation_checkpointing_policy))} - if not _TORCH_GREATER_EQUAL_2_1: - raise ValueError("`activation_checkpointing_policy` requires torch >= 2.1.0. HINT: `pip install -U torch`") + return _auto_wrap_policy_kwargs(activation_checkpointing_policy, {}) return {"auto_wrap_policy": activation_checkpointing_policy} @@ -716,15 +707,10 @@ def _auto_wrap_policy_kwargs(policy: Optional["_POLICY"], kwargs: Dict) -> Dict: if policy is None: return kwargs if isinstance(policy, set): - if _TORCH_GREATER_EQUAL_2_1: - from torch.distributed.fsdp.wrap import ModuleWrapPolicy + from torch.distributed.fsdp.wrap import ModuleWrapPolicy - policy = ModuleWrapPolicy(policy) - else: - from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy + policy = ModuleWrapPolicy(policy) - # this is not transformer specific despite the name - policy = partial(transformer_auto_wrap_policy, transformer_layer_cls=policy) kwargs["auto_wrap_policy"] = policy return kwargs @@ -829,11 +815,8 @@ def _get_full_state_dict_context( from torch.distributed.fsdp import FullyShardedDataParallel as FSDP from torch.distributed.fsdp.api import FullOptimStateDictConfig - # In PyTorch < 2.1, offload to CPU in combination with `world_size=1` is not possible - offload_to_cpu = world_size > 1 or _TORCH_GREATER_EQUAL_2_1 - state_dict_config = FullStateDictConfig(offload_to_cpu=offload_to_cpu, rank0_only=rank0_only) - - optim_state_dict_config = FullOptimStateDictConfig(offload_to_cpu=offload_to_cpu, rank0_only=rank0_only) + state_dict_config = FullStateDictConfig(offload_to_cpu=True, rank0_only=rank0_only) + optim_state_dict_config = FullOptimStateDictConfig(offload_to_cpu=True, rank0_only=rank0_only) state_dict_type_context = FSDP.state_dict_type( module=module, state_dict_type=StateDictType.FULL_STATE_DICT, diff --git a/src/lightning/fabric/utilities/imports.py b/src/lightning/fabric/utilities/imports.py index 46374e23ad..fc40175ff5 100644 --- a/src/lightning/fabric/utilities/imports.py +++ b/src/lightning/fabric/utilities/imports.py @@ -26,13 +26,10 @@ _IS_WINDOWS = platform.system() == "Windows" # 2. The inspection mode via `python -i`: https://stackoverflow.com/a/6879085/1162383 _IS_INTERACTIVE = hasattr(sys, "ps1") or bool(sys.flags.interactive) -_TORCH_GREATER_EQUAL_2_1 = compare_version("torch", operator.ge, "2.1.0") _TORCH_GREATER_EQUAL_2_2 = compare_version("torch", operator.ge, "2.2.0") _TORCH_GREATER_EQUAL_2_3 = compare_version("torch", operator.ge, "2.3.0") _TORCH_GREATER_EQUAL_2_4 = compare_version("torch", operator.ge, "2.4.0", use_base_version=True) -_TORCH_EQUAL_2_0 = compare_version("torch", operator.ge, "2.0.0") and not _TORCH_GREATER_EQUAL_2_1 - _PYTHON_GREATER_EQUAL_3_8_0 = (sys.version_info.major, sys.version_info.minor) >= (3, 8) _PYTHON_GREATER_EQUAL_3_10_0 = (sys.version_info.major, sys.version_info.minor) >= (3, 10) diff --git a/src/lightning/fabric/utilities/init.py b/src/lightning/fabric/utilities/init.py index fccdce7aa8..c92dfd8c2e 100644 --- a/src/lightning/fabric/utilities/init.py +++ b/src/lightning/fabric/utilities/init.py @@ -20,7 +20,6 @@ from torch.optim import Optimizer from torch.overrides import TorchFunctionMode from typing_extensions import override -from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_1 from lightning.fabric.utilities.rank_zero import rank_zero_warn from lightning.fabric.utilities.types import _DEVICE @@ -61,8 +60,6 @@ class _EmptyInit(TorchFunctionMode): def _materialize(module: Module, device: _DEVICE) -> None: """Materialize a module.""" - if not _TORCH_GREATER_EQUAL_2_1: - raise RuntimeError("recurse=False requires torch 2.1") module.to_empty(device=device, recurse=False) if not hasattr(module, "reset_parameters"): raise TypeError( diff --git a/src/lightning/fabric/utilities/testing/_runif.py b/src/lightning/fabric/utilities/testing/_runif.py index 9a6f5554ba..f282651048 100644 --- a/src/lightning/fabric/utilities/testing/_runif.py +++ b/src/lightning/fabric/utilities/testing/_runif.py @@ -24,7 +24,6 @@ from lightning.fabric.accelerators import XLAAccelerator from lightning.fabric.accelerators.cuda import num_cuda_devices from lightning.fabric.accelerators.mps import MPSAccelerator from lightning.fabric.strategies.deepspeed import _DEEPSPEED_AVAILABLE -from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_1 def _runif_reasons( @@ -116,13 +115,9 @@ def _runif_reasons( reasons.append("Deepspeed") if dynamo: - if _TORCH_GREATER_EQUAL_2_1: - from torch._dynamo.eval_frame import is_dynamo_supported + from torch._dynamo.eval_frame import is_dynamo_supported - cond = not is_dynamo_supported() - else: - cond = sys.platform == "win32" or sys.version_info >= (3, 11) - if cond: + if not is_dynamo_supported(): reasons.append("torch.dynamo") return reasons, kwargs diff --git a/src/lightning/fabric/utilities/throughput.py b/src/lightning/fabric/utilities/throughput.py index f483c274c3..6743da7b34 100644 --- a/src/lightning/fabric/utilities/throughput.py +++ b/src/lightning/fabric/utilities/throughput.py @@ -18,7 +18,6 @@ from typing import TYPE_CHECKING, Any, Callable, Deque, Dict, List, Optional, Ty import torch from typing_extensions import override -from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_1 from lightning.fabric.utilities.rank_zero import rank_zero_only, rank_zero_warn if TYPE_CHECKING: @@ -292,8 +291,6 @@ def measure_flops( FLOPs will be included in the result. """ - if not _TORCH_GREATER_EQUAL_2_1: - raise ImportError("`measure_flops` requires PyTorch >= 2.1.") from torch.utils.flop_counter import FlopCounterMode flop_counter = FlopCounterMode(display=False) diff --git a/src/lightning/pytorch/CHANGELOG.md b/src/lightning/pytorch/CHANGELOG.md index 39dd56e8e7..8e026f485f 100644 --- a/src/lightning/pytorch/CHANGELOG.md +++ b/src/lightning/pytorch/CHANGELOG.md @@ -27,7 +27,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Removed -- +- Removed support for PyTorch 2.1 ([#20009](https://github.com/Lightning-AI/lightning/pull/20009)) + - diff --git a/src/lightning/pytorch/core/module.py b/src/lightning/pytorch/core/module.py index 68395ce97d..c78cb87bb9 100644 --- a/src/lightning/pytorch/core/module.py +++ b/src/lightning/pytorch/core/module.py @@ -51,7 +51,6 @@ from lightning.fabric.loggers import Logger as FabricLogger from lightning.fabric.utilities.apply_func import convert_to_tensors from lightning.fabric.utilities.cloud_io import get_filesystem from lightning.fabric.utilities.device_dtype_mixin import _DeviceDtypeModuleMixin -from lightning.fabric.utilities.imports import _IS_WINDOWS, _TORCH_GREATER_EQUAL_2_1 from lightning.fabric.utilities.types import _MAP_LOCATION_TYPE, _PATH from lightning.fabric.wrappers import _FabricOptimizer from lightning.pytorch.callbacks.callback import Callback @@ -67,7 +66,7 @@ from lightning.pytorch.utilities import GradClipAlgorithmType from lightning.pytorch.utilities.exceptions import MisconfigurationException from lightning.pytorch.utilities.imports import _TORCHMETRICS_GREATER_EQUAL_0_9_1 from lightning.pytorch.utilities.model_helpers import _restricted_classmethod -from lightning.pytorch.utilities.rank_zero import WarningCache, rank_zero_debug, rank_zero_warn +from lightning.pytorch.utilities.rank_zero import WarningCache, rank_zero_warn from lightning.pytorch.utilities.signature_utils import is_param_in_hook_signature from lightning.pytorch.utilities.types import ( _METRIC, @@ -140,7 +139,6 @@ class LightningModule( self._current_fx_name: Optional[str] = None self._param_requires_grad_state: Dict[str, bool] = {} self._metric_attributes: Optional[Dict[int, str]] = None - self._register_sharded_tensor_state_dict_hooks_if_available() self._compiler_ctx: Optional[Dict[str, Any]] = None # attributes only used when using fabric @@ -1390,9 +1388,7 @@ class LightningModule( """ if not _ONNX_AVAILABLE: - raise ModuleNotFoundError( - f"`torch>=2.0` requires `onnx` to be installed to use `{type(self).__name__}.to_onnx()`" - ) + raise ModuleNotFoundError(f"`{type(self).__name__}.to_onnx()` requires `onnx` to be installed.") mode = self.training @@ -1599,24 +1595,6 @@ class LightningModule( state["_trainer"] = None return state - def _register_sharded_tensor_state_dict_hooks_if_available(self) -> None: - """Adds ShardedTensor state dict hooks if ShardedTensors are supported. - - These hooks ensure that ShardedTensors are included when saving, and are loaded the LightningModule correctly. - - """ - if _TORCH_GREATER_EQUAL_2_1: - # ShardedTensor is deprecated in favor of DistributedTensor - return - if _IS_WINDOWS or not torch.distributed.is_available(): - rank_zero_debug("Could not register sharded tensor state dict hooks") - return - - from torch.distributed._shard.sharded_tensor import pre_load_state_dict_hook, state_dict_hook - - self._register_state_dict_hook(state_dict_hook) - self._register_load_state_dict_pre_hook(pre_load_state_dict_hook, True) - @contextmanager def _jit_is_scripting() -> Generator: diff --git a/src/lightning/pytorch/demos/transformer.py b/src/lightning/pytorch/demos/transformer.py index 833c15d91c..3f5bcb696a 100644 --- a/src/lightning/pytorch/demos/transformer.py +++ b/src/lightning/pytorch/demos/transformer.py @@ -85,8 +85,7 @@ class PositionalEncoding(nn.Module): if self.pe is None: # 1) can't use buffer, see https://github.com/pytorch/pytorch/issues/68407 # 2) can't use parameter becauses pe gets sliced and DDP requires all params to participate in forward - # 3) can't make it a `requires_grad=False` parameter because FSDP in PyTorch < 2.1 needs all params to - # require grad + # TODO: Could make this a `nn.Parameter` with `requires_grad=False` self.pe = self._init_pos_encoding(device=x.device) x + self.pe[: x.size(0), :] diff --git a/src/lightning/pytorch/loops/utilities.py b/src/lightning/pytorch/loops/utilities.py index 8ca54184b4..99ea5c4254 100644 --- a/src/lightning/pytorch/loops/utilities.py +++ b/src/lightning/pytorch/loops/utilities.py @@ -21,7 +21,6 @@ from torch import Tensor import lightning.pytorch as pl from lightning.fabric.utilities.distributed import _distributed_is_initialized -from lightning.fabric.utilities.imports import _TORCH_EQUAL_2_0 from lightning.fabric.utilities.warnings import PossibleUserWarning from lightning.pytorch.accelerators.xla import XLAAccelerator from lightning.pytorch.callbacks.timer import Timer @@ -171,9 +170,6 @@ def _no_grad_context(loop_run: Callable) -> Callable: elif isinstance(self.trainer.strategy, FSDPStrategy): # https://github.com/pytorch/pytorch/issues/95957 context_manager = torch.no_grad - elif _TORCH_EQUAL_2_0 and self.trainer.lightning_module._compiler_ctx is not None: - # avoid: `RuntimeError: Inference tensors do not track version counter` fixed in v2.1 - context_manager = torch.no_grad elif self.inference_mode: context_manager = torch.inference_mode else: diff --git a/src/lightning/pytorch/strategies/fsdp.py b/src/lightning/pytorch/strategies/fsdp.py index 90f6c1febd..ab6e579c30 100644 --- a/src/lightning/pytorch/strategies/fsdp.py +++ b/src/lightning/pytorch/strategies/fsdp.py @@ -67,11 +67,8 @@ from lightning.fabric.utilities.distributed import ( _sync_ddp_if_available, ) from lightning.fabric.utilities.distributed import group as _group -from lightning.fabric.utilities.imports import ( - _TORCH_GREATER_EQUAL_2_1, - _TORCH_GREATER_EQUAL_2_2, -) -from lightning.fabric.utilities.init import _EmptyInit, _has_meta_device_parameters_or_buffers +from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_2 +from lightning.fabric.utilities.init import _has_meta_device_parameters_or_buffers from lightning.fabric.utilities.load import _lazy_load, _materialize_tensors from lightning.fabric.utilities.optimizer import _optimizers_to_device from lightning.fabric.utilities.seed import reset_seed @@ -368,8 +365,8 @@ class FSDPStrategy(ParallelStrategy): invalid_params_error = False try: - # In PyTorch < 2.0, or if `use_orig_params=False` the user needs to do access - # `self.trainer.model.parameters()` in configure_optimizers() + # If `use_orig_params=False` the user needs to do access `self.trainer.model.parameters()` in + # `configure_optimizers()` super().setup_optimizers(trainer) except ValueError as ex: if "optimizer got an empty parameter list" not in str(ex): @@ -377,7 +374,7 @@ class FSDPStrategy(ParallelStrategy): invalid_params_error = True if invalid_params_error or any(not _optimizer_has_flat_params(optimizer) for optimizer in self.optimizers): - # We avoid this limitation in PyTorch >= 2.0 by setting `use_orig_params=True` + # We avoid this limitation by setting `use_orig_params=True` raise ValueError( "The optimizer does not seem to reference any FSDP parameters. HINT: Make sure to create the" " optimizer after setting up the model by referencing `self.trainer.model.parameters()` in the" @@ -393,14 +390,10 @@ class FSDPStrategy(ParallelStrategy): @contextmanager @override def tensor_init_context(self, empty_init: Optional[bool] = None) -> Generator[None, None, None]: - empty_init_context: Union[torch.device, _EmptyInit, nullcontext] - if _TORCH_GREATER_EQUAL_2_1 and empty_init: - # Materialization happens in `setup`. When modules get wrapped by FSDP, the sequence of operations is: - # 1) materialize module 2) call `reset_parameters()` 3) shard the module. - # These operations are applied to each submodule 'bottom up' in the module hierarchy. - empty_init_context = torch.device("meta") - else: - empty_init_context = _EmptyInit(enabled=bool(empty_init)) + # Materialization happens in `setup`. When modules get wrapped by FSDP, the sequence of operations is: + # 1) materialize module 2) call `reset_parameters()` 3) shard the module. + # These operations are applied to each submodule 'bottom up' in the module hierarchy. + empty_init_context = torch.device("meta") if empty_init else nullcontext() with empty_init_context, self.precision_plugin.tensor_init_context(): yield diff --git a/src/lightning/pytorch/trainer/connectors/logger_connector/result.py b/src/lightning/pytorch/trainer/connectors/logger_connector/result.py index d7320c2c2e..2f4ad406f4 100644 --- a/src/lightning/pytorch/trainer/connectors/logger_connector/result.py +++ b/src/lightning/pytorch/trainer/connectors/logger_connector/result.py @@ -24,7 +24,6 @@ from typing_extensions import TypedDict, override from lightning.fabric.utilities import move_data_to_device from lightning.fabric.utilities.apply_func import convert_tensors_to_scalars from lightning.fabric.utilities.distributed import _distributed_is_initialized -from lightning.fabric.utilities.imports import _TORCH_EQUAL_2_0 from lightning.pytorch.utilities.data import extract_batch_size from lightning.pytorch.utilities.exceptions import MisconfigurationException from lightning.pytorch.utilities.imports import _TORCHMETRICS_GREATER_EQUAL_1_0_0 @@ -112,7 +111,7 @@ class _Metadata: on_step: bool = False on_epoch: bool = True # https://github.com/pytorch/pytorch/issues/96197 - reduce_fx: Callable = "mean" if _TORCH_EQUAL_2_0 else torch.mean # type: ignore[assignment] + reduce_fx: Callable = torch.mean enable_graph: bool = False add_dataloader_idx: bool = True dataloader_idx: Optional[int] = None @@ -362,7 +361,7 @@ class _ResultCollection(dict): on_step: bool = False, on_epoch: bool = True, # https://github.com/pytorch/pytorch/issues/96197 - reduce_fx: Callable = "mean" if _TORCH_EQUAL_2_0 else torch.mean, # type: ignore[assignment] + reduce_fx: Callable = torch.mean, enable_graph: bool = False, sync_dist: bool = False, sync_dist_fn: Callable = _Sync.no_op, diff --git a/src/lightning/pytorch/utilities/compile.py b/src/lightning/pytorch/utilities/compile.py index 7c5a806774..cb2433e04b 100644 --- a/src/lightning/pytorch/utilities/compile.py +++ b/src/lightning/pytorch/utilities/compile.py @@ -17,7 +17,6 @@ import torch from torch._dynamo import OptimizedModule import lightning.pytorch as pl -from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_1 from lightning.pytorch.strategies import DDPStrategy, DeepSpeedStrategy, FSDPStrategy, SingleDeviceStrategy, Strategy from lightning.pytorch.utilities.model_helpers import _check_mixed_imports @@ -56,11 +55,7 @@ def from_compiled(model: OptimizedModule) -> "pl.LightningModule": } orig_module.forward = model.dynamo_ctx(orig_module.forward) # type: ignore[method-assign] - if not _TORCH_GREATER_EQUAL_2_1: # https://github.com/pytorch/pytorch/issues/95630 - orig_module.forward._torchdynamo_inline = orig_module.forward orig_module.training_step = model.dynamo_ctx(orig_module.training_step) # type: ignore[method-assign] - if not _TORCH_GREATER_EQUAL_2_1: # https://github.com/pytorch/pytorch/issues/95630 - orig_module.training_step._torchdynamo_inline = orig_module.training_step orig_module.validation_step = model.dynamo_ctx(orig_module.validation_step) # type: ignore[method-assign] orig_module.test_step = model.dynamo_ctx(orig_module.test_step) # type: ignore[method-assign] orig_module.predict_step = model.dynamo_ctx(orig_module.predict_step) # type: ignore[method-assign] diff --git a/tests/tests_fabric/conftest.py b/tests/tests_fabric/conftest.py index c927548338..8b0d83d7f2 100644 --- a/tests/tests_fabric/conftest.py +++ b/tests/tests_fabric/conftest.py @@ -101,7 +101,10 @@ def thread_police_duuu_daaa_duuu_daaa(): assert not thread.is_alive() elif isinstance(thread, _ChildProcessObserver): thread.join(timeout=10) - elif thread.name == "QueueFeederThread": # tensorboardX + elif ( + thread.name == "QueueFeederThread" # tensorboardX + or thread.name == "QueueManagerThread" # torch.compile + ): thread.join(timeout=20) elif ( sys.version_info >= (3, 9) diff --git a/tests/tests_fabric/plugins/precision/test_bitsandbytes.py b/tests/tests_fabric/plugins/precision/test_bitsandbytes.py index a88e7c2be7..c45adef192 100644 --- a/tests/tests_fabric/plugins/precision/test_bitsandbytes.py +++ b/tests/tests_fabric/plugins/precision/test_bitsandbytes.py @@ -148,7 +148,7 @@ def test_bitsandbytes_layers(args, expected): assert model.l.weight.dtype == expected -@RunIf(min_cuda_gpus=1, min_torch="2.1") +@RunIf(min_cuda_gpus=1) @pytest.mark.skipif(not _BITSANDBYTES_AVAILABLE, reason="bitsandbytes unavailable") @pytest.mark.parametrize( ("args", "expected"), @@ -232,7 +232,7 @@ def test_bitsandbytes_layers_meta_device(args, expected, tmp_path): assert model.l.weight.dtype == expected -@RunIf(min_cuda_gpus=1, min_torch="2.1") +@RunIf(min_cuda_gpus=1) @pytest.mark.skipif(not _BITSANDBYTES_AVAILABLE, reason="bitsandbytes unavailable") def test_load_quantized_checkpoint(tmp_path): """Test that a checkpoint saved from a quantized model can be loaded back into a quantized model.""" diff --git a/tests/tests_fabric/strategies/test_ddp_integration.py b/tests/tests_fabric/strategies/test_ddp_integration.py index 281f0d47ba..a7ed09b00b 100644 --- a/tests/tests_fabric/strategies/test_ddp_integration.py +++ b/tests/tests_fabric/strategies/test_ddp_integration.py @@ -75,7 +75,7 @@ def _run_ddp_save_load(fabric, tmp_path): assert_params_equal(params_before, wrapped_model.parameters()) -@RunIf(min_cuda_gpus=2, standalone=True, min_torch="2.1.0", dynamo=True) +@RunIf(min_cuda_gpus=2, standalone=True, dynamo=True) @mock.patch("lightning.fabric.wrappers.torch.compile", Mock(wraps=torch.compile)) @mock.patch.dict(os.environ, {}) def test_reapply_compile(): diff --git a/tests/tests_fabric/strategies/test_fsdp.py b/tests/tests_fabric/strategies/test_fsdp.py index 1cf2a4d2f1..0c46e7ac17 100644 --- a/tests/tests_fabric/strategies/test_fsdp.py +++ b/tests/tests_fabric/strategies/test_fsdp.py @@ -16,7 +16,6 @@ from re import escape from unittest import mock from unittest.mock import ANY, MagicMock, Mock -import lightning.fabric import pytest import torch import torch.nn as nn @@ -28,8 +27,9 @@ from lightning.fabric.strategies.fsdp import ( _get_full_state_dict_context, _is_sharded_checkpoint, ) -from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_1, _TORCH_GREATER_EQUAL_2_2 +from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_2 from torch.distributed.fsdp.fully_sharded_data_parallel import CPUOffload, FullyShardedDataParallel, MixedPrecision +from torch.distributed.fsdp.wrap import ModuleWrapPolicy from torch.optim import Adam @@ -147,13 +147,6 @@ def test_no_backward_sync(): module.no_sync.assert_called_once() -def test_activation_checkpointing_support(monkeypatch): - """Test that we error out if activation checkpointing requires a newer PyTorch version.""" - monkeypatch.setattr(lightning.fabric.strategies.fsdp, "_TORCH_GREATER_EQUAL_2_1", False) - with pytest.raises(ValueError, match="activation_checkpointing_policy` requires torch >= 2.1.0"): - FSDPStrategy(activation_checkpointing_policy=Mock()) - - def test_activation_checkpointing(): """Test that the FSDP strategy can apply activation checkpointing to the given layers.""" @@ -170,28 +163,13 @@ def test_activation_checkpointing(): self.layer1 = Block2(2, 2) self.layer2 = nn.Linear(3, 3) - if _TORCH_GREATER_EQUAL_2_1: - from torch.distributed.fsdp.wrap import ModuleWrapPolicy + strategy = FSDPStrategy(activation_checkpointing_policy={Block1}) + assert set(strategy._activation_checkpointing_kwargs) == {"auto_wrap_policy"} + assert isinstance(strategy._activation_checkpointing_kwargs["auto_wrap_policy"], ModuleWrapPolicy) - strategy = FSDPStrategy(activation_checkpointing_policy={Block1}) - assert set(strategy._activation_checkpointing_kwargs) == {"auto_wrap_policy"} - assert isinstance(strategy._activation_checkpointing_kwargs["auto_wrap_policy"], ModuleWrapPolicy) - - strategy = FSDPStrategy(activation_checkpointing_policy=ModuleWrapPolicy({Block1, Block2})) - assert set(strategy._activation_checkpointing_kwargs) == {"auto_wrap_policy"} - assert isinstance(strategy._activation_checkpointing_kwargs["auto_wrap_policy"], ModuleWrapPolicy) - else: - strategy = FSDPStrategy(activation_checkpointing=Block1) - assert set(strategy._activation_checkpointing_kwargs) == {"check_fn"} - - strategy = FSDPStrategy(activation_checkpointing=[Block1, Block2]) - assert set(strategy._activation_checkpointing_kwargs) == {"check_fn"} - - strategy = FSDPStrategy(activation_checkpointing_policy={Block1}) - assert set(strategy._activation_checkpointing_kwargs) == {"check_fn"} - - strategy = FSDPStrategy(activation_checkpointing_policy={Block1, Block2}) - assert set(strategy._activation_checkpointing_kwargs) == {"check_fn"} + strategy = FSDPStrategy(activation_checkpointing_policy=ModuleWrapPolicy({Block1, Block2})) + assert set(strategy._activation_checkpointing_kwargs) == {"auto_wrap_policy"} + assert isinstance(strategy._activation_checkpointing_kwargs["auto_wrap_policy"], ModuleWrapPolicy) strategy._parallel_devices = [torch.device("cuda", 0)] with mock.patch("torch.distributed.fsdp.FullyShardedDataParallel", new=MagicMock), mock.patch( @@ -401,15 +379,13 @@ def test_set_timeout(init_process_group_mock): ) -@pytest.mark.parametrize("torch_ge_2_1", [True, False]) @mock.patch("torch.distributed.fsdp.fully_sharded_data_parallel.FullyShardedDataParallel.set_state_dict_type") -def test_get_full_state_dict_context_offload(set_type_mock, monkeypatch, torch_ge_2_1): - """Test that the state dict context manager handles CPU offloading depending on the PyTorch version.""" - monkeypatch.setattr("lightning.fabric.strategies.fsdp._TORCH_GREATER_EQUAL_2_1", torch_ge_2_1) +def test_get_full_state_dict_context_offload(set_type_mock, monkeypatch): + """Test that the state dict context manager handles CPU offloading.""" with _get_full_state_dict_context(module=Mock(spec=FullyShardedDataParallel), world_size=1): - assert set_type_mock.call_args_list[0][0][2].offload_to_cpu is torch_ge_2_1 # model config - assert set_type_mock.call_args_list[0][0][3].offload_to_cpu is torch_ge_2_1 # optim config + assert set_type_mock.call_args_list[0][0][2].offload_to_cpu # model config + assert set_type_mock.call_args_list[0][0][3].offload_to_cpu # optim config set_type_mock.reset_mock() diff --git a/tests/tests_fabric/strategies/test_fsdp_integration.py b/tests/tests_fabric/strategies/test_fsdp_integration.py index 16e4910c7e..e324ab5056 100644 --- a/tests/tests_fabric/strategies/test_fsdp_integration.py +++ b/tests/tests_fabric/strategies/test_fsdp_integration.py @@ -23,7 +23,6 @@ import torch.nn as nn from lightning.fabric import Fabric from lightning.fabric.plugins import FSDPPrecision from lightning.fabric.strategies import FSDPStrategy -from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_1 from lightning.fabric.utilities.load import _load_distributed_checkpoint from lightning.fabric.wrappers import _FabricOptimizer from torch._dynamo import OptimizedModule @@ -400,7 +399,7 @@ def test_setup_with_orig_params_and_multiple_param_groups(): assert not isinstance(layer.weight, FlatParameter) -@RunIf(min_cuda_gpus=2, standalone=True, min_torch="2.1.0", dynamo=True, skip_windows=True) +@RunIf(min_cuda_gpus=2, standalone=True, dynamo=True, skip_windows=True) @mock.patch("lightning.fabric.wrappers.torch.compile", Mock(wraps=torch.compile)) @mock.patch.dict(os.environ, {}) def test_reapply_compile(): @@ -466,12 +465,8 @@ def test_module_init_context(precision, expected_dtype): # Case 1: No empty init _run_setup_assertions(empty_init=False, expected_device=torch.device("cpu")) - if _TORCH_GREATER_EQUAL_2_1: - # Case 2: Empty-init with PyTorch >= 2.1 supports meta device - _run_setup_assertions(empty_init=True, expected_device=torch.device("meta")) - else: - # Case 2: Empty-init with PyTorch < 2.1 only supports `torch.empty()`-init - _run_setup_assertions(empty_init=True, expected_device=torch.device("cpu")) + # Case 2: Empty-init with meta device + _run_setup_assertions(empty_init=True, expected_device=torch.device("meta")) @RunIf(min_cuda_gpus=2, standalone=True) @@ -538,9 +533,6 @@ def test_rewrap_warnings(): assert not isinstance(model._forward_module, FullyShardedDataParallel) assert isinstance(model._forward_module[2], FullyShardedDataParallel) - if not _TORCH_GREATER_EQUAL_2_1: - return - with fabric.init_module(empty_init=True): model = torch.nn.Sequential(torch.nn.Linear(1, 1), torch.nn.ReLU(), wrap(torch.nn.Linear(1, 1))) assert model[0].weight.is_meta diff --git a/tests/tests_fabric/test_fabric.py b/tests/tests_fabric/test_fabric.py index f76a846e80..70d04d5431 100644 --- a/tests/tests_fabric/test_fabric.py +++ b/tests/tests_fabric/test_fabric.py @@ -289,7 +289,7 @@ def test_setup_optimizers_not_supported(strategy_cls): fabric.setup_optimizers(optimizer) -@RunIf(min_cuda_gpus=1, min_torch="2.1") +@RunIf(min_cuda_gpus=1) def test_setup_optimizer_on_meta_device(): """Test that the setup-methods validate that the optimizer doesn't have references to meta-device parameters.""" fabric = Fabric(strategy="fsdp", devices=1) @@ -867,8 +867,6 @@ def test_init_module_context(monkeypatch): def test_init_tensor_context(monkeypatch): - """Test that `.init_tensor()` warns if using PyTorch < 2.0.""" - fabric = Fabric(accelerator="cpu") strategy = SingleDeviceStrategy(device=torch.device("cuda")) strategy.tensor_init_context = Mock(wraps=strategy.tensor_init_context) diff --git a/tests/tests_fabric/test_wrappers.py b/tests/tests_fabric/test_wrappers.py index 91f516d03a..b89a536aff 100644 --- a/tests/tests_fabric/test_wrappers.py +++ b/tests/tests_fabric/test_wrappers.py @@ -19,7 +19,6 @@ import torch from lightning.fabric.fabric import Fabric from lightning.fabric.plugins import Precision from lightning.fabric.utilities.device_dtype_mixin import _DeviceDtypeModuleMixin -from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_1 from lightning.fabric.wrappers import ( _FabricDataLoader, _FabricModule, @@ -268,14 +267,13 @@ def test_fabric_module_state_dict_access(): assert torch.equal(fabric_module.layer.weight, weight) assert torch.equal(fabric_module.layer.bias, bias) - if _TORCH_GREATER_EQUAL_2_1: - # Can use additional `assign` argument in PyTorch >= 2.1 - with torch.device("meta"): - original_module = OriginalModule() - fabric_module = _FabricModule(wrapped_module, Mock(), original_module=original_module) - assert fabric_module.layer.weight.is_meta - fabric_module.load_state_dict({"layer.weight": weight, "layer.bias": bias}, assign=True) - assert not fabric_module.layer.weight.is_meta + # Can use additional `assign` argument + with torch.device("meta"): + original_module = OriginalModule() + fabric_module = _FabricModule(wrapped_module, Mock(), original_module=original_module) + assert fabric_module.layer.weight.is_meta + fabric_module.load_state_dict({"layer.weight": weight, "layer.bias": bias}, assign=True) + assert not fabric_module.layer.weight.is_meta @pytest.mark.parametrize( diff --git a/tests/tests_fabric/utilities/test_init.py b/tests/tests_fabric/utilities/test_init.py index bdbca90495..dd08dec020 100644 --- a/tests/tests_fabric/utilities/test_init.py +++ b/tests/tests_fabric/utilities/test_init.py @@ -58,7 +58,6 @@ def test_empty_init_speed(): assert normal_init_time > 2 * empty_init_time -@RunIf(min_torch="2.1") def test_materialize_meta_tensors(): class Submodule(torch.nn.Module): def __init__(self): diff --git a/tests/tests_fabric/utilities/test_throughput.py b/tests/tests_fabric/utilities/test_throughput.py index f2c3de30a3..eefadb285a 100644 --- a/tests/tests_fabric/utilities/test_throughput.py +++ b/tests/tests_fabric/utilities/test_throughput.py @@ -13,11 +13,9 @@ from lightning.fabric.utilities.throughput import ( measure_flops, ) -from tests_fabric.helpers.runif import RunIf from tests_fabric.test_fabric import BoringModel -@RunIf(min_torch="2.1") def test_measure_flops(): with torch.device("meta"): model = BoringModel() diff --git a/tests/tests_pytorch/callbacks/test_throughput_monitor.py b/tests/tests_pytorch/callbacks/test_throughput_monitor.py index 9467e45e2f..a74efba758 100644 --- a/tests/tests_pytorch/callbacks/test_throughput_monitor.py +++ b/tests/tests_pytorch/callbacks/test_throughput_monitor.py @@ -8,10 +8,7 @@ from lightning.pytorch import Trainer from lightning.pytorch.callbacks.throughput_monitor import ThroughputMonitor from lightning.pytorch.demos.boring_classes import BoringModel -from tests_pytorch.helpers.runif import RunIf - -@RunIf(min_torch="2.1") def test_measure_flops(): with torch.device("meta"): model = BoringModel() diff --git a/tests/tests_pytorch/conftest.py b/tests/tests_pytorch/conftest.py index c0319e873b..97c17c4e46 100644 --- a/tests/tests_pytorch/conftest.py +++ b/tests/tests_pytorch/conftest.py @@ -153,7 +153,10 @@ def thread_police_duuu_daaa_duuu_daaa(): assert not thread.is_alive() elif isinstance(thread, _ChildProcessObserver): thread.join(timeout=10) - elif thread.name == "QueueFeederThread": # tensorboardX + elif ( + thread.name == "QueueFeederThread" # tensorboardX + or thread.name == "QueueManagerThread" # torch.compile + ): thread.join(timeout=20) elif isinstance(thread, TMonitor): thread.exit() diff --git a/tests/tests_pytorch/deprecated_api/test_no_removal_version.py b/tests/tests_pytorch/deprecated_api/test_no_removal_version.py index d12254c679..e6da72c777 100644 --- a/tests/tests_pytorch/deprecated_api/test_no_removal_version.py +++ b/tests/tests_pytorch/deprecated_api/test_no_removal_version.py @@ -35,14 +35,10 @@ def test_ddp_is_distributed(): _ = strategy.is_distributed -def test_fsdp_activation_checkpointing(monkeypatch): +def test_fsdp_activation_checkpointing(): with pytest.raises(ValueError, match="cannot set both `activation_checkpointing"): FSDPStrategy(activation_checkpointing=torch.nn.Linear, activation_checkpointing_policy=lambda *_: True) - monkeypatch.setattr(lightning.fabric.strategies.fsdp, "_TORCH_GREATER_EQUAL_2_1", True) - with pytest.deprecated_call(match=r"use `FSDPStrategy\(activation_checkpointing_policy"): - FSDPStrategy(activation_checkpointing=torch.nn.Linear) - def test_double_precision_wrapper(): with pytest.deprecated_call(match=r"The `LightningDoublePrecisionModule` is deprecated and no longer needed"): diff --git a/tests/tests_pytorch/strategies/test_fsdp.py b/tests/tests_pytorch/strategies/test_fsdp.py index 04eeabbbd7..fe36eb0a03 100644 --- a/tests/tests_pytorch/strategies/test_fsdp.py +++ b/tests/tests_pytorch/strategies/test_fsdp.py @@ -14,7 +14,7 @@ import torch import torch.nn as nn from lightning.fabric.plugins.environments import LightningEnvironment from lightning.fabric.strategies.fsdp import _is_sharded_checkpoint -from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_1, _TORCH_GREATER_EQUAL_2_2 +from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_2 from lightning.fabric.utilities.load import _load_distributed_checkpoint from lightning.pytorch import Trainer from lightning.pytorch.callbacks import ModelCheckpoint @@ -334,10 +334,9 @@ def test_strategy_full_state_dict(tmp_path, wrap_min_params): TestFSDPModelAutoWrapped(), FSDPStrategy, { - "auto_wrap_policy": ModuleWrapPolicy({nn.Linear}) if _TORCH_GREATER_EQUAL_2_1 else None, + "auto_wrap_policy": ModuleWrapPolicy({nn.Linear}), "use_orig_params": True, }, - marks=RunIf(min_torch="2.1.0"), id="autowrap_use_orig_params", ), ], @@ -380,19 +379,12 @@ def test_invalid_parameters_in_optimizer(use_orig_params): fast_dev_run=1, ) - error_context = ( - nullcontext() - if _TORCH_GREATER_EQUAL_2_1 or use_orig_params is not False - else pytest.raises(ValueError, match="The optimizer does not seem to reference any FSDP parameters") - ) - class EmptyParametersModel(BoringModel): def configure_optimizers(self): return torch.optim.Adam(self.parameters(), lr=1e-2) model = EmptyParametersModel() - with error_context: - trainer.fit(model) + trainer.fit(model) class NoFlatParametersModel(BoringModel): def configure_optimizers(self): @@ -435,28 +427,13 @@ def test_activation_checkpointing(): self.layer1 = Block2(2, 2) self.layer2 = nn.Linear(3, 3) - if _TORCH_GREATER_EQUAL_2_1: - from torch.distributed.fsdp.wrap import ModuleWrapPolicy + strategy = FSDPStrategy(activation_checkpointing_policy={Block1}) + assert set(strategy._activation_checkpointing_kwargs) == {"auto_wrap_policy"} + assert isinstance(strategy._activation_checkpointing_kwargs["auto_wrap_policy"], ModuleWrapPolicy) - strategy = FSDPStrategy(activation_checkpointing_policy={Block1}) - assert set(strategy._activation_checkpointing_kwargs) == {"auto_wrap_policy"} - assert isinstance(strategy._activation_checkpointing_kwargs["auto_wrap_policy"], ModuleWrapPolicy) - - strategy = FSDPStrategy(activation_checkpointing_policy=ModuleWrapPolicy({Block1, Block2})) - assert set(strategy._activation_checkpointing_kwargs) == {"auto_wrap_policy"} - assert isinstance(strategy._activation_checkpointing_kwargs["auto_wrap_policy"], ModuleWrapPolicy) - else: - strategy = FSDPStrategy(activation_checkpointing=Block1) - assert set(strategy._activation_checkpointing_kwargs) == {"check_fn"} - - strategy = FSDPStrategy(activation_checkpointing=[Block1, Block2]) - assert set(strategy._activation_checkpointing_kwargs) == {"check_fn"} - - strategy = FSDPStrategy(activation_checkpointing_policy={Block1}) - assert set(strategy._activation_checkpointing_kwargs) == {"check_fn"} - - strategy = FSDPStrategy(activation_checkpointing_policy={Block1, Block2}) - assert set(strategy._activation_checkpointing_kwargs) == {"check_fn"} + strategy = FSDPStrategy(activation_checkpointing_policy=ModuleWrapPolicy({Block1, Block2})) + assert set(strategy._activation_checkpointing_kwargs) == {"auto_wrap_policy"} + assert isinstance(strategy._activation_checkpointing_kwargs["auto_wrap_policy"], ModuleWrapPolicy) model = Model() strategy._parallel_devices = [torch.device("cuda", 0)] @@ -608,7 +585,7 @@ def test_strategy_save_optimizer_states(tmp_path, wrap_min_params): if trainer.global_rank != 0: assert len(model_state_dict) == 0 - if trainer.global_rank != 0 and _TORCH_GREATER_EQUAL_2_1: + if trainer.global_rank != 0: assert len(optimizer_state_dict) == 0 # restore model to ddp @@ -679,7 +656,7 @@ def test_strategy_load_optimizer_states(wrap_min_params, tmp_path): if trainer.global_rank != 0: assert len(restored_model_state_dict) == 0 - if trainer.global_rank != 0 and _TORCH_GREATER_EQUAL_2_1: + if trainer.global_rank != 0: assert len(restored_optimizer_state_dict) == 0 if trainer.global_rank == 0: @@ -936,12 +913,8 @@ def test_module_init_context(precision, expected_dtype, tmp_path): # Case 1: No empty init _run_setup_assertions(empty_init=False, expected_device=torch.device("cpu")) - if _TORCH_GREATER_EQUAL_2_1: - # Case 2: Empty-init with PyTorch >= 2.1 supports meta device - _run_setup_assertions(empty_init=True, expected_device=torch.device("meta")) - else: - # Case 2: Empty-init with PyTorch < 2.1 only supports `torch.empty()`-init - _run_setup_assertions(empty_init=True, expected_device=torch.device("cpu")) + # Case 2: Empty-init with meta device + _run_setup_assertions(empty_init=True, expected_device=torch.device("meta")) @RunIf(min_cuda_gpus=2, standalone=True, min_torch="2.3.0") diff --git a/tests/tests_pytorch/strategies/test_model_parallel_integration.py b/tests/tests_pytorch/strategies/test_model_parallel_integration.py index bb8d7c719f..3f09db0568 100644 --- a/tests/tests_pytorch/strategies/test_model_parallel_integration.py +++ b/tests/tests_pytorch/strategies/test_model_parallel_integration.py @@ -339,7 +339,7 @@ def test_module_init_context(precision, expected_dtype, tmp_path): # Case 1: No empty init _run_setup_assertions(empty_init=False, expected_device=torch.device("cpu")) - # Case 2: Empty-init with PyTorch >= 2.1 supports meta device + # Case 2: Empty-init with meta device _run_setup_assertions(empty_init=True, expected_device=torch.device("meta")) diff --git a/tests/tests_pytorch/trainer/flags/test_inference_mode.py b/tests/tests_pytorch/trainer/flags/test_inference_mode.py index c262f0ca33..bae7b66dbb 100644 --- a/tests/tests_pytorch/trainer/flags/test_inference_mode.py +++ b/tests/tests_pytorch/trainer/flags/test_inference_mode.py @@ -16,7 +16,6 @@ from unittest.mock import Mock import pytest import torch -from lightning.fabric.utilities.imports import _TORCH_EQUAL_2_0 from lightning.pytorch import Trainer from lightning.pytorch.demos.boring_classes import BoringModel from lightning.pytorch.loops import _Loop @@ -81,7 +80,5 @@ def test_no_grad_context(): f.run() no_grad_mock.assert_called_once_with() f.inference_mode = True - with mock.patch("torch.inference_mode") as inference_mode_mock: + with mock.patch("torch.inference_mode"): f.run() - if not _TORCH_EQUAL_2_0: - inference_mode_mock.assert_called_once_with()