Add testing for PyTorch 2.4 (Fabric) (#20028)
This commit is contained in:
parent
37e04d075a
commit
693c21ac1b
|
@ -62,6 +62,9 @@ jobs:
|
|||
"Lightning | latest":
|
||||
image: "pytorchlightning/pytorch_lightning:base-cuda-py3.11-torch2.3-cuda12.1.0"
|
||||
PACKAGE_NAME: "lightning"
|
||||
"Lightning | future":
|
||||
image: "pytorchlightning/pytorch_lightning:base-cuda-py3.11-torch2.4-cuda12.1.0"
|
||||
PACKAGE_NAME: "lightning"
|
||||
workspace:
|
||||
clean: all
|
||||
steps:
|
||||
|
@ -72,9 +75,12 @@ jobs:
|
|||
echo "##vso[task.setvariable variable=TORCH_URL]https://download.pytorch.org/whl/cu${cuda_ver}/torch_stable.html"
|
||||
scope=$(python -c 'n = "$(PACKAGE_NAME)" ; print(dict(fabric="lightning_fabric").get(n, n))')
|
||||
echo "##vso[task.setvariable variable=COVERAGE_SOURCE]$scope"
|
||||
python_ver=$(python -c "import sys; print(f'{sys.version_info.major}{sys.version_info.minor}')")
|
||||
echo "##vso[task.setvariable variable=PYTHON_VERSION_MM]$python_ver"
|
||||
displayName: "set env. vars"
|
||||
- bash: |
|
||||
echo "##vso[task.setvariable variable=TORCH_URL]https://download.pytorch.org/whl/test/cu${CUDA_VERSION_MM}/torch_test.html"
|
||||
echo "##vso[task.setvariable variable=TORCH_URL]https://download.pytorch.org/whl/test/cu${CUDA_VERSION_MM}"
|
||||
echo "##vso[task.setvariable variable=TORCHVISION_URL]https://download.pytorch.org/whl/test/cu124/torchvision-0.19.0%2Bcu124-cp${PYTHON_VERSION_MM}-cp${PYTHON_VERSION_MM}-linux_x86_64.whl"
|
||||
condition: endsWith(variables['Agent.JobName'], 'future')
|
||||
displayName: "set env. vars 4 future"
|
||||
|
||||
|
@ -103,7 +109,7 @@ jobs:
|
|||
|
||||
- bash: |
|
||||
extra=$(python -c "print({'lightning': 'fabric-'}.get('$(PACKAGE_NAME)', ''))")
|
||||
pip install -e ".[${extra}dev]" pytest-timeout -U --find-links="${TORCH_URL}"
|
||||
pip install -e ".[${extra}dev]" pytest-timeout -U --find-links="${TORCH_URL}" --find-links="${TORCHVISION_URL}"
|
||||
displayName: "Install package & dependencies"
|
||||
|
||||
- bash: |
|
||||
|
|
|
@ -49,6 +49,10 @@ jobs:
|
|||
- { os: "macOS-14", pkg-name: "lightning", python-version: "3.10", pytorch-version: "2.3" }
|
||||
- { os: "ubuntu-20.04", pkg-name: "lightning", python-version: "3.11", pytorch-version: "2.3" }
|
||||
- { os: "windows-2022", pkg-name: "lightning", python-version: "3.11", pytorch-version: "2.3" }
|
||||
- { os: "macOS-14", pkg-name: "lightning", python-version: "3.11", pytorch-version: "2.4" }
|
||||
- { os: "ubuntu-20.04", pkg-name: "lightning", python-version: "3.11", pytorch-version: "2.4" }
|
||||
# TODO: PyTorch 2.4 on Windows not yet working with `torch.distributed` (not compiled with libuv support)
|
||||
# - { os: "windows-2022", pkg-name: "lightning", python-version: "3.11", pytorch-version: "2.4" }
|
||||
# only run PyTorch latest with Python latest, use Fabric scope to limit dependency issues
|
||||
- { os: "macOS-12", pkg-name: "fabric", python-version: "3.11", pytorch-version: "2.1" }
|
||||
- { os: "ubuntu-22.04", pkg-name: "fabric", python-version: "3.11", pytorch-version: "2.1" }
|
||||
|
@ -79,7 +83,7 @@ jobs:
|
|||
FREEZE_REQUIREMENTS: ${{ ! (github.ref == 'refs/heads/master' || startsWith(github.ref, 'refs/heads/release/')) }}
|
||||
PYPI_CACHE_DIR: "_pip-wheels"
|
||||
TORCH_URL_STABLE: "https://download.pytorch.org/whl/cpu/torch_stable.html"
|
||||
TORCH_URL_TEST: "https://download.pytorch.org/whl/test/cpu/torch_test.html"
|
||||
TORCH_URL_TEST: "https://download.pytorch.org/whl/test/cpu/torch"
|
||||
# TODO: Remove this - Enable running MPS tests on this platform
|
||||
DISABLE_MPS: ${{ matrix.os == 'macOS-14' && '1' || '0' }}
|
||||
steps:
|
||||
|
@ -118,7 +122,7 @@ jobs:
|
|||
- name: Env. variables
|
||||
run: |
|
||||
# Switch PyTorch URL
|
||||
python -c "print('TORCH_URL=' + str('${{env.TORCH_URL_TEST}}' if '${{ matrix.pytorch-version }}' == '2.3' else '${{env.TORCH_URL_STABLE}}'))" >> $GITHUB_ENV
|
||||
python -c "print('TORCH_URL=' + str('${{env.TORCH_URL_TEST}}' if '${{ matrix.pytorch-version }}' == '2.4' else '${{env.TORCH_URL_STABLE}}'))" >> $GITHUB_ENV
|
||||
# Switch coverage scope
|
||||
python -c "print('COVERAGE_SCOPE=' + str('lightning' if '${{matrix.pkg-name}}' == 'lightning' else 'lightning_fabric'))" >> $GITHUB_ENV
|
||||
# if you install mono-package set dependency only for this subpackage
|
||||
|
|
|
@ -2,7 +2,7 @@
|
|||
# in case you want to preserve/enforce restrictions on the latest compatible version, add "strict" as an in-line comment
|
||||
|
||||
numpy >=1.21.0, <1.27.0
|
||||
torch >=2.1.0, <2.4.0
|
||||
torch >=2.1.0, <2.5.0
|
||||
fsspec[http] >=2022.5.0, <2024.4.0
|
||||
packaging >=20.0, <=23.1
|
||||
typing-extensions >=4.4.0, <4.10.0
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
# NOTE: the upper bound for the package version is only set for CI stability, and it is dropped while installing this package
|
||||
# in case you want to preserve/enforce restrictions on the latest compatible version, add "strict" as an in-line comment
|
||||
|
||||
torchvision >=0.16.0, <0.19.0
|
||||
torchvision >=0.16.0, <0.20.0
|
||||
torchmetrics >=0.10.0, <1.3.0
|
||||
lightning-utilities >=0.8.0, <0.12.0
|
||||
|
|
|
@ -22,6 +22,7 @@ from typing_extensions import override
|
|||
|
||||
from lightning.fabric.plugins.precision.precision import Precision
|
||||
from lightning.fabric.plugins.precision.utils import _convert_fp_tensor
|
||||
from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_4
|
||||
from lightning.fabric.utilities.types import Optimizable
|
||||
|
||||
|
||||
|
@ -39,7 +40,7 @@ class MixedPrecision(Precision):
|
|||
self,
|
||||
precision: Literal["16-mixed", "bf16-mixed"],
|
||||
device: str,
|
||||
scaler: Optional[torch.cuda.amp.GradScaler] = None,
|
||||
scaler: Optional["torch.cuda.amp.GradScaler"] = None,
|
||||
) -> None:
|
||||
if precision not in ("16-mixed", "bf16-mixed"):
|
||||
raise ValueError(
|
||||
|
@ -49,7 +50,7 @@ class MixedPrecision(Precision):
|
|||
|
||||
self.precision = precision
|
||||
if scaler is None and self.precision == "16-mixed":
|
||||
scaler = torch.cuda.amp.GradScaler()
|
||||
scaler = torch.amp.GradScaler(device=device) if _TORCH_GREATER_EQUAL_2_4 else torch.cuda.amp.GradScaler()
|
||||
if scaler is not None and self.precision == "bf16-mixed":
|
||||
raise ValueError(f"`precision='bf16-mixed'` does not use a scaler, found {scaler}.")
|
||||
self.device = device
|
||||
|
|
|
@ -12,6 +12,7 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import shutil
|
||||
import warnings
|
||||
from contextlib import ExitStack, nullcontext
|
||||
from datetime import timedelta
|
||||
from functools import partial
|
||||
|
@ -83,6 +84,9 @@ if TYPE_CHECKING:
|
|||
|
||||
_FSDP_ALIASES = ("fsdp", "fsdp_cpu_offload")
|
||||
|
||||
# TODO: Switch to new state-dict APIs
|
||||
warnings.filterwarnings("ignore", category=FutureWarning, message=".*FSDP.state_dict_type.*") # from torch >= 2.4
|
||||
|
||||
|
||||
class FSDPStrategy(ParallelStrategy, _Sharded):
|
||||
r"""Strategy for Fully Sharded Data Parallel provided by torch.distributed.
|
||||
|
|
|
@ -70,7 +70,7 @@ class ModelParallelStrategy(ParallelStrategy):
|
|||
|
||||
Currently supports up to 2D parallelism. Specifically, it supports the combination of
|
||||
Fully Sharded Data-Parallel 2 (FSDP2) with Tensor Parallelism (DTensor). These PyTorch APIs are currently still
|
||||
experimental in PyTorch. Requires PyTorch 2.3 or newer.
|
||||
experimental in PyTorch. Requires PyTorch 2.4 or newer.
|
||||
|
||||
Arguments:
|
||||
parallelize_fn: A function that applies parallelisms to a module. The strategy will provide the
|
||||
|
@ -95,8 +95,8 @@ class ModelParallelStrategy(ParallelStrategy):
|
|||
timeout: Optional[timedelta] = default_pg_timeout,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
if not _TORCH_GREATER_EQUAL_2_3:
|
||||
raise ImportError(f"{type(self).__name__} requires PyTorch 2.3 or higher.")
|
||||
if not _TORCH_GREATER_EQUAL_2_4:
|
||||
raise ImportError(f"{type(self).__name__} requires PyTorch 2.4 or higher.")
|
||||
self._parallelize_fn = parallelize_fn
|
||||
self._data_parallel_size = data_parallel_size
|
||||
self._tensor_parallel_size = tensor_parallel_size
|
||||
|
@ -178,7 +178,7 @@ class ModelParallelStrategy(ParallelStrategy):
|
|||
if any(isinstance(mod, FullyShardedDataParallel) for mod in module.modules()):
|
||||
raise TypeError(
|
||||
"Found modules that are wrapped with `torch.distributed.fsdp.FullyShardedDataParallel`."
|
||||
f" The `{self.__class__.__name__}` only supports the new FSDP2 APIs in PyTorch >= 2.3."
|
||||
f" The `{self.__class__.__name__}` only supports the new FSDP2 APIs in PyTorch >= 2.4."
|
||||
)
|
||||
|
||||
module = self._parallelize_fn(module, self.device_mesh)
|
||||
|
@ -329,10 +329,10 @@ class _FSDPNoSync(ContextManager):
|
|||
self._enabled = enabled
|
||||
|
||||
def _set_requires_grad_sync(self, requires_grad_sync: bool) -> None:
|
||||
from torch.distributed._composable.fsdp import FSDP
|
||||
from torch.distributed._composable.fsdp import FSDPModule
|
||||
|
||||
for mod in self._module.modules():
|
||||
if isinstance(mod, FSDP):
|
||||
if isinstance(mod, FSDPModule):
|
||||
mod.set_requires_gradient_sync(requires_grad_sync, recurse=False)
|
||||
|
||||
def __enter__(self) -> None:
|
||||
|
@ -458,9 +458,6 @@ def _load_checkpoint(
|
|||
return metadata
|
||||
|
||||
if _is_full_checkpoint(path):
|
||||
if not _TORCH_GREATER_EQUAL_2_4:
|
||||
raise ImportError("Loading a non-distributed checkpoint into a distributed model requires PyTorch >= 2.4.")
|
||||
|
||||
checkpoint = torch.load(path, mmap=True, map_location="cpu")
|
||||
_load_raw_module_state(checkpoint.pop(module_key), module, strict=strict)
|
||||
|
||||
|
@ -546,9 +543,6 @@ def _load_raw_module_state(
|
|||
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
|
||||
|
||||
if _has_dtensor_modules(module):
|
||||
if not _TORCH_GREATER_EQUAL_2_4:
|
||||
raise ImportError("Loading a non-distributed checkpoint into a distributed model requires PyTorch >= 2.4.")
|
||||
|
||||
from torch.distributed.checkpoint.state_dict import StateDictOptions, set_model_state_dict
|
||||
|
||||
state_dict_options = StateDictOptions(
|
||||
|
|
|
@ -28,7 +28,7 @@ _IS_INTERACTIVE = hasattr(sys, "ps1") or bool(sys.flags.interactive)
|
|||
|
||||
_TORCH_GREATER_EQUAL_2_2 = compare_version("torch", operator.ge, "2.2.0")
|
||||
_TORCH_GREATER_EQUAL_2_3 = compare_version("torch", operator.ge, "2.3.0")
|
||||
_TORCH_GREATER_EQUAL_2_4 = compare_version("torch", operator.ge, "2.4.0", use_base_version=True)
|
||||
_TORCH_GREATER_EQUAL_2_4 = compare_version("torch", operator.ge, "2.4.0")
|
||||
|
||||
_PYTHON_GREATER_EQUAL_3_8_0 = (sys.version_info.major, sys.version_info.minor) >= (3, 8)
|
||||
_PYTHON_GREATER_EQUAL_3_10_0 = (sys.version_info.major, sys.version_info.minor) >= (3, 10)
|
||||
|
|
|
@ -24,6 +24,7 @@ from lightning.fabric.accelerators import XLAAccelerator
|
|||
from lightning.fabric.accelerators.cuda import num_cuda_devices
|
||||
from lightning.fabric.accelerators.mps import MPSAccelerator
|
||||
from lightning.fabric.strategies.deepspeed import _DEEPSPEED_AVAILABLE
|
||||
from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_4
|
||||
|
||||
|
||||
def _runif_reasons(
|
||||
|
@ -111,7 +112,9 @@ def _runif_reasons(
|
|||
reasons.append("Standalone execution")
|
||||
kwargs["standalone"] = True
|
||||
|
||||
if deepspeed and not (_DEEPSPEED_AVAILABLE and RequirementCache(module="deepspeed.utils")):
|
||||
if deepspeed and not (
|
||||
_DEEPSPEED_AVAILABLE and not _TORCH_GREATER_EQUAL_2_4 and RequirementCache(module="deepspeed.utils")
|
||||
):
|
||||
reasons.append("Deepspeed")
|
||||
|
||||
if dynamo:
|
||||
|
|
|
@ -104,6 +104,7 @@ def thread_police_duuu_daaa_duuu_daaa():
|
|||
elif (
|
||||
thread.name == "QueueFeederThread" # tensorboardX
|
||||
or thread.name == "QueueManagerThread" # torch.compile
|
||||
or "(_read_thread)" in thread.name # torch.compile
|
||||
):
|
||||
thread.join(timeout=20)
|
||||
elif (
|
||||
|
|
|
@ -17,11 +17,13 @@ from unittest.mock import Mock
|
|||
import pytest
|
||||
import torch
|
||||
from lightning.fabric.plugins.precision.amp import MixedPrecision
|
||||
from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_4
|
||||
|
||||
|
||||
def test_amp_precision_default_scaler():
|
||||
precision = MixedPrecision(precision="16-mixed", device=Mock())
|
||||
assert isinstance(precision.scaler, torch.cuda.amp.GradScaler)
|
||||
scaler_cls = torch.amp.GradScaler if _TORCH_GREATER_EQUAL_2_4 else torch.cuda.amp.GradScaler
|
||||
assert isinstance(precision.scaler, scaler_cls)
|
||||
|
||||
|
||||
def test_amp_precision_scaler_with_bf16():
|
||||
|
@ -36,7 +38,8 @@ def test_amp_precision_forward_context():
|
|||
"""Test to ensure that the context manager correctly is set to bfloat16 on CPU and CUDA."""
|
||||
precision = MixedPrecision(precision="16-mixed", device="cuda")
|
||||
assert precision.device == "cuda"
|
||||
assert isinstance(precision.scaler, torch.cuda.amp.GradScaler)
|
||||
scaler_cls = torch.amp.GradScaler if _TORCH_GREATER_EQUAL_2_4 else torch.cuda.amp.GradScaler
|
||||
assert isinstance(precision.scaler, scaler_cls)
|
||||
assert torch.get_default_dtype() == torch.float32
|
||||
with precision.forward_context():
|
||||
assert torch.get_autocast_gpu_dtype() == torch.float16
|
||||
|
|
|
@ -17,6 +17,7 @@ import pytest
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
from lightning.fabric import Fabric, seed_everything
|
||||
from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_4
|
||||
|
||||
from tests_fabric.helpers.runif import RunIf
|
||||
|
||||
|
@ -82,7 +83,8 @@ def test_amp_fused_optimizer_parity():
|
|||
optimizer = torch.optim.Adam(model.parameters(), lr=1.0, fused=fused)
|
||||
|
||||
model, optimizer = fabric.setup(model, optimizer)
|
||||
assert isinstance(fabric._precision.scaler, torch.cuda.amp.GradScaler)
|
||||
scaler_cls = torch.amp.GradScaler if _TORCH_GREATER_EQUAL_2_4 else torch.cuda.amp.GradScaler
|
||||
assert isinstance(fabric._precision.scaler, scaler_cls)
|
||||
|
||||
data = torch.randn(10, 10, device="cuda")
|
||||
target = torch.randn(10, 10, device="cuda")
|
||||
|
|
|
@ -93,7 +93,7 @@ def test_bitsandbytes_plugin(monkeypatch):
|
|||
precision.convert_module(model)
|
||||
|
||||
|
||||
@RunIf(min_cuda_gpus=1)
|
||||
@RunIf(min_cuda_gpus=1, max_torch="2.4")
|
||||
@pytest.mark.skipif(not _BITSANDBYTES_AVAILABLE, reason="bitsandbytes unavailable")
|
||||
@pytest.mark.parametrize(
|
||||
("args", "expected"),
|
||||
|
@ -232,7 +232,7 @@ def test_bitsandbytes_layers_meta_device(args, expected, tmp_path):
|
|||
assert model.l.weight.dtype == expected
|
||||
|
||||
|
||||
@RunIf(min_cuda_gpus=1)
|
||||
@RunIf(min_cuda_gpus=1, max_torch="2.4")
|
||||
@pytest.mark.skipif(not _BITSANDBYTES_AVAILABLE, reason="bitsandbytes unavailable")
|
||||
def test_load_quantized_checkpoint(tmp_path):
|
||||
"""Test that a checkpoint saved from a quantized model can be loaded back into a quantized model."""
|
||||
|
|
|
@ -74,6 +74,7 @@ def test_dp_module_state_dict():
|
|||
assert strategy.get_module_state_dict(wrapped_module).keys() == original_module.state_dict().keys()
|
||||
|
||||
|
||||
@pytest.mark.filterwarnings("ignore::FutureWarning")
|
||||
@pytest.mark.parametrize(
|
||||
"precision",
|
||||
[
|
||||
|
|
|
@ -118,7 +118,7 @@ class _TrainerManualWrapping(_Trainer):
|
|||
return model
|
||||
|
||||
|
||||
@RunIf(min_cuda_gpus=2, standalone=True)
|
||||
@RunIf(min_cuda_gpus=2, standalone=True, max_torch="2.4")
|
||||
@pytest.mark.parametrize("precision", ["16-mixed", pytest.param("bf16-mixed", marks=RunIf(bf16_cuda=True))])
|
||||
@pytest.mark.parametrize("manual_wrapping", [True, False])
|
||||
def test_train_save_load(tmp_path, manual_wrapping, precision):
|
||||
|
@ -173,6 +173,7 @@ def test_train_save_load(tmp_path, manual_wrapping, precision):
|
|||
assert state["coconut"] == 11
|
||||
|
||||
|
||||
@pytest.mark.filterwarnings("ignore::FutureWarning")
|
||||
@RunIf(min_cuda_gpus=2, standalone=True)
|
||||
def test_save_full_state_dict(tmp_path):
|
||||
"""Test that FSDP saves the full state into a single file with `state_dict_type="full"`."""
|
||||
|
@ -287,6 +288,7 @@ def test_save_full_state_dict(tmp_path):
|
|||
trainer.run()
|
||||
|
||||
|
||||
@pytest.mark.filterwarnings("ignore::FutureWarning")
|
||||
@RunIf(min_cuda_gpus=2, standalone=True)
|
||||
def test_load_full_state_dict_into_sharded_model(tmp_path):
|
||||
"""Test that the strategy can load a full-state checkpoint into a FSDP sharded model."""
|
||||
|
@ -469,6 +471,7 @@ def test_module_init_context(precision, expected_dtype):
|
|||
_run_setup_assertions(empty_init=True, expected_device=torch.device("meta"))
|
||||
|
||||
|
||||
@pytest.mark.filterwarnings("ignore::FutureWarning")
|
||||
@RunIf(min_cuda_gpus=2, standalone=True)
|
||||
def test_save_filter(tmp_path):
|
||||
fabric = Fabric(accelerator="cuda", strategy=FSDPStrategy(state_dict_type="full"), devices=2)
|
||||
|
@ -602,6 +605,7 @@ def test_clip_gradients(clip_type, precision):
|
|||
optimizer.zero_grad()
|
||||
|
||||
|
||||
@pytest.mark.filterwarnings("ignore::FutureWarning")
|
||||
@RunIf(min_cuda_gpus=2, standalone=True, min_torch="2.3.0")
|
||||
def test_save_sharded_and_consolidate_and_load(tmp_path):
|
||||
"""Test the consolidation of a FSDP-sharded checkpoint into a single file."""
|
||||
|
|
|
@ -28,20 +28,20 @@ from torch.optim import Adam
|
|||
from tests_fabric.helpers.runif import RunIf
|
||||
|
||||
|
||||
@mock.patch("lightning.fabric.strategies.model_parallel._TORCH_GREATER_EQUAL_2_3", False)
|
||||
def test_torch_greater_equal_2_3():
|
||||
with pytest.raises(ImportError, match="ModelParallelStrategy requires PyTorch 2.3 or higher"):
|
||||
@mock.patch("lightning.fabric.strategies.model_parallel._TORCH_GREATER_EQUAL_2_4", False)
|
||||
def test_torch_greater_equal_2_4():
|
||||
with pytest.raises(ImportError, match="ModelParallelStrategy requires PyTorch 2.4 or higher"):
|
||||
ModelParallelStrategy(parallelize_fn=(lambda m, _: m))
|
||||
|
||||
|
||||
@RunIf(min_torch="2.3")
|
||||
@RunIf(min_torch="2.4")
|
||||
def test_device_mesh_access():
|
||||
strategy = ModelParallelStrategy(parallelize_fn=(lambda m, _: m))
|
||||
with pytest.raises(RuntimeError, match="Accessing the device mesh .* not allowed"):
|
||||
_ = strategy.device_mesh
|
||||
|
||||
|
||||
@RunIf(min_torch="2.3")
|
||||
@RunIf(min_torch="2.4")
|
||||
@pytest.mark.parametrize(
|
||||
("num_nodes", "devices", "invalid_dp_size", "invalid_tp_size"),
|
||||
[
|
||||
|
@ -70,7 +70,7 @@ def test_validate_device_mesh_dimensions(num_nodes, devices, invalid_dp_size, in
|
|||
strategy.setup_environment()
|
||||
|
||||
|
||||
@RunIf(min_torch="2.3")
|
||||
@RunIf(min_torch="2.4")
|
||||
def test_checkpoint_io_unsupported():
|
||||
"""Test that the ModelParallel strategy does not support the `CheckpointIO` plugin."""
|
||||
strategy = ModelParallelStrategy(parallelize_fn=(lambda m, _: m))
|
||||
|
@ -81,18 +81,18 @@ def test_checkpoint_io_unsupported():
|
|||
strategy.checkpoint_io = Mock()
|
||||
|
||||
|
||||
@RunIf(min_torch="2.3")
|
||||
@RunIf(min_torch="2.4")
|
||||
def test_fsdp_v1_modules_unsupported():
|
||||
"""Test that the strategy won't allow setting up a module wrapped with the legacy FSDP API."""
|
||||
from torch.distributed.fsdp import FullyShardedDataParallel
|
||||
|
||||
module = Mock(modules=Mock(return_value=[Mock(spec=FullyShardedDataParallel)]))
|
||||
strategy = ModelParallelStrategy(parallelize_fn=(lambda x, _: x))
|
||||
with pytest.raises(TypeError, match="only supports the new FSDP2 APIs in PyTorch >= 2.3"):
|
||||
with pytest.raises(TypeError, match="only supports the new FSDP2 APIs in PyTorch >= 2.4"):
|
||||
strategy.setup_module(module)
|
||||
|
||||
|
||||
@RunIf(min_torch="2.3")
|
||||
@RunIf(min_torch="2.4")
|
||||
def test_parallelize_fn_call():
|
||||
model = nn.Linear(2, 2)
|
||||
optimizer = Adam(model.parameters())
|
||||
|
@ -116,15 +116,15 @@ def test_parallelize_fn_call():
|
|||
strategy.setup_module_and_optimizers(model, [optimizer])
|
||||
|
||||
|
||||
@RunIf(min_torch="2.3")
|
||||
@RunIf(min_torch="2.4")
|
||||
def test_no_backward_sync():
|
||||
"""Test that the backward sync control disables gradient sync on modules that benefit from it."""
|
||||
from torch.distributed._composable.fsdp import FSDP
|
||||
from torch.distributed._composable.fsdp import FSDPModule
|
||||
|
||||
strategy = ModelParallelStrategy(parallelize_fn=(lambda m, _: m))
|
||||
assert isinstance(strategy._backward_sync_control, _ParallelBackwardSyncControl)
|
||||
|
||||
fsdp_layer = Mock(spec=FSDP)
|
||||
fsdp_layer = Mock(spec=FSDPModule)
|
||||
other_layer = nn.Linear(2, 2)
|
||||
module = Mock()
|
||||
module.modules = Mock(return_value=[fsdp_layer, other_layer])
|
||||
|
@ -138,7 +138,7 @@ def test_no_backward_sync():
|
|||
fsdp_layer.set_requires_gradient_sync.assert_called_with(False, recurse=False)
|
||||
|
||||
|
||||
@RunIf(min_torch="2.3")
|
||||
@RunIf(min_torch="2.4")
|
||||
def test_save_checkpoint_storage_options(tmp_path):
|
||||
"""Test that the strategy does not accept storage options for saving checkpoints."""
|
||||
strategy = ModelParallelStrategy(parallelize_fn=(lambda m, _: m))
|
||||
|
@ -148,7 +148,7 @@ def test_save_checkpoint_storage_options(tmp_path):
|
|||
strategy.save_checkpoint(path=tmp_path, state=Mock(), storage_options=Mock())
|
||||
|
||||
|
||||
@RunIf(min_torch="2.3")
|
||||
@RunIf(min_torch="2.4")
|
||||
@mock.patch("lightning.fabric.strategies.model_parallel.ModelParallelStrategy.broadcast", lambda _, x: x)
|
||||
@mock.patch("lightning.fabric.strategies.model_parallel._has_dtensor_modules", return_value=True)
|
||||
@mock.patch("torch.distributed.checkpoint.state_dict.get_model_state_dict", return_value={})
|
||||
|
@ -205,7 +205,7 @@ def test_save_checkpoint_path_exists(shutil_mock, torch_save_mock, _, __, ___, t
|
|||
assert path.is_dir()
|
||||
|
||||
|
||||
@RunIf(min_torch="2.3")
|
||||
@RunIf(min_torch="2.4")
|
||||
def test_save_checkpoint_one_dist_module_required(tmp_path):
|
||||
"""Test that the ModelParallelStrategy strategy can only save one distributed model per checkpoint."""
|
||||
strategy = ModelParallelStrategy(parallelize_fn=(lambda m, _: m))
|
||||
|
@ -226,29 +226,7 @@ def test_save_checkpoint_one_dist_module_required(tmp_path):
|
|||
strategy.save_checkpoint(path=tmp_path, state={"model1": model1, "model2": model2})
|
||||
|
||||
|
||||
@RunIf(min_torch="2.3")
|
||||
@mock.patch("lightning.fabric.strategies.model_parallel.torch.load", Mock())
|
||||
@mock.patch("lightning.fabric.strategies.model_parallel._TORCH_GREATER_EQUAL_2_4", False)
|
||||
def test_load_full_checkpoint_support(tmp_path):
|
||||
"""Test that loading non-distributed checkpoints into distributed models requires PyTorch >= 2.4."""
|
||||
strategy = ModelParallelStrategy(parallelize_fn=(lambda m, _: m))
|
||||
model = Mock(spec=nn.Module)
|
||||
model.parameters.return_value = [torch.zeros(2, 1)]
|
||||
path = tmp_path / "full.ckpt"
|
||||
path.touch()
|
||||
|
||||
with pytest.raises(ImportError, match="Loading .* into a distributed model requires PyTorch >= 2.4"), mock.patch(
|
||||
"lightning.fabric.strategies.model_parallel._has_dtensor_modules", return_value=True
|
||||
):
|
||||
strategy.load_checkpoint(path=path, state={"model": model})
|
||||
|
||||
with pytest.raises(ImportError, match="Loading .* into a distributed model requires PyTorch >= 2.4"), mock.patch(
|
||||
"lightning.fabric.strategies.model_parallel._has_dtensor_modules", return_value=True
|
||||
):
|
||||
strategy.load_checkpoint(path=path, state=model)
|
||||
|
||||
|
||||
@RunIf(min_torch="2.3")
|
||||
@RunIf(min_torch="2.4")
|
||||
def test_load_checkpoint_no_state(tmp_path):
|
||||
"""Test that the ModelParallelStrategy strategy can't load the full state without access to a model instance from
|
||||
the user."""
|
||||
|
@ -259,7 +237,7 @@ def test_load_checkpoint_no_state(tmp_path):
|
|||
strategy.load_checkpoint(path=tmp_path, state={})
|
||||
|
||||
|
||||
@RunIf(min_torch="2.3")
|
||||
@RunIf(min_torch="2.4")
|
||||
@mock.patch("lightning.fabric.strategies.model_parallel.ModelParallelStrategy.broadcast", lambda _, x: x)
|
||||
@mock.patch("lightning.fabric.strategies.model_parallel.torch.load", Mock())
|
||||
def test_load_checkpoint_one_dist_module_required(tmp_path):
|
||||
|
@ -289,7 +267,7 @@ def test_load_checkpoint_one_dist_module_required(tmp_path):
|
|||
strategy.load_checkpoint(path=path, state=model)
|
||||
|
||||
|
||||
@RunIf(min_torch="2.3")
|
||||
@RunIf(min_torch="2.4")
|
||||
@mock.patch("lightning.fabric.strategies.model_parallel._has_dtensor_modules", return_value=True)
|
||||
def test_load_unknown_checkpoint_type(_, tmp_path):
|
||||
"""Test that the strategy validates the contents at the checkpoint path."""
|
||||
|
@ -301,7 +279,7 @@ def test_load_unknown_checkpoint_type(_, tmp_path):
|
|||
strategy.load_checkpoint(path=path, state={"model": model})
|
||||
|
||||
|
||||
@RunIf(min_torch="2.3")
|
||||
@RunIf(min_torch="2.4")
|
||||
def test_load_raw_checkpoint_validate_single_file(tmp_path):
|
||||
"""Test that we validate the given checkpoint is a single file when loading a raw PyTorch state-dict checkpoint."""
|
||||
strategy = ModelParallelStrategy(parallelize_fn=(lambda m, _: m))
|
||||
|
@ -312,7 +290,7 @@ def test_load_raw_checkpoint_validate_single_file(tmp_path):
|
|||
strategy.load_checkpoint(path=path, state=model)
|
||||
|
||||
|
||||
@RunIf(min_torch="2.3")
|
||||
@RunIf(min_torch="2.4")
|
||||
def test_load_raw_checkpoint_optimizer_unsupported(tmp_path):
|
||||
"""Validate that the ModelParallelStrategy strategy does not yet support loading the raw PyTorch state-dict for an
|
||||
optimizer."""
|
||||
|
@ -324,7 +302,7 @@ def test_load_raw_checkpoint_optimizer_unsupported(tmp_path):
|
|||
strategy.load_checkpoint(path=tmp_path, state=optimizer)
|
||||
|
||||
|
||||
@RunIf(min_torch="2.3")
|
||||
@RunIf(min_torch="2.4")
|
||||
@mock.patch("lightning.fabric.strategies.model_parallel._setup_device_mesh")
|
||||
@mock.patch("torch.distributed.init_process_group")
|
||||
def test_set_timeout(init_process_group_mock, _):
|
||||
|
@ -343,7 +321,7 @@ def test_set_timeout(init_process_group_mock, _):
|
|||
)
|
||||
|
||||
|
||||
@RunIf(min_torch="2.3")
|
||||
@RunIf(min_torch="2.4")
|
||||
def test_meta_device_materialization():
|
||||
"""Test that the `setup_module()` method materializes meta-device tensors in the module."""
|
||||
|
||||
|
|
|
@ -80,7 +80,7 @@ def _parallelize_feed_forward_fsdp2_tp(model, device_mesh):
|
|||
return model
|
||||
|
||||
|
||||
@RunIf(min_torch="2.3", standalone=True, min_cuda_gpus=4)
|
||||
@RunIf(min_torch="2.4", standalone=True, min_cuda_gpus=4)
|
||||
def test_setup_device_mesh():
|
||||
from torch.distributed.device_mesh import DeviceMesh
|
||||
|
||||
|
@ -116,7 +116,7 @@ def test_setup_device_mesh():
|
|||
assert fabric.strategy.device_mesh.size(1) == 4
|
||||
|
||||
|
||||
@RunIf(min_torch="2.3", standalone=True, min_cuda_gpus=2)
|
||||
@RunIf(min_torch="2.4", standalone=True, min_cuda_gpus=2)
|
||||
def test_tensor_parallel():
|
||||
from torch.distributed._tensor import DTensor
|
||||
|
||||
|
@ -160,7 +160,7 @@ def test_tensor_parallel():
|
|||
optimizer.zero_grad()
|
||||
|
||||
|
||||
@RunIf(min_torch="2.3", standalone=True, min_cuda_gpus=4)
|
||||
@RunIf(min_torch="2.4", standalone=True, min_cuda_gpus=4)
|
||||
def test_fsdp2_tensor_parallel():
|
||||
from torch.distributed._tensor import DTensor
|
||||
|
||||
|
@ -237,7 +237,7 @@ def _train(fabric, model=None, optimizer=None):
|
|||
return model, optimizer
|
||||
|
||||
|
||||
@RunIf(min_torch="2.3", min_cuda_gpus=4, standalone=True)
|
||||
@RunIf(min_torch="2.4", min_cuda_gpus=4, standalone=True)
|
||||
@pytest.mark.parametrize(
|
||||
"precision",
|
||||
[
|
||||
|
@ -445,7 +445,7 @@ def test_load_full_state_dict_into_sharded_model(tmp_path):
|
|||
assert torch.equal(params_before, params_after)
|
||||
|
||||
|
||||
@RunIf(min_torch="2.3", min_cuda_gpus=2, skip_windows=True, standalone=True)
|
||||
@RunIf(min_torch="2.4", min_cuda_gpus=2, skip_windows=True, standalone=True)
|
||||
@pytest.mark.parametrize("move_to_device", [True, False])
|
||||
@mock.patch("lightning.fabric.wrappers._FabricModule")
|
||||
def test_setup_module_move_to_device(fabric_module_mock, move_to_device):
|
||||
|
@ -471,7 +471,7 @@ def test_setup_module_move_to_device(fabric_module_mock, move_to_device):
|
|||
assert fabric.device == torch.device("cuda", fabric.local_rank)
|
||||
|
||||
|
||||
@RunIf(min_torch="2.3", min_cuda_gpus=2, skip_windows=True, standalone=True)
|
||||
@RunIf(min_torch="2.4", min_cuda_gpus=2, skip_windows=True, standalone=True)
|
||||
@pytest.mark.parametrize(
|
||||
("precision", "expected_dtype"),
|
||||
[
|
||||
|
@ -502,7 +502,7 @@ def test_module_init_context(precision, expected_dtype):
|
|||
_run_setup_assertions(empty_init=True, expected_device=torch.device("meta"))
|
||||
|
||||
|
||||
@RunIf(min_torch="2.3", min_cuda_gpus=2, standalone=True)
|
||||
@RunIf(min_torch="2.4", min_cuda_gpus=2, standalone=True)
|
||||
def test_save_filter(tmp_path):
|
||||
strategy = ModelParallelStrategy(
|
||||
parallelize_fn=_parallelize_feed_forward_fsdp2,
|
||||
|
@ -541,7 +541,7 @@ def _parallelize_single_linear_tp_fsdp2(model, device_mesh):
|
|||
return model
|
||||
|
||||
|
||||
@RunIf(min_torch="2.3", min_cuda_gpus=2, standalone=True)
|
||||
@RunIf(min_torch="2.4", min_cuda_gpus=2, standalone=True)
|
||||
@pytest.mark.parametrize(
|
||||
"precision",
|
||||
[
|
||||
|
@ -597,7 +597,7 @@ def test_clip_gradients(clip_type, precision):
|
|||
optimizer.zero_grad()
|
||||
|
||||
|
||||
@RunIf(min_torch="2.3", min_cuda_gpus=4, standalone=True)
|
||||
@RunIf(min_torch="2.4", min_cuda_gpus=4, standalone=True)
|
||||
def test_save_sharded_and_consolidate_and_load(tmp_path):
|
||||
"""Test the consolidation of a distributed (DTensor) checkpoint into a single file."""
|
||||
strategy = ModelParallelStrategy(
|
||||
|
|
|
@ -868,7 +868,7 @@ def test_precision_selection_amp_ddp(strategy, devices, is_custom_plugin, plugin
|
|||
assert isinstance(connector.precision, plugin_cls)
|
||||
|
||||
|
||||
@RunIf(min_torch="2.3")
|
||||
@RunIf(min_torch="2.4")
|
||||
@pytest.mark.parametrize(
|
||||
("precision", "raises"),
|
||||
[("32-true", False), ("16-true", False), ("bf16-true", False), ("16-mixed", True), ("bf16-mixed", False)],
|
||||
|
|
|
@ -685,7 +685,7 @@ def test_unwrap_compiled():
|
|||
assert unwrapped is compiled._orig_mod
|
||||
assert compile_kwargs == {"fullgraph": True, "dynamic": True, "disable": False}
|
||||
|
||||
del compiled._compile_kwargs
|
||||
compiled._compile_kwargs = None
|
||||
with pytest.raises(RuntimeError, match="Failed to determine the arguments that were used to compile the module"):
|
||||
_unwrap_compiled(compiled)
|
||||
|
||||
|
|
|
@ -174,27 +174,6 @@ def test_save_checkpoint_path_exists(shutil_mock, torch_save_mock, tmp_path):
|
|||
assert path.is_dir()
|
||||
|
||||
|
||||
@RunIf(min_torch="2.3")
|
||||
@mock.patch("lightning.fabric.strategies.model_parallel._TORCH_GREATER_EQUAL_2_4", False)
|
||||
def test_load_full_checkpoint_support(tmp_path):
|
||||
"""Test that loading non-distributed checkpoints into distributed models requires PyTorch >= 2.4."""
|
||||
strategy = ModelParallelStrategy()
|
||||
strategy.model = Mock()
|
||||
strategy._lightning_module = Mock(strict_loading=True)
|
||||
path = tmp_path / "full.ckpt"
|
||||
path.touch()
|
||||
|
||||
with pytest.raises(ImportError, match="Loading .* into a distributed model requires PyTorch >= 2.4"), mock.patch(
|
||||
"lightning.fabric.strategies.model_parallel._has_dtensor_modules", return_value=True
|
||||
):
|
||||
strategy.load_checkpoint(checkpoint_path=path)
|
||||
|
||||
with pytest.raises(ImportError, match="Loading .* into a distributed model requires PyTorch >= 2.4"), mock.patch(
|
||||
"lightning.fabric.strategies.model_parallel._has_dtensor_modules", return_value=True
|
||||
):
|
||||
strategy.load_checkpoint(checkpoint_path=path)
|
||||
|
||||
|
||||
@RunIf(min_torch="2.3")
|
||||
@mock.patch("lightning.fabric.strategies.model_parallel._has_dtensor_modules", return_value=True)
|
||||
def test_load_unknown_checkpoint_type(_, tmp_path):
|
||||
|
|
Loading…
Reference in New Issue