From 6ee1f6c4b785b8a9d0f6ac33acecad9b9e029cab Mon Sep 17 00:00:00 2001 From: Justus Schock <12886177+justusschock@users.noreply.github.com> Date: Wed, 26 Oct 2022 18:33:22 +0200 Subject: [PATCH] New skip conditions for unpickle-patching tests (#15329) * New running conditions for tests * found one more mistake --- .../graveyard/test_legacy_import_unpickler.py | 15 +++++++++++---- 1 file changed, 11 insertions(+), 4 deletions(-) diff --git a/tests/tests_pytorch/graveyard/test_legacy_import_unpickler.py b/tests/tests_pytorch/graveyard/test_legacy_import_unpickler.py index da232be005..89bbd4a4f9 100644 --- a/tests/tests_pytorch/graveyard/test_legacy_import_unpickler.py +++ b/tests/tests_pytorch/graveyard/test_legacy_import_unpickler.py @@ -4,8 +4,9 @@ import sys import pytest import torch +from lightning_utilities.core.imports import package_available +from packaging.version import Version -import pytorch_lightning # noqa: F401 from tests_pytorch.checkpointing.test_legacy_checkpoints import ( CHECKPOINT_EXTENSION, LEGACY_BACK_COMPATIBLE_PL_VERSIONS, @@ -16,7 +17,7 @@ from tests_pytorch.helpers.utils import no_warning_call @pytest.mark.parametrize("pl_version", LEGACY_BACK_COMPATIBLE_PL_VERSIONS) @pytest.mark.skipif( - not "pytorch_" + "lightning" in sys.modules, reason="This test is only relevant for the standalone package" + package_available("lightning.pytorch"), reason="This test is only relevant for the standalone package" ) def test_imports_standalone(pl_version: str): assert any( @@ -40,7 +41,8 @@ def test_imports_standalone(pl_version: str): @pytest.mark.parametrize("pl_version", LEGACY_BACK_COMPATIBLE_PL_VERSIONS) @pytest.mark.skipif( - "pytorch_" + "lightning" in sys.modules, reason="This test is only relevant for the unified package" + not package_available("lightning.pytorch"), + reason="This test is only relevant for the unified package", ) def test_imports_unified(pl_version: str): assert any( @@ -55,7 +57,12 @@ def test_imports_unified(pl_version: str): assert path_ckpts, f'No checkpoints found in folder "{path_legacy}"' path_ckpt = path_ckpts[-1] - with pytest.warns(match="Redirecting imports of"): + # only below version 1.5.0 we pickled stuff in checkpoints + if Version(pl_version) < Version("1.5.0"): + context = pytest.warns(UserWarning, match="Redirecting import of") + else: + context = no_warning_call(match="Redirecting import of*") + with context: torch.load(path_ckpt) assert any(