Standalone Lite: Launchers (#14555)

Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>
This commit is contained in:
Adrian Wälchli 2022-09-12 16:15:42 +02:00 committed by GitHub
parent d8fe0cf9b5
commit 8f0a64dab6
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
17 changed files with 689 additions and 15 deletions

View File

@ -0,0 +1,178 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
from dataclasses import dataclass
from multiprocessing.queues import SimpleQueue
from typing import Any, Callable, Dict, Optional
import torch
import torch.backends.cudnn
import torch.multiprocessing as mp
from typing_extensions import Literal
from lightning_lite.strategies.launchers.base import _Launcher
from lightning_lite.utilities.apply_func import move_data_to_device
from lightning_lite.utilities.imports import _TORCH_GREATER_EQUAL_1_11
from lightning_lite.utilities.seed import _collect_rng_states, _set_rng_states
class _MultiProcessingLauncher(_Launcher):
r"""Launches processes that run a given function in parallel, and joins them all at the end.
The main process in which this launcher is invoked creates N so-called worker processes (using
:func:`torch.multiprocessing.start_processes`) that run the given function.
Worker processes have a rank that ranges from 0 to N - 1.
Note:
- This launcher requires all objects to be pickleable.
- It is important that the entry point to the program/script is guarded by ``if __name__ == "__main__"``.
- With start method 'fork' the user must ensure that no CUDA context gets created in the main process before
the launcher is invoked. E.g., one should avoid creating cuda tensors or calling ``torch.cuda.*`` functions
before calling ``Trainer.fit``.
Args:
strategy: A reference to the strategy that is used together with this launcher.
start_method: The method how to start the processes.
- 'spawn': The default start method. Requires all objects to be pickleable.
- 'fork': Preferrable for IPython/Jupyter environments where 'spawn' is not available. Not available on
the Windows platform for example.
- 'forkserver': Alternative implementation to 'fork'.
"""
def __init__(
self,
# TODO(lite): Fix this type annotation once the strategy base class gets added to Lite
strategy: "Strategy", # type: ignore[name-defined] # noqa: F821
start_method: Literal["spawn", "fork", "forkserver"] = "spawn",
) -> None:
self._strategy = strategy
self._start_method = start_method
if start_method not in mp.get_all_start_methods():
raise ValueError(
f"The start method '{self._start_method}' is not available on this platform. Available methods are:"
f" {', '.join(mp.get_all_start_methods())}"
)
if start_method in ("fork", "forkserver") and _is_forking_disabled():
raise ValueError(
"Forking is disabled in this environment by `PL_DISABLE_FORKING=1`. Choose a different start method."
)
@property
def is_interactive_compatible(self) -> bool:
# The start method 'spawn' is not supported in interactive environments
# The start method 'fork' is the only one supported in Jupyter environments, with constraints around CUDA
# initialization. For more context, see https://github.com/Lightning-AI/lightning/issues/7550
return self._start_method == "fork"
def launch(self, function: Callable, *args: Any, **kwargs: Any) -> Any:
"""Launches processes that run the given function in parallel.
The function is allowed to have a return value. However, when all processes join, only the return value
of worker process 0 gets returned from this `launch` method in the main process.
Arguments:
function: The entry point for all launched processes.
*args: Optional positional arguments to be passed to the given function.
**kwargs: Optional keyword arguments to be passed to the given function.
"""
# The default cluster environment in Lightning chooses a random free port number
# This needs to be done in the main process here before starting processes to ensure each rank will connect
# through the same port
os.environ["MASTER_PORT"] = str(self._strategy.cluster_environment.main_port)
context = mp.get_context(self._start_method)
return_queue = context.SimpleQueue()
if self._start_method == "spawn":
global_states = _GlobalStateSnapshot.capture()
process_args = [function, args, kwargs, return_queue, global_states]
else:
process_args = [function, args, kwargs, return_queue]
mp.start_processes(
self._wrapping_function,
args=process_args,
nprocs=self._strategy.num_processes,
start_method=self._start_method,
)
return return_queue.get()
def _wrapping_function(
self,
process_idx: int,
function: Callable,
args: Any,
kwargs: Any,
return_queue: SimpleQueue,
global_states: Optional["_GlobalStateSnapshot"] = None,
) -> None:
if global_states:
global_states.restore()
# TODO(lite): Update worker setup once DDPSpawn strategy is in Lite
self._strategy._worker_setup(process_idx)
results = function(*args, **kwargs)
if self._strategy.local_rank == 0:
return_queue.put(move_data_to_device(results, "cpu"))
@dataclass
class _GlobalStateSnapshot:
"""Captures a hand-selected set of (global) variables in modules and provides a way to restore them.
It facilitates and encapsulates the transfer of globals like PyTorch's deterministic flags or random generator state
across process boundaries when launching processes with :func:`torch.multiprocessing.spawn`.
Example:
.. code-block:: python
# in main process
snapshot = _GlobalStateSnapshot.capture()
# in worker process
snapshot.restore()
"""
use_deterministic_algorithms: bool
use_deterministic_algorithms_warn_only: bool
cudnn_benchmark: bool
rng_states: Dict[str, Any]
@classmethod
def capture(cls) -> "_GlobalStateSnapshot":
"""Capture a few global states from torch, numpy, etc., that we want to restore in a spawned worker
process."""
warn_only = torch.is_deterministic_algorithms_warn_only_enabled() if _TORCH_GREATER_EQUAL_1_11 else False
return cls(
use_deterministic_algorithms=torch.are_deterministic_algorithms_enabled(),
use_deterministic_algorithms_warn_only=warn_only,
cudnn_benchmark=torch.backends.cudnn.benchmark,
rng_states=_collect_rng_states(),
)
def restore(self) -> None:
"""Restores all globals to the values captured in the :meth:`capture` method."""
if _TORCH_GREATER_EQUAL_1_11:
torch.use_deterministic_algorithms(
self.use_deterministic_algorithms, warn_only=self.use_deterministic_algorithms_warn_only
)
else:
torch.use_deterministic_algorithms(self.use_deterministic_algorithms)
torch.backends.cudnn.benchmark = self.cudnn_benchmark
_set_rng_states(self.rng_states)
def _is_forking_disabled() -> bool:
"""Returns whether forking is disabled through the environment variable ``PL_DISABLE_FORK``."""
return bool(int(os.environ.get("PL_DISABLE_FORK", "0")))

View File

@ -0,0 +1,167 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import subprocess
import sys
from time import sleep
from typing import Any, Callable, Optional
import __main__
import numpy as np
from lightning_utilities.core.imports import RequirementCache
from lightning_lite.strategies.launchers.base import _Launcher
_HYDRA_AVAILABLE = RequirementCache("hydra")
class _SubprocessScriptLauncher(_Launcher):
r"""
A process laucher that invokes the current script as many times as desired in a single node.
This launcher needs to be invoked on each node.
In its default behavior, the main process in each node then spawns N-1 child processes via :func:`subprocess.Popen`,
where N is the number of devices (e.g. GPU) per node. It is very similar to how :mod:`torch.distributed.run`
launches processes.
For example, if the script gets invoked with the command
.. code-block:: bash
python train.py --devices 4
The launcher will create three additional subprocesses that get called like so:
.. code-block:: bash
LOCAL_RANK=1 python train.py --devices 4
LOCAL_RANK=2 python train.py --devices 4
LOCAL_RANK=3 python train.py --devices 4
It is implied that the main process which launched the others has ``LOCAL_RANK=0``.
Beside the local rank, the following other environment variables also get set, but unlike the local rank, these
get determined by the cluster environment:
1. `MASTER_ADDR`: The IP address of the main node.
2. `MASTER_PORT`: The port number of the main node through which all processes communicate.
3. `NODE_RANK`: The index of the node the current process is running on. Ranges from 0 to ``num_nodes - 1``.
4. `WORLD_SIZE`: The total number of processes across all nodes, i.e., ``num_processes * num_nodes``.
Arguments:
cluster_environment: A cluster environment that provides access to world size, node rank, etc.
num_processes: The number of processes to launch in the current node.
num_nodes: The total number of nodes that participate in this process group.
"""
def __init__(
self,
# TODO(lite): Update type annotation once ClusterEnvironment has moved to Lite
cluster_environment: "ClusterEnvironment", # type: ignore[name-defined] # noqa: F821
num_processes: int,
num_nodes: int,
) -> None:
super().__init__()
self.cluster_environment = cluster_environment
self.num_processes = num_processes
self.num_nodes = num_nodes
@property
def is_interactive_compatible(self) -> bool:
return False
def launch(self, function: Callable, *args: Any, **kwargs: Any) -> Any:
"""Creates new processes, then calls the given function.
Arguments:
function: A callback function to execute after all processes have been created.
It is up to the implementation of this function to synchronize the processes, e.g., with barriers.
*args: Optional positional arguments to be passed to the given function.
**kwargs: Optional keyword arguments to be passed to the given function.
"""
if not self.cluster_environment.creates_processes_externally:
self._call_children_scripts()
return function(*args, **kwargs)
def _call_children_scripts(self) -> None:
# bookkeeping of spawned processes
self._check_can_spawn_children()
# DDP Environment variables
os.environ["MASTER_ADDR"] = self.cluster_environment.main_address
os.environ["MASTER_PORT"] = str(self.cluster_environment.main_port)
# allow the user to pass the node rank
os.environ["NODE_RANK"] = str(self.cluster_environment.node_rank())
os.environ["LOCAL_RANK"] = str(self.cluster_environment.local_rank())
# Check if the current calling command looked like `python a/b/c.py` or `python -m a.b.c`
# See https://docs.python.org/3/reference/import.html#main-spec
if __main__.__spec__ is None: # pragma: no-cover
# Script called as `python a/b/c.py`
if _HYDRA_AVAILABLE:
# when user is using hydra find the absolute path
from hydra.utils import to_absolute_path
to_abs_path = to_absolute_path
else:
to_abs_path = os.path.abspath
# pull out the commands used to run the script and resolve the absolute file path
command = sys.argv
try:
full_path = to_abs_path(command[0])
except Exception:
full_path = os.path.abspath(command[0])
command[0] = full_path
# use the same python interpreter and actually running
command = [sys.executable] + command
else: # Script called as `python -m a.b.c`
command = [sys.executable, "-m", __main__.__spec__.name] + sys.argv[1:]
os.environ["WORLD_SIZE"] = f"{self.num_processes * self.num_nodes}"
for local_rank in range(1, self.num_processes):
env_copy = os.environ.copy()
env_copy["LOCAL_RANK"] = f"{local_rank}"
# remove env var if global seed not set
if os.environ.get("PL_GLOBAL_SEED") is None and "PL_GLOBAL_SEED" in env_copy:
del env_copy["PL_GLOBAL_SEED"]
# start process
# if hydra is available and initialized, make sure to set the cwd correctly
cwd: Optional[str] = None
if _HYDRA_AVAILABLE:
from hydra.core.hydra_config import HydraConfig
from hydra.utils import get_original_cwd
if HydraConfig.initialized():
cwd = get_original_cwd()
os_cwd = f'"{os.getcwd()}"'
command += [f"hydra.run.dir={os_cwd}", f"hydra.job.name=train_ddp_process_{local_rank}"]
subprocess.Popen(command, env=env_copy, cwd=cwd)
# starting all processes at once can cause issues
# with dataloaders delay between 1-10 seconds
delay = np.random.uniform(1, 5, 1)[0]
sleep(delay)
def _check_can_spawn_children(self) -> None:
if self.cluster_environment.local_rank() != 0:
raise RuntimeError(
"Lightning attempted to launch new distributed processes with `local_rank > 0`. This should not happen."
" Possible reasons: 1) LOCAL_RANK environment variable was incorrectly modified by the user,"
" 2) `ClusterEnvironment.creates_processes_externally` incorrectly implemented."
)

View File

@ -0,0 +1,121 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import time
from functools import wraps
from multiprocessing.queues import SimpleQueue
from typing import Any, Callable, Optional, Tuple, TYPE_CHECKING
import torch.multiprocessing as mp
from torch.multiprocessing import ProcessContext
from lightning_lite.strategies.launchers.multiprocessing import _GlobalStateSnapshot, _MultiProcessingLauncher
from lightning_lite.utilities import _TPU_AVAILABLE
from lightning_lite.utilities.apply_func import move_data_to_device
if _TPU_AVAILABLE:
import torch_xla.core.xla_model as xm
import torch_xla.distributed.xla_multiprocessing as xmp
else:
xm, xmp = None, None
if TYPE_CHECKING:
from lightning_lite.strategies import Strategy
class _XLALauncher(_MultiProcessingLauncher):
r"""Launches processes that run a given function in parallel on XLA supported hardware, and joins them all at the
end.
The main process in which this launcher is invoked creates N so-called worker processes (using the
`torch_xla` :func:`xmp.spawn`) that run the given function.
Worker processes have a rank that ranges from 0 to N - 1.
Note:
- This launcher requires all objects to be pickleable.
- It is important that the entry point to the program/script is guarded by ``if __name__ == "__main__"``.
Args:
strategy: A reference to the strategy that is used together with this launcher
"""
def __init__(self, strategy: "Strategy") -> None:
super().__init__(strategy=strategy, start_method="fork")
@property
def is_interactive_compatible(self) -> bool:
return True
def launch(self, function: Callable, *args: Any, **kwargs: Any) -> Any:
"""Launches processes that run the given function in parallel.
The function is allowed to have a return value. However, when all processes join, only the return value
of worker process 0 gets returned from this `launch` method in the main process.
Arguments:
function: The entry point for all launched processes.
*args: Optional positional arguments to be passed to the given function.
**kwargs: Optional keyword arguments to be passed to the given function.
"""
context = mp.get_context(self._start_method)
return_queue = context.SimpleQueue()
_save_spawn(
self._wrapping_function,
args=(function, args, kwargs, return_queue),
nprocs=len(self._strategy.parallel_devices),
start_method=self._start_method,
)
return return_queue.get()
def _wrapping_function(
self,
process_idx: int,
function: Callable,
args: Any,
kwargs: Any,
return_queue: SimpleQueue,
global_states: Optional[_GlobalStateSnapshot] = None,
) -> None:
# TODO(lite): Update worker setup once TPUSpawn strategy is in Lite
self._strategy._worker_setup(process_idx)
results = function(*args, **kwargs)
if self._strategy.local_rank == 0:
return_queue.put(move_data_to_device(results, "cpu"))
def _save_spawn(
fn: Callable,
args: Tuple = (),
nprocs: Optional[int] = None,
join: bool = True,
daemon: bool = False,
start_method: str = "spawn",
) -> Optional[ProcessContext]:
"""Wraps the :func:`torch_xla.distributed.xla_multiprocessing.spawn` with added teardown logic for the worker
processes."""
@wraps(fn)
def wrapped(rank: int, *_args: Any) -> None:
fn(rank, *_args)
# Make all processes wait for each other before joining
# https://github.com/pytorch/xla/issues/1801#issuecomment-602799542
xm.rendezvous("end-process")
# Ensure that the rank 0 process is the one exiting last
# https://github.com/pytorch/xla/issues/2190#issuecomment-641665358
if rank == 0:
time.sleep(1)
return xmp.spawn(wrapped, args=args, nprocs=nprocs, join=join, daemon=daemon, start_method=start_method)

View File

@ -1,13 +1,10 @@
import multiprocessing
import os
from typing import Any, List, MutableSequence, Optional, Tuple, Union
import torch
from lightning_lite.plugins.environments.torchelastic_environment import TorchElasticEnvironment
# TODO(lite): Fix the imports
# from lightning_lite.strategies.launchers.multiprocessing import _is_forking_disabled
from lightning_lite.strategies.launchers.multiprocessing import _is_forking_disabled
from lightning_lite.utilities.exceptions import MisconfigurationException
from lightning_lite.utilities.types import _DEVICE
@ -309,9 +306,3 @@ def is_cuda_available() -> bool:
return torch.cuda.is_available()
with multiprocessing.get_context("fork").Pool(1) as pool:
return pool.apply(torch.cuda.is_available)
# TODO(lite): move this back to launchers/multiprocessing.py once launchers have moved
def _is_forking_disabled() -> bool:
"""Returns whether forking is disabled through the environment variable ``PL_DISABLE_FORK``."""
return bool(int(os.environ.get("PL_DISABLE_FORK", "0")))

View File

@ -11,13 +11,11 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from pytorch_lightning.strategies.launchers.base import _Launcher
from pytorch_lightning.strategies.launchers.multiprocessing import _MultiProcessingLauncher
from pytorch_lightning.strategies.launchers.subprocess_script import _SubprocessScriptLauncher
from pytorch_lightning.strategies.launchers.xla import _XLALauncher
__all__ = [
"_Launcher",
"_MultiProcessingLauncher",
"_SubprocessScriptLauncher",
"_XLALauncher",

View File

@ -26,10 +26,10 @@ from torch import Tensor
from typing_extensions import Literal
import pytorch_lightning as pl
from lightning_lite.strategies.launchers.base import _Launcher
from lightning_lite.utilities.apply_func import move_data_to_device
from lightning_lite.utilities.seed import _collect_rng_states, _set_rng_states
from lightning_lite.utilities.types import _PATH
from pytorch_lightning.strategies.launchers.base import _Launcher
from pytorch_lightning.trainer.states import TrainerFn, TrainerState
from pytorch_lightning.utilities.imports import _TORCH_GREATER_EQUAL_1_11
from pytorch_lightning.utilities.rank_zero import rank_zero_debug

View File

@ -23,7 +23,7 @@ from lightning_utilities.core.imports import RequirementCache
import pytorch_lightning as pl
from lightning_lite.plugins.environments.cluster_environment import ClusterEnvironment
from pytorch_lightning.strategies.launchers.base import _Launcher
from lightning_lite.strategies.launchers.base import _Launcher
_HYDRA_AVAILABLE = RequirementCache("hydra")

View File

@ -24,6 +24,7 @@ from torch.utils.data import DataLoader
import pytorch_lightning as pl
from lightning_lite.plugins.io.checkpoint_plugin import CheckpointIO
from lightning_lite.strategies.launchers.base import _Launcher
from lightning_lite.utilities.apply_func import move_data_to_device
from lightning_lite.utilities.distributed import ReduceOp
from lightning_lite.utilities.optimizer import optimizer_to_device, optimizers_to_device
@ -32,7 +33,6 @@ from pytorch_lightning.core.optimizer import _init_optimizers_and_lr_schedulers,
from pytorch_lightning.plugins import TorchCheckpointIO
from pytorch_lightning.plugins.io.wrapper import _WrappingCheckpointIO
from pytorch_lightning.plugins.precision import PrecisionPlugin
from pytorch_lightning.strategies.launchers.base import _Launcher
from pytorch_lightning.trainer.states import TrainerFn
from pytorch_lightning.utilities.types import (
LRSchedulerConfig,

View File

@ -76,6 +76,13 @@ def teardown_process_group():
torch.distributed.destroy_process_group()
@pytest.fixture
def reset_deterministic_algorithm():
"""Ensures that torch determinism settings are reset before the next test runs."""
yield
torch.use_deterministic_algorithms(False)
@pytest.fixture
def caplog(caplog):
"""Workaround for https://github.com/pytest-dev/pytest/issues/3697.

View File

View File

@ -0,0 +1,95 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
from unittest import mock
from unittest.mock import ANY, Mock
import pytest
import torch
from lightning_lite.strategies.launchers.multiprocessing import _GlobalStateSnapshot, _MultiProcessingLauncher
from tests_pytorch.helpers.runif import RunIf
@RunIf(skip_windows=True)
@pytest.mark.parametrize("start_method", ["fork", "forkserver"])
def test_multiprocessing_launcher_interactive_compatible(start_method):
launcher = _MultiProcessingLauncher(strategy=Mock(), start_method=start_method)
assert launcher.is_interactive_compatible == (start_method == "fork")
@mock.patch("lightning_lite.strategies.launchers.multiprocessing.mp.get_all_start_methods", return_value=[])
def test_multiprocessing_launcher_forking_on_unsupported_platform(_):
with pytest.raises(ValueError, match="The start method 'fork' is not available on this platform"):
_MultiProcessingLauncher(strategy=Mock(), start_method="fork")
@RunIf(skip_windows=True)
@pytest.mark.parametrize("start_method", ["fork", "forkserver"])
@mock.patch.dict(os.environ, {"PL_DISABLE_FORK": "1"}, clear=True)
def test_multiprocessing_launcher_disabled_forking(start_method):
with pytest.raises(ValueError, match="Forking is disabled in this environment"):
_MultiProcessingLauncher(strategy=Mock(), start_method=start_method)
@pytest.mark.parametrize("start_method", ["spawn", "fork"])
@mock.patch("lightning_lite.strategies.launchers.multiprocessing.mp")
def test_multiprocessing_launcher_start_method(mp_mock, start_method):
mp_mock.get_all_start_methods.return_value = [start_method]
launcher = _MultiProcessingLauncher(strategy=Mock(), start_method=start_method)
launcher.launch(function=Mock())
mp_mock.get_context.assert_called_with(start_method)
mp_mock.start_processes.assert_called_with(
ANY,
args=ANY,
nprocs=ANY,
start_method=start_method,
)
@pytest.mark.parametrize("start_method", ["spawn", "fork"])
@mock.patch("lightning_lite.strategies.launchers.multiprocessing.mp")
def test_multiprocessing_launcher_restore_globals(mp_mock, start_method):
"""Test that we pass the global state snapshot to the worker function only if we are starting with 'spawn'."""
mp_mock.get_all_start_methods.return_value = [start_method]
launcher = _MultiProcessingLauncher(strategy=Mock(), start_method=start_method)
launcher.launch(function=Mock())
function_args = mp_mock.start_processes.call_args[1]["args"]
if start_method == "spawn":
assert len(function_args) == 5
assert isinstance(function_args[4], _GlobalStateSnapshot)
else:
assert len(function_args) == 4
@pytest.mark.usefixtures("reset_deterministic_algorithm")
def test_global_state_snapshot():
"""Test the capture() and restore() methods for the global state snapshot."""
torch.use_deterministic_algorithms(True)
torch.backends.cudnn.benchmark = False
torch.manual_seed(123)
# capture the state of globals
snapshot = _GlobalStateSnapshot.capture()
# simulate there is a process boundary and flags get reset here
torch.use_deterministic_algorithms(False)
torch.backends.cudnn.benchmark = True
torch.manual_seed(321)
# restore the state of globals
snapshot.restore()
assert torch.are_deterministic_algorithms_enabled()
assert not torch.backends.cudnn.benchmark
assert torch.initial_seed() == 123

View File

@ -0,0 +1,78 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
from unittest import mock
from unittest.mock import Mock
import pytest
from lightning_lite.strategies.launchers.subprocess_script import _SubprocessScriptLauncher
def test_subprocess_script_launcher_interactive_compatible():
launcher = _SubprocessScriptLauncher(Mock(), num_processes=2, num_nodes=1)
assert not launcher.is_interactive_compatible
@mock.patch("lightning_lite.strategies.launchers.subprocess_script.subprocess.Popen")
def test_subprocess_script_launcher_error_launching_on_non_zero_rank(popen_mock):
cluster_env = Mock()
cluster_env.creates_processes_externally = False
cluster_env.local_rank.return_value = 1
launcher = _SubprocessScriptLauncher(cluster_env, num_processes=2, num_nodes=1)
with pytest.raises(RuntimeError, match="attempted to launch new distributed processes with `local_rank > 0`"):
launcher.launch(Mock())
@mock.patch("lightning_lite.strategies.launchers.subprocess_script.subprocess.Popen")
def test_subprocess_script_launcher_external_processes(popen_mock):
cluster_env = Mock()
cluster_env.creates_processes_externally = True
function = Mock()
launcher = _SubprocessScriptLauncher(cluster_env, num_processes=4, num_nodes=2)
launcher.launch(function, "positional-arg", keyword_arg=0)
function.assert_called_with("positional-arg", keyword_arg=0)
popen_mock.assert_not_called()
@mock.patch("lightning_lite.strategies.launchers.subprocess_script.sleep")
@mock.patch("lightning_lite.strategies.launchers.subprocess_script.subprocess.Popen")
def test_subprocess_script_launcher_launch_processes(popen_mock, _):
cluster_env = Mock()
cluster_env.creates_processes_externally = False
cluster_env.local_rank.return_value = 0
cluster_env.main_address = "address"
cluster_env.main_port = 1234
function = Mock()
launcher = _SubprocessScriptLauncher(cluster_env, num_processes=4, num_nodes=2)
num_new_processes = launcher.num_processes - 1
# launches n-1 new processes, the current one will participate too
launcher.launch(function, "positional-arg", keyword_arg=0)
calls = popen_mock.call_args_list
assert len(calls) == num_new_processes
# world size in child processes
world_sizes = [int(calls[i][1]["env"]["WORLD_SIZE"]) for i in range(num_new_processes)]
assert world_sizes == [launcher.num_processes * launcher.num_nodes] * num_new_processes
# local rank in child processes
local_ranks = [int(calls[i][1]["env"]["LOCAL_RANK"]) for i in range(num_new_processes)]
assert local_ranks == list(range(1, num_new_processes + 1))
# the current process
assert int(os.environ["WORLD_SIZE"]) == launcher.num_processes * launcher.num_nodes
assert int(os.environ["LOCAL_RANK"]) == 0

View File

@ -0,0 +1,39 @@
from unittest import mock
from unittest.mock import ANY, Mock
from tests_lite.helpers.runif import RunIf
from lightning_lite.strategies.launchers.xla import _XLALauncher
@RunIf(skip_windows=True)
def test_xla_launcher_default_start_method():
launcher = _XLALauncher(strategy=Mock())
assert launcher._start_method == "fork"
@RunIf(skip_windows=True)
def test_xla_launcher_interactive_compatible():
launcher = _XLALauncher(strategy=Mock())
assert launcher.is_interactive_compatible
@RunIf(skip_windows=True)
@mock.patch("lightning_lite.strategies.launchers.xla.mp")
@mock.patch("lightning_lite.strategies.launchers.xla.xm")
@mock.patch("lightning_lite.strategies.launchers.xla.xmp")
def test_xla_launcher_xmp_spawn(xmp_mock, xm_mock, mp_mock):
strategy = Mock()
strategy.parallel_devices = [0, 1, 2, 3]
launcher = _XLALauncher(strategy=strategy)
function = Mock()
launcher.launch(function, "positional-arg", keyword_arg=0)
# mp_mock.get_context.assert_called_with(start_method)
xmp_mock.spawn.assert_called_with(
ANY,
args=(function, ("positional-arg",), {"keyword_arg": 0}, ANY),
nprocs=4,
join=True,
daemon=False,
start_method="fork",
)