diff --git a/tests/deprecated_api/__init__.py b/tests/deprecated_api/__init__.py index 91c7ef1c1f..5480f398da 100644 --- a/tests/deprecated_api/__init__.py +++ b/tests/deprecated_api/__init__.py @@ -14,9 +14,9 @@ """Test deprecated functionality which will be removed in vX.Y.Z.""" import sys from contextlib import contextmanager -from typing import Optional, Type +from typing import Optional -import pytest +from tests.helpers.utils import no_warning_call def _soft_unimport_module(str_module): @@ -25,28 +25,6 @@ def _soft_unimport_module(str_module): del sys.modules[str_module] -@contextmanager -def no_warning_call(expected_warning: Type[Warning] = UserWarning, match: Optional[str] = None): - with pytest.warns(None) as record: - yield - - if match is None: - try: - w = record.pop(expected_warning) - except AssertionError: - # no warning raised - return - else: - for w in record.list: - if w.category is expected_warning and match in w.message.args[0]: - break - else: - return - - msg = "A warning" if expected_warning is None else f"`{expected_warning.__name__}`" - raise AssertionError(f"{msg} was raised: {w}") - - @contextmanager def no_deprecated_call(match: Optional[str] = None): with no_warning_call(expected_warning=DeprecationWarning, match=match): diff --git a/tests/helpers/utils.py b/tests/helpers/utils.py index 61ed359536..05dc566949 100644 --- a/tests/helpers/utils.py +++ b/tests/helpers/utils.py @@ -15,7 +15,7 @@ import functools import os import traceback from contextlib import contextmanager -from typing import Optional +from typing import Optional, Type import pytest @@ -114,15 +114,22 @@ def pl_multi_process_test(func): @contextmanager -def no_warning_call(warning_type, match: Optional[str] = None): +def no_warning_call(expected_warning: Type[Warning] = UserWarning, match: Optional[str] = None): with pytest.warns(None) as record: yield + if match is None: try: - w = record.pop(warning_type) - if not (match and match in str(w.message)): - return + w = record.pop(expected_warning) except AssertionError: # no warning raised return - raise AssertionError(f"`{warning_type}` was raised: {w}") + else: + for w in record.list: + if w.category is expected_warning and match in w.message.args[0]: + break + else: + return + + msg = "A warning" if expected_warning is None else f"`{expected_warning.__name__}`" + raise AssertionError(f"{msg} was raised: {w}") diff --git a/tests/utilities/test_data.py b/tests/utilities/test_data.py index 0d874e81d8..ae454080a3 100644 --- a/tests/utilities/test_data.py +++ b/tests/utilities/test_data.py @@ -16,8 +16,8 @@ from pytorch_lightning.utilities.data import ( warning_cache, ) from pytorch_lightning.utilities.exceptions import MisconfigurationException -from tests.deprecated_api import no_warning_call from tests.helpers.boring_model import BoringModel, RandomDataset, RandomIterableDataset +from tests.helpers.utils import no_warning_call def test_extract_batch_size():