lightning/tests/conftest.py

23 lines
584 B
Python

import pytest
import torch.multiprocessing as mp
def pytest_configure(config):
config.addinivalue_line("markers", "spawn: spawn test in a separate process using torch.multiprocessing.spawn")
def wrap(i, fn, args):
return fn(*args)
@pytest.mark.tryfirst
def pytest_pyfunc_call(pyfuncitem):
if pyfuncitem.get_closest_marker("spawn"):
testfunction = pyfuncitem.obj
funcargs = pyfuncitem.funcargs
testargs = tuple([funcargs[arg] for arg in pyfuncitem._fixtureinfo.argnames])
mp.spawn(wrap, (testfunction, testargs))
return True