76 lines
2.5 KiB
Python
76 lines
2.5 KiB
Python
|
# Copyright The PyTorch Lightning team.
|
||
|
#
|
||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||
|
# you may not use this file except in compliance with the License.
|
||
|
# You may obtain a copy of the License at
|
||
|
#
|
||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||
|
#
|
||
|
# Unless required by applicable law or agreed to in writing, software
|
||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||
|
# See the License for the specific language governing permissions and
|
||
|
# limitations under the License.
|
||
|
import os
|
||
|
|
||
|
import pytest
|
||
|
import torch
|
||
|
|
||
|
|
||
|
# TODO(lite): Add all RunIf conditions once the relevant utilities have moved to lite source dir
|
||
|
class RunIf:
|
||
|
"""RunIf wrapper for simple marking specific cases, fully compatible with pytest.mark::
|
||
|
|
||
|
@RunIf(min_torch="0.0")
|
||
|
@pytest.mark.parametrize("arg1", [1, 2.0])
|
||
|
def test_wrapper(arg1):
|
||
|
assert arg1 > 0.0
|
||
|
"""
|
||
|
|
||
|
def __new__(
|
||
|
self,
|
||
|
*args,
|
||
|
min_cuda_gpus: int = 0,
|
||
|
standalone: bool = False,
|
||
|
**kwargs,
|
||
|
):
|
||
|
"""
|
||
|
Args:
|
||
|
*args: Any :class:`pytest.mark.skipif` arguments.
|
||
|
min_cuda_gpus: Require this number of gpus and that the ``PL_RUN_CUDA_TESTS=1`` environment variable is set.
|
||
|
standalone: Mark the test as standalone, our CI will run it in a separate process.
|
||
|
This requires that the ``PL_RUN_STANDALONE_TESTS=1`` environment variable is set.
|
||
|
**kwargs: Any :class:`pytest.mark.skipif` keyword arguments.
|
||
|
"""
|
||
|
conditions = []
|
||
|
reasons = []
|
||
|
|
||
|
if min_cuda_gpus:
|
||
|
conditions.append(torch.cuda.device_count() < min_cuda_gpus)
|
||
|
reasons.append(f"GPUs>={min_cuda_gpus}")
|
||
|
# used in conftest.py::pytest_collection_modifyitems
|
||
|
kwargs["min_cuda_gpus"] = True
|
||
|
|
||
|
if standalone:
|
||
|
env_flag = os.getenv("PL_RUN_STANDALONE_TESTS", "0")
|
||
|
conditions.append(env_flag != "1")
|
||
|
reasons.append("Standalone execution")
|
||
|
# used in conftest.py::pytest_collection_modifyitems
|
||
|
kwargs["standalone"] = True
|
||
|
|
||
|
reasons = [rs for cond, rs in zip(conditions, reasons) if cond]
|
||
|
return pytest.mark.skipif(
|
||
|
*args, condition=any(conditions), reason=f"Requires: [{' + '.join(reasons)}]", **kwargs
|
||
|
)
|
||
|
|
||
|
|
||
|
@RunIf(min_torch="99")
|
||
|
def test_always_skip():
|
||
|
exit(1)
|
||
|
|
||
|
|
||
|
@pytest.mark.parametrize("arg1", [0.5, 1.0, 2.0])
|
||
|
@RunIf(min_torch="0.0")
|
||
|
def test_wrapper(arg1: float):
|
||
|
assert arg1 > 0.0
|