update no_warning_call utility in tests (#11557)
This commit is contained in:
parent
115a5d08e8
commit
e76e1e7018
|
@ -14,9 +14,9 @@
|
|||
"""Test deprecated functionality which will be removed in vX.Y.Z."""
|
||||
import sys
|
||||
from contextlib import contextmanager
|
||||
from typing import Optional, Type
|
||||
from typing import Optional
|
||||
|
||||
import pytest
|
||||
from tests.helpers.utils import no_warning_call
|
||||
|
||||
|
||||
def _soft_unimport_module(str_module):
|
||||
|
@ -25,28 +25,6 @@ def _soft_unimport_module(str_module):
|
|||
del sys.modules[str_module]
|
||||
|
||||
|
||||
@contextmanager
|
||||
def no_warning_call(expected_warning: Type[Warning] = UserWarning, match: Optional[str] = None):
|
||||
with pytest.warns(None) as record:
|
||||
yield
|
||||
|
||||
if match is None:
|
||||
try:
|
||||
w = record.pop(expected_warning)
|
||||
except AssertionError:
|
||||
# no warning raised
|
||||
return
|
||||
else:
|
||||
for w in record.list:
|
||||
if w.category is expected_warning and match in w.message.args[0]:
|
||||
break
|
||||
else:
|
||||
return
|
||||
|
||||
msg = "A warning" if expected_warning is None else f"`{expected_warning.__name__}`"
|
||||
raise AssertionError(f"{msg} was raised: {w}")
|
||||
|
||||
|
||||
@contextmanager
|
||||
def no_deprecated_call(match: Optional[str] = None):
|
||||
with no_warning_call(expected_warning=DeprecationWarning, match=match):
|
||||
|
|
|
@ -15,7 +15,7 @@ import functools
|
|||
import os
|
||||
import traceback
|
||||
from contextlib import contextmanager
|
||||
from typing import Optional
|
||||
from typing import Optional, Type
|
||||
|
||||
import pytest
|
||||
|
||||
|
@ -114,15 +114,22 @@ def pl_multi_process_test(func):
|
|||
|
||||
|
||||
@contextmanager
|
||||
def no_warning_call(warning_type, match: Optional[str] = None):
|
||||
def no_warning_call(expected_warning: Type[Warning] = UserWarning, match: Optional[str] = None):
|
||||
with pytest.warns(None) as record:
|
||||
yield
|
||||
|
||||
if match is None:
|
||||
try:
|
||||
w = record.pop(warning_type)
|
||||
if not (match and match in str(w.message)):
|
||||
return
|
||||
w = record.pop(expected_warning)
|
||||
except AssertionError:
|
||||
# no warning raised
|
||||
return
|
||||
raise AssertionError(f"`{warning_type}` was raised: {w}")
|
||||
else:
|
||||
for w in record.list:
|
||||
if w.category is expected_warning and match in w.message.args[0]:
|
||||
break
|
||||
else:
|
||||
return
|
||||
|
||||
msg = "A warning" if expected_warning is None else f"`{expected_warning.__name__}`"
|
||||
raise AssertionError(f"{msg} was raised: {w}")
|
||||
|
|
|
@ -16,8 +16,8 @@ from pytorch_lightning.utilities.data import (
|
|||
warning_cache,
|
||||
)
|
||||
from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
||||
from tests.deprecated_api import no_warning_call
|
||||
from tests.helpers.boring_model import BoringModel, RandomDataset, RandomIterableDataset
|
||||
from tests.helpers.utils import no_warning_call
|
||||
|
||||
|
||||
def test_extract_batch_size():
|
||||
|
|
Loading…
Reference in New Issue