fix(tests): update tests after torch 2.4.1 (#20302)
* update * test_loggers_pickle_all * more... * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
parent
e452fe83b1
commit
d1ca3c6e09
|
@ -76,7 +76,6 @@ ignore = [
|
|||
"S108",
|
||||
"E203", # conflicts with black
|
||||
]
|
||||
ignore-init-module-imports = true
|
||||
|
||||
[tool.ruff.lint.per-file-ignores]
|
||||
".actions/*" = ["S101", "S310"]
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
mypy==1.11.0
|
||||
torch==2.4.0
|
||||
torch==2.4.1
|
||||
|
||||
types-Markdown
|
||||
types-PyYAML
|
||||
|
|
|
@ -31,7 +31,9 @@ _IS_INTERACTIVE = hasattr(sys, "ps1") or bool(sys.flags.interactive)
|
|||
|
||||
_TORCH_GREATER_EQUAL_2_2 = compare_version("torch", operator.ge, "2.2.0")
|
||||
_TORCH_GREATER_EQUAL_2_3 = compare_version("torch", operator.ge, "2.3.0")
|
||||
_TORCH_EQUAL_2_4_0 = compare_version("torch", operator.eq, "2.4.0")
|
||||
_TORCH_GREATER_EQUAL_2_4 = compare_version("torch", operator.ge, "2.4.0")
|
||||
_TORCH_GREATER_EQUAL_2_4_1 = compare_version("torch", operator.ge, "2.4.1")
|
||||
|
||||
_PYTHON_GREATER_EQUAL_3_10_0 = (sys.version_info.major, sys.version_info.minor) >= (3, 10)
|
||||
|
||||
|
|
|
@ -23,7 +23,7 @@ from unittest.mock import Mock
|
|||
import cloudpickle
|
||||
import pytest
|
||||
import torch
|
||||
from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_4
|
||||
from lightning.fabric.utilities.imports import _TORCH_EQUAL_2_4_0
|
||||
from lightning.pytorch import Trainer, seed_everything
|
||||
from lightning.pytorch.callbacks import EarlyStopping, ModelCheckpoint
|
||||
from lightning.pytorch.demos.boring_classes import BoringModel
|
||||
|
@ -193,12 +193,12 @@ def test_pickling():
|
|||
early_stopping = EarlyStopping(monitor="foo")
|
||||
|
||||
early_stopping_pickled = pickle.dumps(early_stopping)
|
||||
with pytest.warns(FutureWarning, match="`weights_only=False`") if _TORCH_GREATER_EQUAL_2_4 else nullcontext():
|
||||
with pytest.warns(FutureWarning, match="`weights_only=False`") if _TORCH_EQUAL_2_4_0 else nullcontext():
|
||||
early_stopping_loaded = pickle.loads(early_stopping_pickled)
|
||||
assert vars(early_stopping) == vars(early_stopping_loaded)
|
||||
|
||||
early_stopping_pickled = cloudpickle.dumps(early_stopping)
|
||||
with pytest.warns(FutureWarning, match="`weights_only=False`") if _TORCH_GREATER_EQUAL_2_4 else nullcontext():
|
||||
with pytest.warns(FutureWarning, match="`weights_only=False`") if _TORCH_EQUAL_2_4_0 else nullcontext():
|
||||
early_stopping_loaded = cloudpickle.loads(early_stopping_pickled)
|
||||
assert vars(early_stopping) == vars(early_stopping_loaded)
|
||||
|
||||
|
|
|
@ -32,7 +32,7 @@ import torch
|
|||
import yaml
|
||||
from jsonargparse import ArgumentParser
|
||||
from lightning.fabric.utilities.cloud_io import _load as pl_load
|
||||
from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_4
|
||||
from lightning.fabric.utilities.imports import _TORCH_EQUAL_2_4_0
|
||||
from lightning.pytorch import Trainer, seed_everything
|
||||
from lightning.pytorch.callbacks import ModelCheckpoint
|
||||
from lightning.pytorch.demos.boring_classes import BoringModel
|
||||
|
@ -352,12 +352,12 @@ def test_pickling(tmp_path):
|
|||
ckpt = ModelCheckpoint(dirpath=tmp_path)
|
||||
|
||||
ckpt_pickled = pickle.dumps(ckpt)
|
||||
with pytest.warns(FutureWarning, match="`weights_only=False`") if _TORCH_GREATER_EQUAL_2_4 else nullcontext():
|
||||
with pytest.warns(FutureWarning, match="`weights_only=False`") if _TORCH_EQUAL_2_4_0 else nullcontext():
|
||||
ckpt_loaded = pickle.loads(ckpt_pickled)
|
||||
assert vars(ckpt) == vars(ckpt_loaded)
|
||||
|
||||
ckpt_pickled = cloudpickle.dumps(ckpt)
|
||||
with pytest.warns(FutureWarning, match="`weights_only=False`") if _TORCH_GREATER_EQUAL_2_4 else nullcontext():
|
||||
with pytest.warns(FutureWarning, match="`weights_only=False`") if _TORCH_EQUAL_2_4_0 else nullcontext():
|
||||
ckpt_loaded = cloudpickle.loads(ckpt_pickled)
|
||||
assert vars(ckpt) == vars(ckpt_loaded)
|
||||
|
||||
|
|
|
@ -19,7 +19,7 @@ from unittest import mock
|
|||
import lightning.pytorch as pl
|
||||
import pytest
|
||||
import torch
|
||||
from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_4
|
||||
from lightning.fabric.utilities.imports import _TORCH_EQUAL_2_4_0
|
||||
from lightning.fabric.utilities.warnings import PossibleUserWarning
|
||||
from lightning.pytorch import Trainer
|
||||
from lightning.pytorch.callbacks import OnExceptionCheckpoint
|
||||
|
@ -254,7 +254,7 @@ def test_result_collection_restoration(tmp_path):
|
|||
}
|
||||
|
||||
# make sure can be pickled
|
||||
with pytest.warns(FutureWarning, match="`weights_only=False`") if _TORCH_GREATER_EQUAL_2_4 else nullcontext():
|
||||
with pytest.warns(FutureWarning, match="`weights_only=False`") if _TORCH_EQUAL_2_4_0 else nullcontext():
|
||||
pickle.loads(pickle.dumps(result))
|
||||
# make sure can be torch.loaded
|
||||
filepath = str(tmp_path / "result")
|
||||
|
|
|
@ -17,7 +17,7 @@ from contextlib import nullcontext
|
|||
import cloudpickle
|
||||
import pytest
|
||||
import torch
|
||||
from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_4
|
||||
from lightning.fabric.utilities.imports import _TORCH_EQUAL_2_4_0
|
||||
|
||||
from tests_pytorch import _PATH_DATASETS
|
||||
from tests_pytorch.helpers.datasets import MNIST, AverageDataset, TrialMNIST
|
||||
|
@ -44,9 +44,9 @@ def test_pickling_dataset_mnist(dataset_cls, args):
|
|||
mnist = dataset_cls(**args)
|
||||
|
||||
mnist_pickled = pickle.dumps(mnist)
|
||||
with pytest.warns(FutureWarning, match="`weights_only=False`") if _TORCH_GREATER_EQUAL_2_4 else nullcontext():
|
||||
with pytest.warns(FutureWarning, match="`weights_only=False`") if _TORCH_EQUAL_2_4_0 else nullcontext():
|
||||
pickle.loads(mnist_pickled)
|
||||
|
||||
mnist_pickled = cloudpickle.dumps(mnist)
|
||||
with pytest.warns(FutureWarning, match="`weights_only=False`") if _TORCH_GREATER_EQUAL_2_4 else nullcontext():
|
||||
with pytest.warns(FutureWarning, match="`weights_only=False`") if _TORCH_EQUAL_2_4_0 else nullcontext():
|
||||
cloudpickle.loads(mnist_pickled)
|
||||
|
|
|
@ -20,7 +20,7 @@ from unittest.mock import ANY, Mock
|
|||
|
||||
import pytest
|
||||
import torch
|
||||
from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_4
|
||||
from lightning.fabric.utilities.imports import _TORCH_EQUAL_2_4_0, _TORCH_GREATER_EQUAL_2_4_1
|
||||
from lightning.pytorch import Callback, Trainer
|
||||
from lightning.pytorch.demos.boring_classes import BoringModel
|
||||
from lightning.pytorch.loggers import (
|
||||
|
@ -163,7 +163,7 @@ def test_loggers_pickle_all(tmp_path, monkeypatch, logger_class):
|
|||
pytest.xfail(f"pickle test requires {logger_class.__class__} dependencies to be installed.")
|
||||
|
||||
|
||||
def _test_loggers_pickle(tmp_path, monkeypatch, logger_class):
|
||||
def _test_loggers_pickle(tmp_path, monkeypatch, logger_class: Logger):
|
||||
"""Verify that pickling trainer with logger works."""
|
||||
_patch_comet_atexit(monkeypatch)
|
||||
|
||||
|
@ -184,7 +184,11 @@ def _test_loggers_pickle(tmp_path, monkeypatch, logger_class):
|
|||
trainer = Trainer(max_epochs=1, logger=logger)
|
||||
pkl_bytes = pickle.dumps(trainer)
|
||||
|
||||
with pytest.warns(FutureWarning, match="`weights_only=False`") if _TORCH_GREATER_EQUAL_2_4 else nullcontext():
|
||||
with (
|
||||
pytest.warns(FutureWarning, match="`weights_only=False`")
|
||||
if _TORCH_EQUAL_2_4_0 or (_TORCH_GREATER_EQUAL_2_4_1 and logger_class not in (CSVLogger, TensorBoardLogger))
|
||||
else nullcontext()
|
||||
):
|
||||
trainer2 = pickle.loads(pkl_bytes)
|
||||
trainer2.logger.log_metrics({"acc": 1.0})
|
||||
|
||||
|
|
|
@ -21,7 +21,7 @@ from unittest.mock import patch
|
|||
import numpy as np
|
||||
import pytest
|
||||
import torch
|
||||
from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_4
|
||||
from lightning.fabric.utilities.imports import _TORCH_EQUAL_2_4_0
|
||||
from lightning.fabric.utilities.logger import _convert_params, _sanitize_params
|
||||
from lightning.pytorch import Trainer
|
||||
from lightning.pytorch.demos.boring_classes import BoringDataModule, BoringModel
|
||||
|
@ -124,7 +124,7 @@ def test_multiple_loggers_pickle(tmp_path):
|
|||
|
||||
trainer = Trainer(logger=[logger1, logger2])
|
||||
pkl_bytes = pickle.dumps(trainer)
|
||||
with pytest.warns(FutureWarning, match="`weights_only=False`") if _TORCH_GREATER_EQUAL_2_4 else nullcontext():
|
||||
with pytest.warns(FutureWarning, match="`weights_only=False`") if _TORCH_EQUAL_2_4_0 else nullcontext():
|
||||
trainer2 = pickle.loads(pkl_bytes)
|
||||
for logger in trainer2.loggers:
|
||||
logger.log_metrics({"acc": 1.0}, 0)
|
||||
|
|
|
@ -19,7 +19,7 @@ from unittest import mock
|
|||
|
||||
import pytest
|
||||
import yaml
|
||||
from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_4
|
||||
from lightning.fabric.utilities.imports import _TORCH_EQUAL_2_4_0
|
||||
from lightning.pytorch import Trainer
|
||||
from lightning.pytorch.callbacks import ModelCheckpoint
|
||||
from lightning.pytorch.cli import LightningCLI
|
||||
|
@ -162,7 +162,7 @@ def test_wandb_pickle(wandb_mock, tmp_path):
|
|||
assert trainer.logger.experiment, "missing experiment"
|
||||
assert trainer.log_dir == logger.save_dir
|
||||
pkl_bytes = pickle.dumps(trainer)
|
||||
with pytest.warns(FutureWarning, match="`weights_only=False`") if _TORCH_GREATER_EQUAL_2_4 else nullcontext():
|
||||
with pytest.warns(FutureWarning, match="`weights_only=False`") if _TORCH_EQUAL_2_4_0 else nullcontext():
|
||||
trainer2 = pickle.loads(pkl_bytes)
|
||||
|
||||
assert os.environ["WANDB_MODE"] == "dryrun"
|
||||
|
|
Loading…
Reference in New Issue