lightning/tests/tests_lite/strategies/launchers/test_xla.py

38 lines
1.2 KiB
Python

from unittest import mock
from unittest.mock import Mock
from tests_lite.helpers.runif import RunIf
from lightning_lite.strategies.launchers.xla import _XLALauncher
@RunIf(skip_windows=True)
def test_xla_launcher_default_start_method(xla_available):
launcher = _XLALauncher(strategy=Mock())
assert launcher._start_method == "fork"
@RunIf(skip_windows=True)
def test_xla_launcher_interactive_compatible(xla_available):
launcher = _XLALauncher(strategy=Mock())
assert launcher.is_interactive_compatible
@RunIf(skip_windows=True, tpu=True)
@mock.patch("torch_xla.distributed.xla_multiprocessing")
@mock.patch("lightning_lite.strategies.launchers.xla.get_context")
def test_xla_launcher_xmp_spawn(get_context_mock, xmp_mock):
strategy = Mock()
launcher = _XLALauncher(strategy=strategy)
function = Mock()
launcher.launch(function, "positional-arg", keyword_arg=0)
queue = get_context_mock.return_value.SimpleQueue.return_value
get_context_mock.assert_called_with("fork")
xmp_mock.spawn.assert_called_with(
launcher._wrapping_function,
args=(function, ("positional-arg",), {"keyword_arg": 0}, queue),
nprocs=strategy.num_processes,
start_method="fork",
)
queue.get.assert_called_once_with()