diff --git a/pyproject.toml b/pyproject.toml index 6edd6d1a8f..da4cd7f197 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -76,7 +76,6 @@ ignore = [ "S108", "E203", # conflicts with black ] -ignore-init-module-imports = true [tool.ruff.lint.per-file-ignores] ".actions/*" = ["S101", "S310"] diff --git a/requirements/typing.txt b/requirements/typing.txt index 9f1952605b..0323edfd60 100644 --- a/requirements/typing.txt +++ b/requirements/typing.txt @@ -1,5 +1,5 @@ mypy==1.11.0 -torch==2.4.0 +torch==2.4.1 types-Markdown types-PyYAML diff --git a/src/lightning/fabric/utilities/imports.py b/src/lightning/fabric/utilities/imports.py index 4dbd57e531..a1c5a6f6dc 100644 --- a/src/lightning/fabric/utilities/imports.py +++ b/src/lightning/fabric/utilities/imports.py @@ -31,7 +31,9 @@ _IS_INTERACTIVE = hasattr(sys, "ps1") or bool(sys.flags.interactive) _TORCH_GREATER_EQUAL_2_2 = compare_version("torch", operator.ge, "2.2.0") _TORCH_GREATER_EQUAL_2_3 = compare_version("torch", operator.ge, "2.3.0") +_TORCH_EQUAL_2_4_0 = compare_version("torch", operator.eq, "2.4.0") _TORCH_GREATER_EQUAL_2_4 = compare_version("torch", operator.ge, "2.4.0") +_TORCH_GREATER_EQUAL_2_4_1 = compare_version("torch", operator.ge, "2.4.1") _PYTHON_GREATER_EQUAL_3_10_0 = (sys.version_info.major, sys.version_info.minor) >= (3, 10) diff --git a/tests/tests_pytorch/callbacks/test_early_stopping.py b/tests/tests_pytorch/callbacks/test_early_stopping.py index 633c1dc085..b7e52ee549 100644 --- a/tests/tests_pytorch/callbacks/test_early_stopping.py +++ b/tests/tests_pytorch/callbacks/test_early_stopping.py @@ -23,7 +23,7 @@ from unittest.mock import Mock import cloudpickle import pytest import torch -from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_4 +from lightning.fabric.utilities.imports import _TORCH_EQUAL_2_4_0 from lightning.pytorch import Trainer, seed_everything from lightning.pytorch.callbacks import EarlyStopping, ModelCheckpoint from lightning.pytorch.demos.boring_classes import BoringModel @@ -193,12 +193,12 @@ def test_pickling(): early_stopping = EarlyStopping(monitor="foo") early_stopping_pickled = pickle.dumps(early_stopping) - with pytest.warns(FutureWarning, match="`weights_only=False`") if _TORCH_GREATER_EQUAL_2_4 else nullcontext(): + with pytest.warns(FutureWarning, match="`weights_only=False`") if _TORCH_EQUAL_2_4_0 else nullcontext(): early_stopping_loaded = pickle.loads(early_stopping_pickled) assert vars(early_stopping) == vars(early_stopping_loaded) early_stopping_pickled = cloudpickle.dumps(early_stopping) - with pytest.warns(FutureWarning, match="`weights_only=False`") if _TORCH_GREATER_EQUAL_2_4 else nullcontext(): + with pytest.warns(FutureWarning, match="`weights_only=False`") if _TORCH_EQUAL_2_4_0 else nullcontext(): early_stopping_loaded = cloudpickle.loads(early_stopping_pickled) assert vars(early_stopping) == vars(early_stopping_loaded) diff --git a/tests/tests_pytorch/checkpointing/test_model_checkpoint.py b/tests/tests_pytorch/checkpointing/test_model_checkpoint.py index 8ef78a742f..97d8d3c4d0 100644 --- a/tests/tests_pytorch/checkpointing/test_model_checkpoint.py +++ b/tests/tests_pytorch/checkpointing/test_model_checkpoint.py @@ -32,7 +32,7 @@ import torch import yaml from jsonargparse import ArgumentParser from lightning.fabric.utilities.cloud_io import _load as pl_load -from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_4 +from lightning.fabric.utilities.imports import _TORCH_EQUAL_2_4_0 from lightning.pytorch import Trainer, seed_everything from lightning.pytorch.callbacks import ModelCheckpoint from lightning.pytorch.demos.boring_classes import BoringModel @@ -352,12 +352,12 @@ def test_pickling(tmp_path): ckpt = ModelCheckpoint(dirpath=tmp_path) ckpt_pickled = pickle.dumps(ckpt) - with pytest.warns(FutureWarning, match="`weights_only=False`") if _TORCH_GREATER_EQUAL_2_4 else nullcontext(): + with pytest.warns(FutureWarning, match="`weights_only=False`") if _TORCH_EQUAL_2_4_0 else nullcontext(): ckpt_loaded = pickle.loads(ckpt_pickled) assert vars(ckpt) == vars(ckpt_loaded) ckpt_pickled = cloudpickle.dumps(ckpt) - with pytest.warns(FutureWarning, match="`weights_only=False`") if _TORCH_GREATER_EQUAL_2_4 else nullcontext(): + with pytest.warns(FutureWarning, match="`weights_only=False`") if _TORCH_EQUAL_2_4_0 else nullcontext(): ckpt_loaded = cloudpickle.loads(ckpt_pickled) assert vars(ckpt) == vars(ckpt_loaded) diff --git a/tests/tests_pytorch/core/test_metric_result_integration.py b/tests/tests_pytorch/core/test_metric_result_integration.py index 9818f9807a..ef340d1e17 100644 --- a/tests/tests_pytorch/core/test_metric_result_integration.py +++ b/tests/tests_pytorch/core/test_metric_result_integration.py @@ -19,7 +19,7 @@ from unittest import mock import lightning.pytorch as pl import pytest import torch -from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_4 +from lightning.fabric.utilities.imports import _TORCH_EQUAL_2_4_0 from lightning.fabric.utilities.warnings import PossibleUserWarning from lightning.pytorch import Trainer from lightning.pytorch.callbacks import OnExceptionCheckpoint @@ -254,7 +254,7 @@ def test_result_collection_restoration(tmp_path): } # make sure can be pickled - with pytest.warns(FutureWarning, match="`weights_only=False`") if _TORCH_GREATER_EQUAL_2_4 else nullcontext(): + with pytest.warns(FutureWarning, match="`weights_only=False`") if _TORCH_EQUAL_2_4_0 else nullcontext(): pickle.loads(pickle.dumps(result)) # make sure can be torch.loaded filepath = str(tmp_path / "result") diff --git a/tests/tests_pytorch/helpers/test_datasets.py b/tests/tests_pytorch/helpers/test_datasets.py index ddc20c29e6..98d77a6d9a 100644 --- a/tests/tests_pytorch/helpers/test_datasets.py +++ b/tests/tests_pytorch/helpers/test_datasets.py @@ -17,7 +17,7 @@ from contextlib import nullcontext import cloudpickle import pytest import torch -from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_4 +from lightning.fabric.utilities.imports import _TORCH_EQUAL_2_4_0 from tests_pytorch import _PATH_DATASETS from tests_pytorch.helpers.datasets import MNIST, AverageDataset, TrialMNIST @@ -44,9 +44,9 @@ def test_pickling_dataset_mnist(dataset_cls, args): mnist = dataset_cls(**args) mnist_pickled = pickle.dumps(mnist) - with pytest.warns(FutureWarning, match="`weights_only=False`") if _TORCH_GREATER_EQUAL_2_4 else nullcontext(): + with pytest.warns(FutureWarning, match="`weights_only=False`") if _TORCH_EQUAL_2_4_0 else nullcontext(): pickle.loads(mnist_pickled) mnist_pickled = cloudpickle.dumps(mnist) - with pytest.warns(FutureWarning, match="`weights_only=False`") if _TORCH_GREATER_EQUAL_2_4 else nullcontext(): + with pytest.warns(FutureWarning, match="`weights_only=False`") if _TORCH_EQUAL_2_4_0 else nullcontext(): cloudpickle.loads(mnist_pickled) diff --git a/tests/tests_pytorch/loggers/test_all.py b/tests/tests_pytorch/loggers/test_all.py index 503e49fe6c..c5b07562af 100644 --- a/tests/tests_pytorch/loggers/test_all.py +++ b/tests/tests_pytorch/loggers/test_all.py @@ -20,7 +20,7 @@ from unittest.mock import ANY, Mock import pytest import torch -from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_4 +from lightning.fabric.utilities.imports import _TORCH_EQUAL_2_4_0, _TORCH_GREATER_EQUAL_2_4_1 from lightning.pytorch import Callback, Trainer from lightning.pytorch.demos.boring_classes import BoringModel from lightning.pytorch.loggers import ( @@ -163,7 +163,7 @@ def test_loggers_pickle_all(tmp_path, monkeypatch, logger_class): pytest.xfail(f"pickle test requires {logger_class.__class__} dependencies to be installed.") -def _test_loggers_pickle(tmp_path, monkeypatch, logger_class): +def _test_loggers_pickle(tmp_path, monkeypatch, logger_class: Logger): """Verify that pickling trainer with logger works.""" _patch_comet_atexit(monkeypatch) @@ -184,7 +184,11 @@ def _test_loggers_pickle(tmp_path, monkeypatch, logger_class): trainer = Trainer(max_epochs=1, logger=logger) pkl_bytes = pickle.dumps(trainer) - with pytest.warns(FutureWarning, match="`weights_only=False`") if _TORCH_GREATER_EQUAL_2_4 else nullcontext(): + with ( + pytest.warns(FutureWarning, match="`weights_only=False`") + if _TORCH_EQUAL_2_4_0 or (_TORCH_GREATER_EQUAL_2_4_1 and logger_class not in (CSVLogger, TensorBoardLogger)) + else nullcontext() + ): trainer2 = pickle.loads(pkl_bytes) trainer2.logger.log_metrics({"acc": 1.0}) diff --git a/tests/tests_pytorch/loggers/test_logger.py b/tests/tests_pytorch/loggers/test_logger.py index 7b384890f6..de0028000c 100644 --- a/tests/tests_pytorch/loggers/test_logger.py +++ b/tests/tests_pytorch/loggers/test_logger.py @@ -21,7 +21,7 @@ from unittest.mock import patch import numpy as np import pytest import torch -from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_4 +from lightning.fabric.utilities.imports import _TORCH_EQUAL_2_4_0 from lightning.fabric.utilities.logger import _convert_params, _sanitize_params from lightning.pytorch import Trainer from lightning.pytorch.demos.boring_classes import BoringDataModule, BoringModel @@ -124,7 +124,7 @@ def test_multiple_loggers_pickle(tmp_path): trainer = Trainer(logger=[logger1, logger2]) pkl_bytes = pickle.dumps(trainer) - with pytest.warns(FutureWarning, match="`weights_only=False`") if _TORCH_GREATER_EQUAL_2_4 else nullcontext(): + with pytest.warns(FutureWarning, match="`weights_only=False`") if _TORCH_EQUAL_2_4_0 else nullcontext(): trainer2 = pickle.loads(pkl_bytes) for logger in trainer2.loggers: logger.log_metrics({"acc": 1.0}, 0) diff --git a/tests/tests_pytorch/loggers/test_wandb.py b/tests/tests_pytorch/loggers/test_wandb.py index e9195f6283..4e3fbb287a 100644 --- a/tests/tests_pytorch/loggers/test_wandb.py +++ b/tests/tests_pytorch/loggers/test_wandb.py @@ -19,7 +19,7 @@ from unittest import mock import pytest import yaml -from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_4 +from lightning.fabric.utilities.imports import _TORCH_EQUAL_2_4_0 from lightning.pytorch import Trainer from lightning.pytorch.callbacks import ModelCheckpoint from lightning.pytorch.cli import LightningCLI @@ -162,7 +162,7 @@ def test_wandb_pickle(wandb_mock, tmp_path): assert trainer.logger.experiment, "missing experiment" assert trainer.log_dir == logger.save_dir pkl_bytes = pickle.dumps(trainer) - with pytest.warns(FutureWarning, match="`weights_only=False`") if _TORCH_GREATER_EQUAL_2_4 else nullcontext(): + with pytest.warns(FutureWarning, match="`weights_only=False`") if _TORCH_EQUAL_2_4_0 else nullcontext(): trainer2 = pickle.loads(pkl_bytes) assert os.environ["WANDB_MODE"] == "dryrun"