# Copyright The PyTorch Lightning team. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import os import sys from typing import Optional import pytest import torch from packaging.version import Version from pkg_resources import get_distribution # TODO(lite): Add all RunIf conditions once the relevant utilities have moved to lite source dir class RunIf: """RunIf wrapper for simple marking specific cases, fully compatible with pytest.mark:: @RunIf(min_torch="0.0") @pytest.mark.parametrize("arg1", [1, 2.0]) def test_wrapper(arg1): assert arg1 > 0.0 """ def __new__( self, *args, min_cuda_gpus: int = 0, min_torch: Optional[str] = None, max_torch: Optional[str] = None, min_python: Optional[str] = None, skip_windows: bool = False, standalone: bool = False, **kwargs, ): """ Args: *args: Any :class:`pytest.mark.skipif` arguments. min_cuda_gpus: Require this number of gpus and that the ``PL_RUN_CUDA_TESTS=1`` environment variable is set. min_torch: Require that PyTorch is greater or equal than this version. max_torch: Require that PyTorch is less than this version. min_python: Require that Python is greater or equal than this version. skip_windows: Skip for Windows platform. standalone: Mark the test as standalone, our CI will run it in a separate process. This requires that the ``PL_RUN_STANDALONE_TESTS=1`` environment variable is set. **kwargs: Any :class:`pytest.mark.skipif` keyword arguments. """ conditions = [] reasons = [] if min_cuda_gpus: conditions.append(torch.cuda.device_count() < min_cuda_gpus) reasons.append(f"GPUs>={min_cuda_gpus}") # used in conftest.py::pytest_collection_modifyitems kwargs["min_cuda_gpus"] = True if min_torch: torch_version = get_distribution("torch").version conditions.append(Version(torch_version) < Version(min_torch)) reasons.append(f"torch>={min_torch}") if max_torch: torch_version = get_distribution("torch").version conditions.append(Version(torch_version) >= Version(max_torch)) reasons.append(f"torch<{max_torch}") if min_python: py_version = f"{sys.version_info.major}.{sys.version_info.minor}.{sys.version_info.micro}" conditions.append(Version(py_version) < Version(min_python)) reasons.append(f"python>={min_python}") if skip_windows: conditions.append(sys.platform == "win32") reasons.append("unimplemented on Windows") if standalone: env_flag = os.getenv("PL_RUN_STANDALONE_TESTS", "0") conditions.append(env_flag != "1") reasons.append("Standalone execution") # used in conftest.py::pytest_collection_modifyitems kwargs["standalone"] = True reasons = [rs for cond, rs in zip(conditions, reasons) if cond] return pytest.mark.skipif( *args, condition=any(conditions), reason=f"Requires: [{' + '.join(reasons)}]", **kwargs ) @RunIf(min_torch="99") def test_always_skip(): exit(1) @pytest.mark.parametrize("arg1", [0.5, 1.0, 2.0]) @RunIf(min_torch="0.0") def test_wrapper(arg1: float): assert arg1 > 0.0