update no_warning_call utility in tests (#11557)

This commit is contained in:
Rohit Gupta 2022-02-03 04:13:13 +05:30 committed by GitHub
parent 115a5d08e8
commit e76e1e7018
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 16 additions and 31 deletions

View File

@ -14,9 +14,9 @@
"""Test deprecated functionality which will be removed in vX.Y.Z."""
import sys
from contextlib import contextmanager
from typing import Optional, Type
from typing import Optional
import pytest
from tests.helpers.utils import no_warning_call
def _soft_unimport_module(str_module):
@ -25,28 +25,6 @@ def _soft_unimport_module(str_module):
del sys.modules[str_module]
@contextmanager
def no_warning_call(expected_warning: Type[Warning] = UserWarning, match: Optional[str] = None):
with pytest.warns(None) as record:
yield
if match is None:
try:
w = record.pop(expected_warning)
except AssertionError:
# no warning raised
return
else:
for w in record.list:
if w.category is expected_warning and match in w.message.args[0]:
break
else:
return
msg = "A warning" if expected_warning is None else f"`{expected_warning.__name__}`"
raise AssertionError(f"{msg} was raised: {w}")
@contextmanager
def no_deprecated_call(match: Optional[str] = None):
with no_warning_call(expected_warning=DeprecationWarning, match=match):

View File

@ -15,7 +15,7 @@ import functools
import os
import traceback
from contextlib import contextmanager
from typing import Optional
from typing import Optional, Type
import pytest
@ -114,15 +114,22 @@ def pl_multi_process_test(func):
@contextmanager
def no_warning_call(warning_type, match: Optional[str] = None):
def no_warning_call(expected_warning: Type[Warning] = UserWarning, match: Optional[str] = None):
with pytest.warns(None) as record:
yield
if match is None:
try:
w = record.pop(warning_type)
if not (match and match in str(w.message)):
return
w = record.pop(expected_warning)
except AssertionError:
# no warning raised
return
raise AssertionError(f"`{warning_type}` was raised: {w}")
else:
for w in record.list:
if w.category is expected_warning and match in w.message.args[0]:
break
else:
return
msg = "A warning" if expected_warning is None else f"`{expected_warning.__name__}`"
raise AssertionError(f"{msg} was raised: {w}")

View File

@ -16,8 +16,8 @@ from pytorch_lightning.utilities.data import (
warning_cache,
)
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from tests.deprecated_api import no_warning_call
from tests.helpers.boring_model import BoringModel, RandomDataset, RandomIterableDataset
from tests.helpers.utils import no_warning_call
def test_extract_batch_size():