Standalone Lite: Launchers (#14555)
Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>
This commit is contained in:
parent
d8fe0cf9b5
commit
8f0a64dab6
|
@ -0,0 +1,178 @@
|
|||
# Copyright The PyTorch Lightning team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import os
|
||||
from dataclasses import dataclass
|
||||
from multiprocessing.queues import SimpleQueue
|
||||
from typing import Any, Callable, Dict, Optional
|
||||
|
||||
import torch
|
||||
import torch.backends.cudnn
|
||||
import torch.multiprocessing as mp
|
||||
from typing_extensions import Literal
|
||||
|
||||
from lightning_lite.strategies.launchers.base import _Launcher
|
||||
from lightning_lite.utilities.apply_func import move_data_to_device
|
||||
from lightning_lite.utilities.imports import _TORCH_GREATER_EQUAL_1_11
|
||||
from lightning_lite.utilities.seed import _collect_rng_states, _set_rng_states
|
||||
|
||||
|
||||
class _MultiProcessingLauncher(_Launcher):
|
||||
r"""Launches processes that run a given function in parallel, and joins them all at the end.
|
||||
|
||||
The main process in which this launcher is invoked creates N so-called worker processes (using
|
||||
:func:`torch.multiprocessing.start_processes`) that run the given function.
|
||||
Worker processes have a rank that ranges from 0 to N - 1.
|
||||
|
||||
Note:
|
||||
- This launcher requires all objects to be pickleable.
|
||||
- It is important that the entry point to the program/script is guarded by ``if __name__ == "__main__"``.
|
||||
- With start method 'fork' the user must ensure that no CUDA context gets created in the main process before
|
||||
the launcher is invoked. E.g., one should avoid creating cuda tensors or calling ``torch.cuda.*`` functions
|
||||
before calling ``Trainer.fit``.
|
||||
|
||||
Args:
|
||||
strategy: A reference to the strategy that is used together with this launcher.
|
||||
start_method: The method how to start the processes.
|
||||
- 'spawn': The default start method. Requires all objects to be pickleable.
|
||||
- 'fork': Preferrable for IPython/Jupyter environments where 'spawn' is not available. Not available on
|
||||
the Windows platform for example.
|
||||
- 'forkserver': Alternative implementation to 'fork'.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
# TODO(lite): Fix this type annotation once the strategy base class gets added to Lite
|
||||
strategy: "Strategy", # type: ignore[name-defined] # noqa: F821
|
||||
start_method: Literal["spawn", "fork", "forkserver"] = "spawn",
|
||||
) -> None:
|
||||
self._strategy = strategy
|
||||
self._start_method = start_method
|
||||
if start_method not in mp.get_all_start_methods():
|
||||
raise ValueError(
|
||||
f"The start method '{self._start_method}' is not available on this platform. Available methods are:"
|
||||
f" {', '.join(mp.get_all_start_methods())}"
|
||||
)
|
||||
if start_method in ("fork", "forkserver") and _is_forking_disabled():
|
||||
raise ValueError(
|
||||
"Forking is disabled in this environment by `PL_DISABLE_FORKING=1`. Choose a different start method."
|
||||
)
|
||||
|
||||
@property
|
||||
def is_interactive_compatible(self) -> bool:
|
||||
# The start method 'spawn' is not supported in interactive environments
|
||||
# The start method 'fork' is the only one supported in Jupyter environments, with constraints around CUDA
|
||||
# initialization. For more context, see https://github.com/Lightning-AI/lightning/issues/7550
|
||||
return self._start_method == "fork"
|
||||
|
||||
def launch(self, function: Callable, *args: Any, **kwargs: Any) -> Any:
|
||||
"""Launches processes that run the given function in parallel.
|
||||
|
||||
The function is allowed to have a return value. However, when all processes join, only the return value
|
||||
of worker process 0 gets returned from this `launch` method in the main process.
|
||||
|
||||
Arguments:
|
||||
function: The entry point for all launched processes.
|
||||
*args: Optional positional arguments to be passed to the given function.
|
||||
**kwargs: Optional keyword arguments to be passed to the given function.
|
||||
"""
|
||||
# The default cluster environment in Lightning chooses a random free port number
|
||||
# This needs to be done in the main process here before starting processes to ensure each rank will connect
|
||||
# through the same port
|
||||
os.environ["MASTER_PORT"] = str(self._strategy.cluster_environment.main_port)
|
||||
context = mp.get_context(self._start_method)
|
||||
return_queue = context.SimpleQueue()
|
||||
|
||||
if self._start_method == "spawn":
|
||||
global_states = _GlobalStateSnapshot.capture()
|
||||
process_args = [function, args, kwargs, return_queue, global_states]
|
||||
else:
|
||||
process_args = [function, args, kwargs, return_queue]
|
||||
|
||||
mp.start_processes(
|
||||
self._wrapping_function,
|
||||
args=process_args,
|
||||
nprocs=self._strategy.num_processes,
|
||||
start_method=self._start_method,
|
||||
)
|
||||
return return_queue.get()
|
||||
|
||||
def _wrapping_function(
|
||||
self,
|
||||
process_idx: int,
|
||||
function: Callable,
|
||||
args: Any,
|
||||
kwargs: Any,
|
||||
return_queue: SimpleQueue,
|
||||
global_states: Optional["_GlobalStateSnapshot"] = None,
|
||||
) -> None:
|
||||
if global_states:
|
||||
global_states.restore()
|
||||
# TODO(lite): Update worker setup once DDPSpawn strategy is in Lite
|
||||
self._strategy._worker_setup(process_idx)
|
||||
results = function(*args, **kwargs)
|
||||
|
||||
if self._strategy.local_rank == 0:
|
||||
return_queue.put(move_data_to_device(results, "cpu"))
|
||||
|
||||
|
||||
@dataclass
|
||||
class _GlobalStateSnapshot:
|
||||
"""Captures a hand-selected set of (global) variables in modules and provides a way to restore them.
|
||||
|
||||
It facilitates and encapsulates the transfer of globals like PyTorch's deterministic flags or random generator state
|
||||
across process boundaries when launching processes with :func:`torch.multiprocessing.spawn`.
|
||||
|
||||
Example:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
# in main process
|
||||
snapshot = _GlobalStateSnapshot.capture()
|
||||
|
||||
# in worker process
|
||||
snapshot.restore()
|
||||
"""
|
||||
|
||||
use_deterministic_algorithms: bool
|
||||
use_deterministic_algorithms_warn_only: bool
|
||||
cudnn_benchmark: bool
|
||||
rng_states: Dict[str, Any]
|
||||
|
||||
@classmethod
|
||||
def capture(cls) -> "_GlobalStateSnapshot":
|
||||
"""Capture a few global states from torch, numpy, etc., that we want to restore in a spawned worker
|
||||
process."""
|
||||
warn_only = torch.is_deterministic_algorithms_warn_only_enabled() if _TORCH_GREATER_EQUAL_1_11 else False
|
||||
return cls(
|
||||
use_deterministic_algorithms=torch.are_deterministic_algorithms_enabled(),
|
||||
use_deterministic_algorithms_warn_only=warn_only,
|
||||
cudnn_benchmark=torch.backends.cudnn.benchmark,
|
||||
rng_states=_collect_rng_states(),
|
||||
)
|
||||
|
||||
def restore(self) -> None:
|
||||
"""Restores all globals to the values captured in the :meth:`capture` method."""
|
||||
if _TORCH_GREATER_EQUAL_1_11:
|
||||
torch.use_deterministic_algorithms(
|
||||
self.use_deterministic_algorithms, warn_only=self.use_deterministic_algorithms_warn_only
|
||||
)
|
||||
else:
|
||||
torch.use_deterministic_algorithms(self.use_deterministic_algorithms)
|
||||
torch.backends.cudnn.benchmark = self.cudnn_benchmark
|
||||
_set_rng_states(self.rng_states)
|
||||
|
||||
|
||||
def _is_forking_disabled() -> bool:
|
||||
"""Returns whether forking is disabled through the environment variable ``PL_DISABLE_FORK``."""
|
||||
return bool(int(os.environ.get("PL_DISABLE_FORK", "0")))
|
|
@ -0,0 +1,167 @@
|
|||
# Copyright The PyTorch Lightning team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import os
|
||||
import subprocess
|
||||
import sys
|
||||
from time import sleep
|
||||
from typing import Any, Callable, Optional
|
||||
|
||||
import __main__
|
||||
import numpy as np
|
||||
from lightning_utilities.core.imports import RequirementCache
|
||||
|
||||
from lightning_lite.strategies.launchers.base import _Launcher
|
||||
|
||||
_HYDRA_AVAILABLE = RequirementCache("hydra")
|
||||
|
||||
|
||||
class _SubprocessScriptLauncher(_Launcher):
|
||||
r"""
|
||||
A process laucher that invokes the current script as many times as desired in a single node.
|
||||
|
||||
This launcher needs to be invoked on each node.
|
||||
In its default behavior, the main process in each node then spawns N-1 child processes via :func:`subprocess.Popen`,
|
||||
where N is the number of devices (e.g. GPU) per node. It is very similar to how :mod:`torch.distributed.run`
|
||||
launches processes.
|
||||
|
||||
For example, if the script gets invoked with the command
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
python train.py --devices 4
|
||||
|
||||
The launcher will create three additional subprocesses that get called like so:
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
LOCAL_RANK=1 python train.py --devices 4
|
||||
LOCAL_RANK=2 python train.py --devices 4
|
||||
LOCAL_RANK=3 python train.py --devices 4
|
||||
|
||||
It is implied that the main process which launched the others has ``LOCAL_RANK=0``.
|
||||
Beside the local rank, the following other environment variables also get set, but unlike the local rank, these
|
||||
get determined by the cluster environment:
|
||||
|
||||
1. `MASTER_ADDR`: The IP address of the main node.
|
||||
2. `MASTER_PORT`: The port number of the main node through which all processes communicate.
|
||||
3. `NODE_RANK`: The index of the node the current process is running on. Ranges from 0 to ``num_nodes - 1``.
|
||||
4. `WORLD_SIZE`: The total number of processes across all nodes, i.e., ``num_processes * num_nodes``.
|
||||
|
||||
Arguments:
|
||||
cluster_environment: A cluster environment that provides access to world size, node rank, etc.
|
||||
num_processes: The number of processes to launch in the current node.
|
||||
num_nodes: The total number of nodes that participate in this process group.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
# TODO(lite): Update type annotation once ClusterEnvironment has moved to Lite
|
||||
cluster_environment: "ClusterEnvironment", # type: ignore[name-defined] # noqa: F821
|
||||
num_processes: int,
|
||||
num_nodes: int,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.cluster_environment = cluster_environment
|
||||
self.num_processes = num_processes
|
||||
self.num_nodes = num_nodes
|
||||
|
||||
@property
|
||||
def is_interactive_compatible(self) -> bool:
|
||||
return False
|
||||
|
||||
def launch(self, function: Callable, *args: Any, **kwargs: Any) -> Any:
|
||||
"""Creates new processes, then calls the given function.
|
||||
|
||||
Arguments:
|
||||
function: A callback function to execute after all processes have been created.
|
||||
It is up to the implementation of this function to synchronize the processes, e.g., with barriers.
|
||||
*args: Optional positional arguments to be passed to the given function.
|
||||
**kwargs: Optional keyword arguments to be passed to the given function.
|
||||
"""
|
||||
if not self.cluster_environment.creates_processes_externally:
|
||||
self._call_children_scripts()
|
||||
return function(*args, **kwargs)
|
||||
|
||||
def _call_children_scripts(self) -> None:
|
||||
# bookkeeping of spawned processes
|
||||
self._check_can_spawn_children()
|
||||
|
||||
# DDP Environment variables
|
||||
os.environ["MASTER_ADDR"] = self.cluster_environment.main_address
|
||||
os.environ["MASTER_PORT"] = str(self.cluster_environment.main_port)
|
||||
|
||||
# allow the user to pass the node rank
|
||||
os.environ["NODE_RANK"] = str(self.cluster_environment.node_rank())
|
||||
os.environ["LOCAL_RANK"] = str(self.cluster_environment.local_rank())
|
||||
|
||||
# Check if the current calling command looked like `python a/b/c.py` or `python -m a.b.c`
|
||||
# See https://docs.python.org/3/reference/import.html#main-spec
|
||||
if __main__.__spec__ is None: # pragma: no-cover
|
||||
# Script called as `python a/b/c.py`
|
||||
if _HYDRA_AVAILABLE:
|
||||
# when user is using hydra find the absolute path
|
||||
from hydra.utils import to_absolute_path
|
||||
|
||||
to_abs_path = to_absolute_path
|
||||
else:
|
||||
to_abs_path = os.path.abspath
|
||||
|
||||
# pull out the commands used to run the script and resolve the absolute file path
|
||||
command = sys.argv
|
||||
try:
|
||||
full_path = to_abs_path(command[0])
|
||||
except Exception:
|
||||
full_path = os.path.abspath(command[0])
|
||||
|
||||
command[0] = full_path
|
||||
# use the same python interpreter and actually running
|
||||
command = [sys.executable] + command
|
||||
else: # Script called as `python -m a.b.c`
|
||||
command = [sys.executable, "-m", __main__.__spec__.name] + sys.argv[1:]
|
||||
|
||||
os.environ["WORLD_SIZE"] = f"{self.num_processes * self.num_nodes}"
|
||||
|
||||
for local_rank in range(1, self.num_processes):
|
||||
env_copy = os.environ.copy()
|
||||
env_copy["LOCAL_RANK"] = f"{local_rank}"
|
||||
|
||||
# remove env var if global seed not set
|
||||
if os.environ.get("PL_GLOBAL_SEED") is None and "PL_GLOBAL_SEED" in env_copy:
|
||||
del env_copy["PL_GLOBAL_SEED"]
|
||||
|
||||
# start process
|
||||
# if hydra is available and initialized, make sure to set the cwd correctly
|
||||
cwd: Optional[str] = None
|
||||
if _HYDRA_AVAILABLE:
|
||||
from hydra.core.hydra_config import HydraConfig
|
||||
from hydra.utils import get_original_cwd
|
||||
|
||||
if HydraConfig.initialized():
|
||||
cwd = get_original_cwd()
|
||||
os_cwd = f'"{os.getcwd()}"'
|
||||
command += [f"hydra.run.dir={os_cwd}", f"hydra.job.name=train_ddp_process_{local_rank}"]
|
||||
subprocess.Popen(command, env=env_copy, cwd=cwd)
|
||||
|
||||
# starting all processes at once can cause issues
|
||||
# with dataloaders delay between 1-10 seconds
|
||||
delay = np.random.uniform(1, 5, 1)[0]
|
||||
sleep(delay)
|
||||
|
||||
def _check_can_spawn_children(self) -> None:
|
||||
if self.cluster_environment.local_rank() != 0:
|
||||
raise RuntimeError(
|
||||
"Lightning attempted to launch new distributed processes with `local_rank > 0`. This should not happen."
|
||||
" Possible reasons: 1) LOCAL_RANK environment variable was incorrectly modified by the user,"
|
||||
" 2) `ClusterEnvironment.creates_processes_externally` incorrectly implemented."
|
||||
)
|
|
@ -0,0 +1,121 @@
|
|||
# Copyright The PyTorch Lightning team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import time
|
||||
from functools import wraps
|
||||
from multiprocessing.queues import SimpleQueue
|
||||
from typing import Any, Callable, Optional, Tuple, TYPE_CHECKING
|
||||
|
||||
import torch.multiprocessing as mp
|
||||
from torch.multiprocessing import ProcessContext
|
||||
|
||||
from lightning_lite.strategies.launchers.multiprocessing import _GlobalStateSnapshot, _MultiProcessingLauncher
|
||||
from lightning_lite.utilities import _TPU_AVAILABLE
|
||||
from lightning_lite.utilities.apply_func import move_data_to_device
|
||||
|
||||
if _TPU_AVAILABLE:
|
||||
import torch_xla.core.xla_model as xm
|
||||
import torch_xla.distributed.xla_multiprocessing as xmp
|
||||
else:
|
||||
xm, xmp = None, None
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from lightning_lite.strategies import Strategy
|
||||
|
||||
|
||||
class _XLALauncher(_MultiProcessingLauncher):
|
||||
r"""Launches processes that run a given function in parallel on XLA supported hardware, and joins them all at the
|
||||
end.
|
||||
|
||||
The main process in which this launcher is invoked creates N so-called worker processes (using the
|
||||
`torch_xla` :func:`xmp.spawn`) that run the given function.
|
||||
Worker processes have a rank that ranges from 0 to N - 1.
|
||||
|
||||
Note:
|
||||
- This launcher requires all objects to be pickleable.
|
||||
- It is important that the entry point to the program/script is guarded by ``if __name__ == "__main__"``.
|
||||
|
||||
Args:
|
||||
strategy: A reference to the strategy that is used together with this launcher
|
||||
"""
|
||||
|
||||
def __init__(self, strategy: "Strategy") -> None:
|
||||
super().__init__(strategy=strategy, start_method="fork")
|
||||
|
||||
@property
|
||||
def is_interactive_compatible(self) -> bool:
|
||||
return True
|
||||
|
||||
def launch(self, function: Callable, *args: Any, **kwargs: Any) -> Any:
|
||||
"""Launches processes that run the given function in parallel.
|
||||
|
||||
The function is allowed to have a return value. However, when all processes join, only the return value
|
||||
of worker process 0 gets returned from this `launch` method in the main process.
|
||||
|
||||
Arguments:
|
||||
function: The entry point for all launched processes.
|
||||
*args: Optional positional arguments to be passed to the given function.
|
||||
**kwargs: Optional keyword arguments to be passed to the given function.
|
||||
"""
|
||||
context = mp.get_context(self._start_method)
|
||||
return_queue = context.SimpleQueue()
|
||||
_save_spawn(
|
||||
self._wrapping_function,
|
||||
args=(function, args, kwargs, return_queue),
|
||||
nprocs=len(self._strategy.parallel_devices),
|
||||
start_method=self._start_method,
|
||||
)
|
||||
return return_queue.get()
|
||||
|
||||
def _wrapping_function(
|
||||
self,
|
||||
process_idx: int,
|
||||
function: Callable,
|
||||
args: Any,
|
||||
kwargs: Any,
|
||||
return_queue: SimpleQueue,
|
||||
global_states: Optional[_GlobalStateSnapshot] = None,
|
||||
) -> None:
|
||||
# TODO(lite): Update worker setup once TPUSpawn strategy is in Lite
|
||||
self._strategy._worker_setup(process_idx)
|
||||
results = function(*args, **kwargs)
|
||||
|
||||
if self._strategy.local_rank == 0:
|
||||
return_queue.put(move_data_to_device(results, "cpu"))
|
||||
|
||||
|
||||
def _save_spawn(
|
||||
fn: Callable,
|
||||
args: Tuple = (),
|
||||
nprocs: Optional[int] = None,
|
||||
join: bool = True,
|
||||
daemon: bool = False,
|
||||
start_method: str = "spawn",
|
||||
) -> Optional[ProcessContext]:
|
||||
"""Wraps the :func:`torch_xla.distributed.xla_multiprocessing.spawn` with added teardown logic for the worker
|
||||
processes."""
|
||||
|
||||
@wraps(fn)
|
||||
def wrapped(rank: int, *_args: Any) -> None:
|
||||
fn(rank, *_args)
|
||||
|
||||
# Make all processes wait for each other before joining
|
||||
# https://github.com/pytorch/xla/issues/1801#issuecomment-602799542
|
||||
xm.rendezvous("end-process")
|
||||
|
||||
# Ensure that the rank 0 process is the one exiting last
|
||||
# https://github.com/pytorch/xla/issues/2190#issuecomment-641665358
|
||||
if rank == 0:
|
||||
time.sleep(1)
|
||||
|
||||
return xmp.spawn(wrapped, args=args, nprocs=nprocs, join=join, daemon=daemon, start_method=start_method)
|
|
@ -1,13 +1,10 @@
|
|||
import multiprocessing
|
||||
import os
|
||||
from typing import Any, List, MutableSequence, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
|
||||
from lightning_lite.plugins.environments.torchelastic_environment import TorchElasticEnvironment
|
||||
|
||||
# TODO(lite): Fix the imports
|
||||
# from lightning_lite.strategies.launchers.multiprocessing import _is_forking_disabled
|
||||
from lightning_lite.strategies.launchers.multiprocessing import _is_forking_disabled
|
||||
from lightning_lite.utilities.exceptions import MisconfigurationException
|
||||
from lightning_lite.utilities.types import _DEVICE
|
||||
|
||||
|
@ -309,9 +306,3 @@ def is_cuda_available() -> bool:
|
|||
return torch.cuda.is_available()
|
||||
with multiprocessing.get_context("fork").Pool(1) as pool:
|
||||
return pool.apply(torch.cuda.is_available)
|
||||
|
||||
|
||||
# TODO(lite): move this back to launchers/multiprocessing.py once launchers have moved
|
||||
def _is_forking_disabled() -> bool:
|
||||
"""Returns whether forking is disabled through the environment variable ``PL_DISABLE_FORK``."""
|
||||
return bool(int(os.environ.get("PL_DISABLE_FORK", "0")))
|
||||
|
|
|
@ -11,13 +11,11 @@
|
|||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from pytorch_lightning.strategies.launchers.base import _Launcher
|
||||
from pytorch_lightning.strategies.launchers.multiprocessing import _MultiProcessingLauncher
|
||||
from pytorch_lightning.strategies.launchers.subprocess_script import _SubprocessScriptLauncher
|
||||
from pytorch_lightning.strategies.launchers.xla import _XLALauncher
|
||||
|
||||
__all__ = [
|
||||
"_Launcher",
|
||||
"_MultiProcessingLauncher",
|
||||
"_SubprocessScriptLauncher",
|
||||
"_XLALauncher",
|
||||
|
|
|
@ -26,10 +26,10 @@ from torch import Tensor
|
|||
from typing_extensions import Literal
|
||||
|
||||
import pytorch_lightning as pl
|
||||
from lightning_lite.strategies.launchers.base import _Launcher
|
||||
from lightning_lite.utilities.apply_func import move_data_to_device
|
||||
from lightning_lite.utilities.seed import _collect_rng_states, _set_rng_states
|
||||
from lightning_lite.utilities.types import _PATH
|
||||
from pytorch_lightning.strategies.launchers.base import _Launcher
|
||||
from pytorch_lightning.trainer.states import TrainerFn, TrainerState
|
||||
from pytorch_lightning.utilities.imports import _TORCH_GREATER_EQUAL_1_11
|
||||
from pytorch_lightning.utilities.rank_zero import rank_zero_debug
|
||||
|
|
|
@ -23,7 +23,7 @@ from lightning_utilities.core.imports import RequirementCache
|
|||
|
||||
import pytorch_lightning as pl
|
||||
from lightning_lite.plugins.environments.cluster_environment import ClusterEnvironment
|
||||
from pytorch_lightning.strategies.launchers.base import _Launcher
|
||||
from lightning_lite.strategies.launchers.base import _Launcher
|
||||
|
||||
_HYDRA_AVAILABLE = RequirementCache("hydra")
|
||||
|
||||
|
|
|
@ -24,6 +24,7 @@ from torch.utils.data import DataLoader
|
|||
|
||||
import pytorch_lightning as pl
|
||||
from lightning_lite.plugins.io.checkpoint_plugin import CheckpointIO
|
||||
from lightning_lite.strategies.launchers.base import _Launcher
|
||||
from lightning_lite.utilities.apply_func import move_data_to_device
|
||||
from lightning_lite.utilities.distributed import ReduceOp
|
||||
from lightning_lite.utilities.optimizer import optimizer_to_device, optimizers_to_device
|
||||
|
@ -32,7 +33,6 @@ from pytorch_lightning.core.optimizer import _init_optimizers_and_lr_schedulers,
|
|||
from pytorch_lightning.plugins import TorchCheckpointIO
|
||||
from pytorch_lightning.plugins.io.wrapper import _WrappingCheckpointIO
|
||||
from pytorch_lightning.plugins.precision import PrecisionPlugin
|
||||
from pytorch_lightning.strategies.launchers.base import _Launcher
|
||||
from pytorch_lightning.trainer.states import TrainerFn
|
||||
from pytorch_lightning.utilities.types import (
|
||||
LRSchedulerConfig,
|
||||
|
|
|
@ -76,6 +76,13 @@ def teardown_process_group():
|
|||
torch.distributed.destroy_process_group()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def reset_deterministic_algorithm():
|
||||
"""Ensures that torch determinism settings are reset before the next test runs."""
|
||||
yield
|
||||
torch.use_deterministic_algorithms(False)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def caplog(caplog):
|
||||
"""Workaround for https://github.com/pytest-dev/pytest/issues/3697.
|
||||
|
|
|
@ -0,0 +1,95 @@
|
|||
# Copyright The PyTorch Lightning team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import os
|
||||
from unittest import mock
|
||||
from unittest.mock import ANY, Mock
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from lightning_lite.strategies.launchers.multiprocessing import _GlobalStateSnapshot, _MultiProcessingLauncher
|
||||
from tests_pytorch.helpers.runif import RunIf
|
||||
|
||||
|
||||
@RunIf(skip_windows=True)
|
||||
@pytest.mark.parametrize("start_method", ["fork", "forkserver"])
|
||||
def test_multiprocessing_launcher_interactive_compatible(start_method):
|
||||
launcher = _MultiProcessingLauncher(strategy=Mock(), start_method=start_method)
|
||||
assert launcher.is_interactive_compatible == (start_method == "fork")
|
||||
|
||||
|
||||
@mock.patch("lightning_lite.strategies.launchers.multiprocessing.mp.get_all_start_methods", return_value=[])
|
||||
def test_multiprocessing_launcher_forking_on_unsupported_platform(_):
|
||||
with pytest.raises(ValueError, match="The start method 'fork' is not available on this platform"):
|
||||
_MultiProcessingLauncher(strategy=Mock(), start_method="fork")
|
||||
|
||||
|
||||
@RunIf(skip_windows=True)
|
||||
@pytest.mark.parametrize("start_method", ["fork", "forkserver"])
|
||||
@mock.patch.dict(os.environ, {"PL_DISABLE_FORK": "1"}, clear=True)
|
||||
def test_multiprocessing_launcher_disabled_forking(start_method):
|
||||
with pytest.raises(ValueError, match="Forking is disabled in this environment"):
|
||||
_MultiProcessingLauncher(strategy=Mock(), start_method=start_method)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("start_method", ["spawn", "fork"])
|
||||
@mock.patch("lightning_lite.strategies.launchers.multiprocessing.mp")
|
||||
def test_multiprocessing_launcher_start_method(mp_mock, start_method):
|
||||
mp_mock.get_all_start_methods.return_value = [start_method]
|
||||
launcher = _MultiProcessingLauncher(strategy=Mock(), start_method=start_method)
|
||||
launcher.launch(function=Mock())
|
||||
mp_mock.get_context.assert_called_with(start_method)
|
||||
mp_mock.start_processes.assert_called_with(
|
||||
ANY,
|
||||
args=ANY,
|
||||
nprocs=ANY,
|
||||
start_method=start_method,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("start_method", ["spawn", "fork"])
|
||||
@mock.patch("lightning_lite.strategies.launchers.multiprocessing.mp")
|
||||
def test_multiprocessing_launcher_restore_globals(mp_mock, start_method):
|
||||
"""Test that we pass the global state snapshot to the worker function only if we are starting with 'spawn'."""
|
||||
mp_mock.get_all_start_methods.return_value = [start_method]
|
||||
launcher = _MultiProcessingLauncher(strategy=Mock(), start_method=start_method)
|
||||
launcher.launch(function=Mock())
|
||||
function_args = mp_mock.start_processes.call_args[1]["args"]
|
||||
if start_method == "spawn":
|
||||
assert len(function_args) == 5
|
||||
assert isinstance(function_args[4], _GlobalStateSnapshot)
|
||||
else:
|
||||
assert len(function_args) == 4
|
||||
|
||||
|
||||
@pytest.mark.usefixtures("reset_deterministic_algorithm")
|
||||
def test_global_state_snapshot():
|
||||
"""Test the capture() and restore() methods for the global state snapshot."""
|
||||
torch.use_deterministic_algorithms(True)
|
||||
torch.backends.cudnn.benchmark = False
|
||||
torch.manual_seed(123)
|
||||
|
||||
# capture the state of globals
|
||||
snapshot = _GlobalStateSnapshot.capture()
|
||||
|
||||
# simulate there is a process boundary and flags get reset here
|
||||
torch.use_deterministic_algorithms(False)
|
||||
torch.backends.cudnn.benchmark = True
|
||||
torch.manual_seed(321)
|
||||
|
||||
# restore the state of globals
|
||||
snapshot.restore()
|
||||
assert torch.are_deterministic_algorithms_enabled()
|
||||
assert not torch.backends.cudnn.benchmark
|
||||
assert torch.initial_seed() == 123
|
|
@ -0,0 +1,78 @@
|
|||
# Copyright The PyTorch Lightning team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import os
|
||||
from unittest import mock
|
||||
from unittest.mock import Mock
|
||||
|
||||
import pytest
|
||||
|
||||
from lightning_lite.strategies.launchers.subprocess_script import _SubprocessScriptLauncher
|
||||
|
||||
|
||||
def test_subprocess_script_launcher_interactive_compatible():
|
||||
launcher = _SubprocessScriptLauncher(Mock(), num_processes=2, num_nodes=1)
|
||||
assert not launcher.is_interactive_compatible
|
||||
|
||||
|
||||
@mock.patch("lightning_lite.strategies.launchers.subprocess_script.subprocess.Popen")
|
||||
def test_subprocess_script_launcher_error_launching_on_non_zero_rank(popen_mock):
|
||||
cluster_env = Mock()
|
||||
cluster_env.creates_processes_externally = False
|
||||
cluster_env.local_rank.return_value = 1
|
||||
launcher = _SubprocessScriptLauncher(cluster_env, num_processes=2, num_nodes=1)
|
||||
with pytest.raises(RuntimeError, match="attempted to launch new distributed processes with `local_rank > 0`"):
|
||||
launcher.launch(Mock())
|
||||
|
||||
|
||||
@mock.patch("lightning_lite.strategies.launchers.subprocess_script.subprocess.Popen")
|
||||
def test_subprocess_script_launcher_external_processes(popen_mock):
|
||||
cluster_env = Mock()
|
||||
cluster_env.creates_processes_externally = True
|
||||
function = Mock()
|
||||
launcher = _SubprocessScriptLauncher(cluster_env, num_processes=4, num_nodes=2)
|
||||
launcher.launch(function, "positional-arg", keyword_arg=0)
|
||||
function.assert_called_with("positional-arg", keyword_arg=0)
|
||||
popen_mock.assert_not_called()
|
||||
|
||||
|
||||
@mock.patch("lightning_lite.strategies.launchers.subprocess_script.sleep")
|
||||
@mock.patch("lightning_lite.strategies.launchers.subprocess_script.subprocess.Popen")
|
||||
def test_subprocess_script_launcher_launch_processes(popen_mock, _):
|
||||
cluster_env = Mock()
|
||||
cluster_env.creates_processes_externally = False
|
||||
cluster_env.local_rank.return_value = 0
|
||||
cluster_env.main_address = "address"
|
||||
cluster_env.main_port = 1234
|
||||
|
||||
function = Mock()
|
||||
launcher = _SubprocessScriptLauncher(cluster_env, num_processes=4, num_nodes=2)
|
||||
num_new_processes = launcher.num_processes - 1
|
||||
|
||||
# launches n-1 new processes, the current one will participate too
|
||||
launcher.launch(function, "positional-arg", keyword_arg=0)
|
||||
|
||||
calls = popen_mock.call_args_list
|
||||
assert len(calls) == num_new_processes
|
||||
|
||||
# world size in child processes
|
||||
world_sizes = [int(calls[i][1]["env"]["WORLD_SIZE"]) for i in range(num_new_processes)]
|
||||
assert world_sizes == [launcher.num_processes * launcher.num_nodes] * num_new_processes
|
||||
|
||||
# local rank in child processes
|
||||
local_ranks = [int(calls[i][1]["env"]["LOCAL_RANK"]) for i in range(num_new_processes)]
|
||||
assert local_ranks == list(range(1, num_new_processes + 1))
|
||||
|
||||
# the current process
|
||||
assert int(os.environ["WORLD_SIZE"]) == launcher.num_processes * launcher.num_nodes
|
||||
assert int(os.environ["LOCAL_RANK"]) == 0
|
|
@ -0,0 +1,39 @@
|
|||
from unittest import mock
|
||||
from unittest.mock import ANY, Mock
|
||||
|
||||
from tests_lite.helpers.runif import RunIf
|
||||
|
||||
from lightning_lite.strategies.launchers.xla import _XLALauncher
|
||||
|
||||
|
||||
@RunIf(skip_windows=True)
|
||||
def test_xla_launcher_default_start_method():
|
||||
launcher = _XLALauncher(strategy=Mock())
|
||||
assert launcher._start_method == "fork"
|
||||
|
||||
|
||||
@RunIf(skip_windows=True)
|
||||
def test_xla_launcher_interactive_compatible():
|
||||
launcher = _XLALauncher(strategy=Mock())
|
||||
assert launcher.is_interactive_compatible
|
||||
|
||||
|
||||
@RunIf(skip_windows=True)
|
||||
@mock.patch("lightning_lite.strategies.launchers.xla.mp")
|
||||
@mock.patch("lightning_lite.strategies.launchers.xla.xm")
|
||||
@mock.patch("lightning_lite.strategies.launchers.xla.xmp")
|
||||
def test_xla_launcher_xmp_spawn(xmp_mock, xm_mock, mp_mock):
|
||||
strategy = Mock()
|
||||
strategy.parallel_devices = [0, 1, 2, 3]
|
||||
launcher = _XLALauncher(strategy=strategy)
|
||||
function = Mock()
|
||||
launcher.launch(function, "positional-arg", keyword_arg=0)
|
||||
# mp_mock.get_context.assert_called_with(start_method)
|
||||
xmp_mock.spawn.assert_called_with(
|
||||
ANY,
|
||||
args=(function, ("positional-arg",), {"keyword_arg": 0}, ANY),
|
||||
nprocs=4,
|
||||
join=True,
|
||||
daemon=False,
|
||||
start_method="fork",
|
||||
)
|
Loading…
Reference in New Issue