21 lines
568 B
Python
21 lines
568 B
Python
from functools import wraps
|
|
|
|
import pytest
|
|
|
|
import torch.multiprocessing as mp
|
|
|
|
|
|
def pytest_configure(config):
|
|
config.addinivalue_line("markers", "spawn: spawn test in a separate process using torch.multiprocessing.spawn")
|
|
|
|
|
|
@pytest.mark.tryfirst
|
|
def pytest_pyfunc_call(pyfuncitem):
|
|
if pyfuncitem.get_closest_marker("spawn"):
|
|
testfunction = pyfuncitem.obj
|
|
funcargs = pyfuncitem.funcargs
|
|
testargs = tuple([funcargs[arg] for arg in pyfuncitem._fixtureinfo.argnames])
|
|
|
|
mp.spawn(wraps, (testfunction, testargs))
|
|
return True
|