23 lines
584 B
Python
23 lines
584 B
Python
|
import pytest
|
||
|
|
||
|
import torch.multiprocessing as mp
|
||
|
|
||
|
|
||
|
def pytest_configure(config):
|
||
|
config.addinivalue_line("markers", "spawn: spawn test in a separate process using torch.multiprocessing.spawn")
|
||
|
|
||
|
|
||
|
def wrap(i, fn, args):
|
||
|
return fn(*args)
|
||
|
|
||
|
|
||
|
@pytest.mark.tryfirst
|
||
|
def pytest_pyfunc_call(pyfuncitem):
|
||
|
if pyfuncitem.get_closest_marker("spawn"):
|
||
|
testfunction = pyfuncitem.obj
|
||
|
funcargs = pyfuncitem.funcargs
|
||
|
testargs = tuple([funcargs[arg] for arg in pyfuncitem._fixtureinfo.argnames])
|
||
|
|
||
|
mp.spawn(wrap, (testfunction, testargs))
|
||
|
return True
|