add skipif warpper (#6258)
This commit is contained in:
parent
651c25feb6
commit
352e8f0d28
|
@ -22,15 +22,15 @@ from pytorch_lightning.metrics.functional.mean_relative_error import mean_relati
|
|||
from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
||||
from tests.helpers.datamodules import RegressDataModule
|
||||
from tests.helpers.simple_models import RegressionModel
|
||||
from tests.helpers.skipif import skipif_args
|
||||
from tests.helpers.skipif import SkipIf
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"observe",
|
||||
['average', pytest.param('histogram', marks=pytest.mark.skipif(**skipif_args(min_torch="1.5")))]
|
||||
['average', pytest.param('histogram', marks=SkipIf(min_torch="1.5"))]
|
||||
)
|
||||
@pytest.mark.parametrize("fuse", [True, False])
|
||||
@pytest.mark.skipif(**skipif_args(quant_available=True))
|
||||
@SkipIf(quantization=True)
|
||||
def test_quantization(tmpdir, observe, fuse):
|
||||
"""Parity test for quant model"""
|
||||
seed_everything(42)
|
||||
|
@ -65,7 +65,7 @@ def test_quantization(tmpdir, observe, fuse):
|
|||
assert torch.allclose(org_score, quant_score, atol=0.45)
|
||||
|
||||
|
||||
@pytest.mark.skipif(**skipif_args(quant_available=True))
|
||||
@SkipIf(quantization=True)
|
||||
def test_quantize_torchscript(tmpdir):
|
||||
"""Test converting to torchscipt """
|
||||
dm = RegressDataModule()
|
||||
|
@ -81,7 +81,7 @@ def test_quantize_torchscript(tmpdir):
|
|||
tsmodel(tsmodel.quant(batch[0]))
|
||||
|
||||
|
||||
@pytest.mark.skipif(**skipif_args(quant_available=True))
|
||||
@SkipIf(quantization=True)
|
||||
def test_quantization_exceptions(tmpdir):
|
||||
"""Test wrong fuse layers"""
|
||||
with pytest.raises(MisconfigurationException, match='Unsupported qconfig'):
|
||||
|
@ -124,7 +124,7 @@ def custom_trigger_last(trainer):
|
|||
(custom_trigger_last, 2),
|
||||
]
|
||||
)
|
||||
@pytest.mark.skipif(**skipif_args(quant_available=True))
|
||||
@SkipIf(quantization=True)
|
||||
def test_quantization_triggers(tmpdir, trigger_fn, expected_count):
|
||||
"""Test how many times the quant is called"""
|
||||
dm = RegressDataModule()
|
||||
|
|
|
@ -26,7 +26,7 @@ from pytorch_lightning import Trainer
|
|||
from pytorch_lightning.core.step_result import Result
|
||||
from pytorch_lightning.trainer.states import TrainerState
|
||||
from tests.helpers import BoringDataModule, BoringModel
|
||||
from tests.helpers.skipif import skipif_args
|
||||
from tests.helpers.skipif import SkipIf
|
||||
|
||||
|
||||
def _setup_ddp(rank, worldsize):
|
||||
|
@ -72,7 +72,7 @@ def test_result_reduce_ddp(result_cls):
|
|||
pytest.param(5, False, 0, id='nested_list_predictions'),
|
||||
pytest.param(6, False, 0, id='dict_list_predictions'),
|
||||
pytest.param(7, True, 0, id='write_dict_predictions'),
|
||||
pytest.param(0, True, 1, id='full_loop_single_gpu', marks=pytest.mark.skipif(**skipif_args(min_gpus=1)))
|
||||
pytest.param(0, True, 1, id='full_loop_single_gpu', marks=SkipIf(min_gpus=1))
|
||||
]
|
||||
)
|
||||
def test_result_obj_predictions(tmpdir, test_option, do_train, gpus):
|
||||
|
|
|
@ -21,52 +21,64 @@ from pkg_resources import get_distribution
|
|||
from pytorch_lightning.utilities import _TORCH_QUANTIZE_AVAILABLE
|
||||
|
||||
|
||||
def skipif_args(
|
||||
min_gpus: int = 0,
|
||||
min_torch: Optional[str] = None,
|
||||
quant_available: bool = False,
|
||||
) -> dict:
|
||||
""" Creating aggregated arguments for standard pytest skipif, sot the usecase is::
|
||||
|
||||
@pytest.mark.skipif(**create_skipif(min_torch="99"))
|
||||
def test_any_func(...):
|
||||
...
|
||||
|
||||
>>> from pprint import pprint
|
||||
>>> pprint(skipif_args(min_torch="99", min_gpus=0))
|
||||
{'condition': True, 'reason': 'Required: [torch>=99]'}
|
||||
>>> pprint(skipif_args(min_torch="0.0", min_gpus=0)) # doctest: +NORMALIZE_WHITESPACE
|
||||
{'condition': False, 'reason': 'Conditions satisfied, going ahead with the test.'}
|
||||
class SkipIf:
|
||||
"""
|
||||
conditions = []
|
||||
reasons = []
|
||||
SkipIf wrapper for simple marking specific cases, fully compatible with pytest.mark::
|
||||
|
||||
if min_gpus:
|
||||
conditions.append(torch.cuda.device_count() < min_gpus)
|
||||
reasons.append(f"GPUs>={min_gpus}")
|
||||
@SkipIf(min_torch="0.0")
|
||||
@pytest.mark.parametrize("arg1", [1, 2.0])
|
||||
def test_wrapper(arg1):
|
||||
assert arg1 > 0.0
|
||||
"""
|
||||
|
||||
if min_torch:
|
||||
torch_version = LooseVersion(get_distribution("torch").version)
|
||||
conditions.append(torch_version < LooseVersion(min_torch))
|
||||
reasons.append(f"torch>={min_torch}")
|
||||
def __new__(
|
||||
self,
|
||||
*args,
|
||||
min_gpus: int = 0,
|
||||
min_torch: Optional[str] = None,
|
||||
quantization: bool = False,
|
||||
**kwargs
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
args: native pytest.mark.skipif arguments
|
||||
min_gpus: min number of gpus required to run test
|
||||
min_torch: minimum pytorch version to run test
|
||||
quantization: if `torch.quantization` package is required to run test
|
||||
kwargs: native pytest.mark.skipif keyword arguments
|
||||
"""
|
||||
conditions = []
|
||||
reasons = []
|
||||
|
||||
if quant_available:
|
||||
_miss_default = 'fbgemm' not in torch.backends.quantized.supported_engines
|
||||
conditions.append(not _TORCH_QUANTIZE_AVAILABLE or _miss_default)
|
||||
reasons.append("PyTorch quantization is available")
|
||||
if min_gpus:
|
||||
conditions.append(torch.cuda.device_count() < min_gpus)
|
||||
reasons.append(f"GPUs>={min_gpus}")
|
||||
|
||||
if not any(conditions):
|
||||
return dict(condition=False, reason="Conditions satisfied, going ahead with the test.")
|
||||
if min_torch:
|
||||
torch_version = LooseVersion(get_distribution("torch").version)
|
||||
conditions.append(torch_version < LooseVersion(min_torch))
|
||||
reasons.append(f"torch>={min_torch}")
|
||||
|
||||
reasons = [rs for cond, rs in zip(conditions, reasons) if cond]
|
||||
return dict(condition=any(conditions), reason=f"Required: [{' + '.join(reasons)}]",)
|
||||
if quantization:
|
||||
_miss_default = 'fbgemm' not in torch.backends.quantized.supported_engines
|
||||
conditions.append(not _TORCH_QUANTIZE_AVAILABLE or _miss_default)
|
||||
reasons.append("missing PyTorch quantization")
|
||||
|
||||
reasons = [rs for cond, rs in zip(conditions, reasons) if cond]
|
||||
return pytest.mark.skipif(
|
||||
*args,
|
||||
condition=any(conditions),
|
||||
reason=f"Requires: [{' + '.join(reasons)}]",
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.skipif(**skipif_args(min_torch="99"))
|
||||
@SkipIf(min_torch="99")
|
||||
def test_always_skip():
|
||||
exit(1)
|
||||
|
||||
|
||||
@pytest.mark.skipif(**skipif_args(min_torch="0.0"))
|
||||
def test_always_pass():
|
||||
assert True
|
||||
@pytest.mark.parametrize("arg1", [0.5, 1.0, 2.0])
|
||||
@SkipIf(min_torch="0.0")
|
||||
def test_wrapper(arg1):
|
||||
assert arg1 > 0.0
|
||||
|
|
Loading…
Reference in New Issue