from functools import wraps import pytest import torch.multiprocessing as mp def pytest_configure(config): config.addinivalue_line("markers", "spawn: spawn test in a separate process using torch.multiprocessing.spawn") @pytest.mark.tryfirst def pytest_pyfunc_call(pyfuncitem): if pyfuncitem.get_closest_marker("spawn"): testfunction = pyfuncitem.obj funcargs = pyfuncitem.funcargs testargs = tuple([funcargs[arg] for arg in pyfuncitem._fixtureinfo.argnames]) mp.spawn(wraps, (testfunction, testargs)) return True