fix(tests): update tests after torch 2.4.1 (#20302)

* update

* test_loggers_pickle_all

* more...

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
Jirka Borovec 2024-09-26 17:52:22 +02:00 committed by GitHub
parent e452fe83b1
commit d1ca3c6e09
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
10 changed files with 25 additions and 20 deletions

View File

@ -76,7 +76,6 @@ ignore = [
"S108",
"E203", # conflicts with black
]
ignore-init-module-imports = true
[tool.ruff.lint.per-file-ignores]
".actions/*" = ["S101", "S310"]

View File

@ -1,5 +1,5 @@
mypy==1.11.0
torch==2.4.0
torch==2.4.1
types-Markdown
types-PyYAML

View File

@ -31,7 +31,9 @@ _IS_INTERACTIVE = hasattr(sys, "ps1") or bool(sys.flags.interactive)
_TORCH_GREATER_EQUAL_2_2 = compare_version("torch", operator.ge, "2.2.0")
_TORCH_GREATER_EQUAL_2_3 = compare_version("torch", operator.ge, "2.3.0")
_TORCH_EQUAL_2_4_0 = compare_version("torch", operator.eq, "2.4.0")
_TORCH_GREATER_EQUAL_2_4 = compare_version("torch", operator.ge, "2.4.0")
_TORCH_GREATER_EQUAL_2_4_1 = compare_version("torch", operator.ge, "2.4.1")
_PYTHON_GREATER_EQUAL_3_10_0 = (sys.version_info.major, sys.version_info.minor) >= (3, 10)

View File

@ -23,7 +23,7 @@ from unittest.mock import Mock
import cloudpickle
import pytest
import torch
from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_4
from lightning.fabric.utilities.imports import _TORCH_EQUAL_2_4_0
from lightning.pytorch import Trainer, seed_everything
from lightning.pytorch.callbacks import EarlyStopping, ModelCheckpoint
from lightning.pytorch.demos.boring_classes import BoringModel
@ -193,12 +193,12 @@ def test_pickling():
early_stopping = EarlyStopping(monitor="foo")
early_stopping_pickled = pickle.dumps(early_stopping)
with pytest.warns(FutureWarning, match="`weights_only=False`") if _TORCH_GREATER_EQUAL_2_4 else nullcontext():
with pytest.warns(FutureWarning, match="`weights_only=False`") if _TORCH_EQUAL_2_4_0 else nullcontext():
early_stopping_loaded = pickle.loads(early_stopping_pickled)
assert vars(early_stopping) == vars(early_stopping_loaded)
early_stopping_pickled = cloudpickle.dumps(early_stopping)
with pytest.warns(FutureWarning, match="`weights_only=False`") if _TORCH_GREATER_EQUAL_2_4 else nullcontext():
with pytest.warns(FutureWarning, match="`weights_only=False`") if _TORCH_EQUAL_2_4_0 else nullcontext():
early_stopping_loaded = cloudpickle.loads(early_stopping_pickled)
assert vars(early_stopping) == vars(early_stopping_loaded)

View File

@ -32,7 +32,7 @@ import torch
import yaml
from jsonargparse import ArgumentParser
from lightning.fabric.utilities.cloud_io import _load as pl_load
from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_4
from lightning.fabric.utilities.imports import _TORCH_EQUAL_2_4_0
from lightning.pytorch import Trainer, seed_everything
from lightning.pytorch.callbacks import ModelCheckpoint
from lightning.pytorch.demos.boring_classes import BoringModel
@ -352,12 +352,12 @@ def test_pickling(tmp_path):
ckpt = ModelCheckpoint(dirpath=tmp_path)
ckpt_pickled = pickle.dumps(ckpt)
with pytest.warns(FutureWarning, match="`weights_only=False`") if _TORCH_GREATER_EQUAL_2_4 else nullcontext():
with pytest.warns(FutureWarning, match="`weights_only=False`") if _TORCH_EQUAL_2_4_0 else nullcontext():
ckpt_loaded = pickle.loads(ckpt_pickled)
assert vars(ckpt) == vars(ckpt_loaded)
ckpt_pickled = cloudpickle.dumps(ckpt)
with pytest.warns(FutureWarning, match="`weights_only=False`") if _TORCH_GREATER_EQUAL_2_4 else nullcontext():
with pytest.warns(FutureWarning, match="`weights_only=False`") if _TORCH_EQUAL_2_4_0 else nullcontext():
ckpt_loaded = cloudpickle.loads(ckpt_pickled)
assert vars(ckpt) == vars(ckpt_loaded)

View File

@ -19,7 +19,7 @@ from unittest import mock
import lightning.pytorch as pl
import pytest
import torch
from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_4
from lightning.fabric.utilities.imports import _TORCH_EQUAL_2_4_0
from lightning.fabric.utilities.warnings import PossibleUserWarning
from lightning.pytorch import Trainer
from lightning.pytorch.callbacks import OnExceptionCheckpoint
@ -254,7 +254,7 @@ def test_result_collection_restoration(tmp_path):
}
# make sure can be pickled
with pytest.warns(FutureWarning, match="`weights_only=False`") if _TORCH_GREATER_EQUAL_2_4 else nullcontext():
with pytest.warns(FutureWarning, match="`weights_only=False`") if _TORCH_EQUAL_2_4_0 else nullcontext():
pickle.loads(pickle.dumps(result))
# make sure can be torch.loaded
filepath = str(tmp_path / "result")

View File

@ -17,7 +17,7 @@ from contextlib import nullcontext
import cloudpickle
import pytest
import torch
from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_4
from lightning.fabric.utilities.imports import _TORCH_EQUAL_2_4_0
from tests_pytorch import _PATH_DATASETS
from tests_pytorch.helpers.datasets import MNIST, AverageDataset, TrialMNIST
@ -44,9 +44,9 @@ def test_pickling_dataset_mnist(dataset_cls, args):
mnist = dataset_cls(**args)
mnist_pickled = pickle.dumps(mnist)
with pytest.warns(FutureWarning, match="`weights_only=False`") if _TORCH_GREATER_EQUAL_2_4 else nullcontext():
with pytest.warns(FutureWarning, match="`weights_only=False`") if _TORCH_EQUAL_2_4_0 else nullcontext():
pickle.loads(mnist_pickled)
mnist_pickled = cloudpickle.dumps(mnist)
with pytest.warns(FutureWarning, match="`weights_only=False`") if _TORCH_GREATER_EQUAL_2_4 else nullcontext():
with pytest.warns(FutureWarning, match="`weights_only=False`") if _TORCH_EQUAL_2_4_0 else nullcontext():
cloudpickle.loads(mnist_pickled)

View File

@ -20,7 +20,7 @@ from unittest.mock import ANY, Mock
import pytest
import torch
from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_4
from lightning.fabric.utilities.imports import _TORCH_EQUAL_2_4_0, _TORCH_GREATER_EQUAL_2_4_1
from lightning.pytorch import Callback, Trainer
from lightning.pytorch.demos.boring_classes import BoringModel
from lightning.pytorch.loggers import (
@ -163,7 +163,7 @@ def test_loggers_pickle_all(tmp_path, monkeypatch, logger_class):
pytest.xfail(f"pickle test requires {logger_class.__class__} dependencies to be installed.")
def _test_loggers_pickle(tmp_path, monkeypatch, logger_class):
def _test_loggers_pickle(tmp_path, monkeypatch, logger_class: Logger):
"""Verify that pickling trainer with logger works."""
_patch_comet_atexit(monkeypatch)
@ -184,7 +184,11 @@ def _test_loggers_pickle(tmp_path, monkeypatch, logger_class):
trainer = Trainer(max_epochs=1, logger=logger)
pkl_bytes = pickle.dumps(trainer)
with pytest.warns(FutureWarning, match="`weights_only=False`") if _TORCH_GREATER_EQUAL_2_4 else nullcontext():
with (
pytest.warns(FutureWarning, match="`weights_only=False`")
if _TORCH_EQUAL_2_4_0 or (_TORCH_GREATER_EQUAL_2_4_1 and logger_class not in (CSVLogger, TensorBoardLogger))
else nullcontext()
):
trainer2 = pickle.loads(pkl_bytes)
trainer2.logger.log_metrics({"acc": 1.0})

View File

@ -21,7 +21,7 @@ from unittest.mock import patch
import numpy as np
import pytest
import torch
from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_4
from lightning.fabric.utilities.imports import _TORCH_EQUAL_2_4_0
from lightning.fabric.utilities.logger import _convert_params, _sanitize_params
from lightning.pytorch import Trainer
from lightning.pytorch.demos.boring_classes import BoringDataModule, BoringModel
@ -124,7 +124,7 @@ def test_multiple_loggers_pickle(tmp_path):
trainer = Trainer(logger=[logger1, logger2])
pkl_bytes = pickle.dumps(trainer)
with pytest.warns(FutureWarning, match="`weights_only=False`") if _TORCH_GREATER_EQUAL_2_4 else nullcontext():
with pytest.warns(FutureWarning, match="`weights_only=False`") if _TORCH_EQUAL_2_4_0 else nullcontext():
trainer2 = pickle.loads(pkl_bytes)
for logger in trainer2.loggers:
logger.log_metrics({"acc": 1.0}, 0)

View File

@ -19,7 +19,7 @@ from unittest import mock
import pytest
import yaml
from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_4
from lightning.fabric.utilities.imports import _TORCH_EQUAL_2_4_0
from lightning.pytorch import Trainer
from lightning.pytorch.callbacks import ModelCheckpoint
from lightning.pytorch.cli import LightningCLI
@ -162,7 +162,7 @@ def test_wandb_pickle(wandb_mock, tmp_path):
assert trainer.logger.experiment, "missing experiment"
assert trainer.log_dir == logger.save_dir
pkl_bytes = pickle.dumps(trainer)
with pytest.warns(FutureWarning, match="`weights_only=False`") if _TORCH_GREATER_EQUAL_2_4 else nullcontext():
with pytest.warns(FutureWarning, match="`weights_only=False`") if _TORCH_EQUAL_2_4_0 else nullcontext():
trainer2 = pickle.loads(pkl_bytes)
assert os.environ["WANDB_MODE"] == "dryrun"