From 8f0a64dab6413ab495e0edce47d81afc6f14060c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Mon, 12 Sep 2022 16:15:42 +0200 Subject: [PATCH] Standalone Lite: Launchers (#14555) Co-authored-by: Jirka Borovec --- src/lightning_lite/strategies/__init__.py | 0 .../strategies/launchers/__init__.py | 0 .../strategies/launchers/base.py | 0 .../strategies/launchers/multiprocessing.py | 178 ++++++++++++++++++ .../strategies/launchers/subprocess_script.py | 167 ++++++++++++++++ .../strategies/launchers/xla.py | 121 ++++++++++++ src/lightning_lite/utilities/device_parser.py | 11 +- .../strategies/launchers/__init__.py | 2 - .../strategies/launchers/multiprocessing.py | 2 +- .../strategies/launchers/subprocess_script.py | 2 +- src/pytorch_lightning/strategies/strategy.py | 2 +- tests/tests_lite/conftest.py | 7 + tests/tests_lite/strategies/__init__.py | 0 .../strategies/launchers/__init__.py | 0 .../launchers/test_multiprocessing.py | 95 ++++++++++ .../launchers/test_subprocess_script.py | 78 ++++++++ .../strategies/launchers/test_xla.py | 39 ++++ 17 files changed, 689 insertions(+), 15 deletions(-) create mode 100644 src/lightning_lite/strategies/__init__.py create mode 100644 src/lightning_lite/strategies/launchers/__init__.py rename src/{pytorch_lightning => lightning_lite}/strategies/launchers/base.py (100%) create mode 100644 src/lightning_lite/strategies/launchers/multiprocessing.py create mode 100644 src/lightning_lite/strategies/launchers/subprocess_script.py create mode 100644 src/lightning_lite/strategies/launchers/xla.py create mode 100644 tests/tests_lite/strategies/__init__.py create mode 100644 tests/tests_lite/strategies/launchers/__init__.py create mode 100644 tests/tests_lite/strategies/launchers/test_multiprocessing.py create mode 100644 tests/tests_lite/strategies/launchers/test_subprocess_script.py create mode 100644 tests/tests_lite/strategies/launchers/test_xla.py diff --git a/src/lightning_lite/strategies/__init__.py b/src/lightning_lite/strategies/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/src/lightning_lite/strategies/launchers/__init__.py b/src/lightning_lite/strategies/launchers/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/src/pytorch_lightning/strategies/launchers/base.py b/src/lightning_lite/strategies/launchers/base.py similarity index 100% rename from src/pytorch_lightning/strategies/launchers/base.py rename to src/lightning_lite/strategies/launchers/base.py diff --git a/src/lightning_lite/strategies/launchers/multiprocessing.py b/src/lightning_lite/strategies/launchers/multiprocessing.py new file mode 100644 index 0000000000..fc6dd5025f --- /dev/null +++ b/src/lightning_lite/strategies/launchers/multiprocessing.py @@ -0,0 +1,178 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os +from dataclasses import dataclass +from multiprocessing.queues import SimpleQueue +from typing import Any, Callable, Dict, Optional + +import torch +import torch.backends.cudnn +import torch.multiprocessing as mp +from typing_extensions import Literal + +from lightning_lite.strategies.launchers.base import _Launcher +from lightning_lite.utilities.apply_func import move_data_to_device +from lightning_lite.utilities.imports import _TORCH_GREATER_EQUAL_1_11 +from lightning_lite.utilities.seed import _collect_rng_states, _set_rng_states + + +class _MultiProcessingLauncher(_Launcher): + r"""Launches processes that run a given function in parallel, and joins them all at the end. + + The main process in which this launcher is invoked creates N so-called worker processes (using + :func:`torch.multiprocessing.start_processes`) that run the given function. + Worker processes have a rank that ranges from 0 to N - 1. + + Note: + - This launcher requires all objects to be pickleable. + - It is important that the entry point to the program/script is guarded by ``if __name__ == "__main__"``. + - With start method 'fork' the user must ensure that no CUDA context gets created in the main process before + the launcher is invoked. E.g., one should avoid creating cuda tensors or calling ``torch.cuda.*`` functions + before calling ``Trainer.fit``. + + Args: + strategy: A reference to the strategy that is used together with this launcher. + start_method: The method how to start the processes. + - 'spawn': The default start method. Requires all objects to be pickleable. + - 'fork': Preferrable for IPython/Jupyter environments where 'spawn' is not available. Not available on + the Windows platform for example. + - 'forkserver': Alternative implementation to 'fork'. + """ + + def __init__( + self, + # TODO(lite): Fix this type annotation once the strategy base class gets added to Lite + strategy: "Strategy", # type: ignore[name-defined] # noqa: F821 + start_method: Literal["spawn", "fork", "forkserver"] = "spawn", + ) -> None: + self._strategy = strategy + self._start_method = start_method + if start_method not in mp.get_all_start_methods(): + raise ValueError( + f"The start method '{self._start_method}' is not available on this platform. Available methods are:" + f" {', '.join(mp.get_all_start_methods())}" + ) + if start_method in ("fork", "forkserver") and _is_forking_disabled(): + raise ValueError( + "Forking is disabled in this environment by `PL_DISABLE_FORKING=1`. Choose a different start method." + ) + + @property + def is_interactive_compatible(self) -> bool: + # The start method 'spawn' is not supported in interactive environments + # The start method 'fork' is the only one supported in Jupyter environments, with constraints around CUDA + # initialization. For more context, see https://github.com/Lightning-AI/lightning/issues/7550 + return self._start_method == "fork" + + def launch(self, function: Callable, *args: Any, **kwargs: Any) -> Any: + """Launches processes that run the given function in parallel. + + The function is allowed to have a return value. However, when all processes join, only the return value + of worker process 0 gets returned from this `launch` method in the main process. + + Arguments: + function: The entry point for all launched processes. + *args: Optional positional arguments to be passed to the given function. + **kwargs: Optional keyword arguments to be passed to the given function. + """ + # The default cluster environment in Lightning chooses a random free port number + # This needs to be done in the main process here before starting processes to ensure each rank will connect + # through the same port + os.environ["MASTER_PORT"] = str(self._strategy.cluster_environment.main_port) + context = mp.get_context(self._start_method) + return_queue = context.SimpleQueue() + + if self._start_method == "spawn": + global_states = _GlobalStateSnapshot.capture() + process_args = [function, args, kwargs, return_queue, global_states] + else: + process_args = [function, args, kwargs, return_queue] + + mp.start_processes( + self._wrapping_function, + args=process_args, + nprocs=self._strategy.num_processes, + start_method=self._start_method, + ) + return return_queue.get() + + def _wrapping_function( + self, + process_idx: int, + function: Callable, + args: Any, + kwargs: Any, + return_queue: SimpleQueue, + global_states: Optional["_GlobalStateSnapshot"] = None, + ) -> None: + if global_states: + global_states.restore() + # TODO(lite): Update worker setup once DDPSpawn strategy is in Lite + self._strategy._worker_setup(process_idx) + results = function(*args, **kwargs) + + if self._strategy.local_rank == 0: + return_queue.put(move_data_to_device(results, "cpu")) + + +@dataclass +class _GlobalStateSnapshot: + """Captures a hand-selected set of (global) variables in modules and provides a way to restore them. + + It facilitates and encapsulates the transfer of globals like PyTorch's deterministic flags or random generator state + across process boundaries when launching processes with :func:`torch.multiprocessing.spawn`. + + Example: + + .. code-block:: python + + # in main process + snapshot = _GlobalStateSnapshot.capture() + + # in worker process + snapshot.restore() + """ + + use_deterministic_algorithms: bool + use_deterministic_algorithms_warn_only: bool + cudnn_benchmark: bool + rng_states: Dict[str, Any] + + @classmethod + def capture(cls) -> "_GlobalStateSnapshot": + """Capture a few global states from torch, numpy, etc., that we want to restore in a spawned worker + process.""" + warn_only = torch.is_deterministic_algorithms_warn_only_enabled() if _TORCH_GREATER_EQUAL_1_11 else False + return cls( + use_deterministic_algorithms=torch.are_deterministic_algorithms_enabled(), + use_deterministic_algorithms_warn_only=warn_only, + cudnn_benchmark=torch.backends.cudnn.benchmark, + rng_states=_collect_rng_states(), + ) + + def restore(self) -> None: + """Restores all globals to the values captured in the :meth:`capture` method.""" + if _TORCH_GREATER_EQUAL_1_11: + torch.use_deterministic_algorithms( + self.use_deterministic_algorithms, warn_only=self.use_deterministic_algorithms_warn_only + ) + else: + torch.use_deterministic_algorithms(self.use_deterministic_algorithms) + torch.backends.cudnn.benchmark = self.cudnn_benchmark + _set_rng_states(self.rng_states) + + +def _is_forking_disabled() -> bool: + """Returns whether forking is disabled through the environment variable ``PL_DISABLE_FORK``.""" + return bool(int(os.environ.get("PL_DISABLE_FORK", "0"))) diff --git a/src/lightning_lite/strategies/launchers/subprocess_script.py b/src/lightning_lite/strategies/launchers/subprocess_script.py new file mode 100644 index 0000000000..7f814e01e2 --- /dev/null +++ b/src/lightning_lite/strategies/launchers/subprocess_script.py @@ -0,0 +1,167 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os +import subprocess +import sys +from time import sleep +from typing import Any, Callable, Optional + +import __main__ +import numpy as np +from lightning_utilities.core.imports import RequirementCache + +from lightning_lite.strategies.launchers.base import _Launcher + +_HYDRA_AVAILABLE = RequirementCache("hydra") + + +class _SubprocessScriptLauncher(_Launcher): + r""" + A process laucher that invokes the current script as many times as desired in a single node. + + This launcher needs to be invoked on each node. + In its default behavior, the main process in each node then spawns N-1 child processes via :func:`subprocess.Popen`, + where N is the number of devices (e.g. GPU) per node. It is very similar to how :mod:`torch.distributed.run` + launches processes. + + For example, if the script gets invoked with the command + + .. code-block:: bash + + python train.py --devices 4 + + The launcher will create three additional subprocesses that get called like so: + + .. code-block:: bash + + LOCAL_RANK=1 python train.py --devices 4 + LOCAL_RANK=2 python train.py --devices 4 + LOCAL_RANK=3 python train.py --devices 4 + + It is implied that the main process which launched the others has ``LOCAL_RANK=0``. + Beside the local rank, the following other environment variables also get set, but unlike the local rank, these + get determined by the cluster environment: + + 1. `MASTER_ADDR`: The IP address of the main node. + 2. `MASTER_PORT`: The port number of the main node through which all processes communicate. + 3. `NODE_RANK`: The index of the node the current process is running on. Ranges from 0 to ``num_nodes - 1``. + 4. `WORLD_SIZE`: The total number of processes across all nodes, i.e., ``num_processes * num_nodes``. + + Arguments: + cluster_environment: A cluster environment that provides access to world size, node rank, etc. + num_processes: The number of processes to launch in the current node. + num_nodes: The total number of nodes that participate in this process group. + """ + + def __init__( + self, + # TODO(lite): Update type annotation once ClusterEnvironment has moved to Lite + cluster_environment: "ClusterEnvironment", # type: ignore[name-defined] # noqa: F821 + num_processes: int, + num_nodes: int, + ) -> None: + super().__init__() + self.cluster_environment = cluster_environment + self.num_processes = num_processes + self.num_nodes = num_nodes + + @property + def is_interactive_compatible(self) -> bool: + return False + + def launch(self, function: Callable, *args: Any, **kwargs: Any) -> Any: + """Creates new processes, then calls the given function. + + Arguments: + function: A callback function to execute after all processes have been created. + It is up to the implementation of this function to synchronize the processes, e.g., with barriers. + *args: Optional positional arguments to be passed to the given function. + **kwargs: Optional keyword arguments to be passed to the given function. + """ + if not self.cluster_environment.creates_processes_externally: + self._call_children_scripts() + return function(*args, **kwargs) + + def _call_children_scripts(self) -> None: + # bookkeeping of spawned processes + self._check_can_spawn_children() + + # DDP Environment variables + os.environ["MASTER_ADDR"] = self.cluster_environment.main_address + os.environ["MASTER_PORT"] = str(self.cluster_environment.main_port) + + # allow the user to pass the node rank + os.environ["NODE_RANK"] = str(self.cluster_environment.node_rank()) + os.environ["LOCAL_RANK"] = str(self.cluster_environment.local_rank()) + + # Check if the current calling command looked like `python a/b/c.py` or `python -m a.b.c` + # See https://docs.python.org/3/reference/import.html#main-spec + if __main__.__spec__ is None: # pragma: no-cover + # Script called as `python a/b/c.py` + if _HYDRA_AVAILABLE: + # when user is using hydra find the absolute path + from hydra.utils import to_absolute_path + + to_abs_path = to_absolute_path + else: + to_abs_path = os.path.abspath + + # pull out the commands used to run the script and resolve the absolute file path + command = sys.argv + try: + full_path = to_abs_path(command[0]) + except Exception: + full_path = os.path.abspath(command[0]) + + command[0] = full_path + # use the same python interpreter and actually running + command = [sys.executable] + command + else: # Script called as `python -m a.b.c` + command = [sys.executable, "-m", __main__.__spec__.name] + sys.argv[1:] + + os.environ["WORLD_SIZE"] = f"{self.num_processes * self.num_nodes}" + + for local_rank in range(1, self.num_processes): + env_copy = os.environ.copy() + env_copy["LOCAL_RANK"] = f"{local_rank}" + + # remove env var if global seed not set + if os.environ.get("PL_GLOBAL_SEED") is None and "PL_GLOBAL_SEED" in env_copy: + del env_copy["PL_GLOBAL_SEED"] + + # start process + # if hydra is available and initialized, make sure to set the cwd correctly + cwd: Optional[str] = None + if _HYDRA_AVAILABLE: + from hydra.core.hydra_config import HydraConfig + from hydra.utils import get_original_cwd + + if HydraConfig.initialized(): + cwd = get_original_cwd() + os_cwd = f'"{os.getcwd()}"' + command += [f"hydra.run.dir={os_cwd}", f"hydra.job.name=train_ddp_process_{local_rank}"] + subprocess.Popen(command, env=env_copy, cwd=cwd) + + # starting all processes at once can cause issues + # with dataloaders delay between 1-10 seconds + delay = np.random.uniform(1, 5, 1)[0] + sleep(delay) + + def _check_can_spawn_children(self) -> None: + if self.cluster_environment.local_rank() != 0: + raise RuntimeError( + "Lightning attempted to launch new distributed processes with `local_rank > 0`. This should not happen." + " Possible reasons: 1) LOCAL_RANK environment variable was incorrectly modified by the user," + " 2) `ClusterEnvironment.creates_processes_externally` incorrectly implemented." + ) diff --git a/src/lightning_lite/strategies/launchers/xla.py b/src/lightning_lite/strategies/launchers/xla.py new file mode 100644 index 0000000000..6580fd4a01 --- /dev/null +++ b/src/lightning_lite/strategies/launchers/xla.py @@ -0,0 +1,121 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import time +from functools import wraps +from multiprocessing.queues import SimpleQueue +from typing import Any, Callable, Optional, Tuple, TYPE_CHECKING + +import torch.multiprocessing as mp +from torch.multiprocessing import ProcessContext + +from lightning_lite.strategies.launchers.multiprocessing import _GlobalStateSnapshot, _MultiProcessingLauncher +from lightning_lite.utilities import _TPU_AVAILABLE +from lightning_lite.utilities.apply_func import move_data_to_device + +if _TPU_AVAILABLE: + import torch_xla.core.xla_model as xm + import torch_xla.distributed.xla_multiprocessing as xmp +else: + xm, xmp = None, None + +if TYPE_CHECKING: + from lightning_lite.strategies import Strategy + + +class _XLALauncher(_MultiProcessingLauncher): + r"""Launches processes that run a given function in parallel on XLA supported hardware, and joins them all at the + end. + + The main process in which this launcher is invoked creates N so-called worker processes (using the + `torch_xla` :func:`xmp.spawn`) that run the given function. + Worker processes have a rank that ranges from 0 to N - 1. + + Note: + - This launcher requires all objects to be pickleable. + - It is important that the entry point to the program/script is guarded by ``if __name__ == "__main__"``. + + Args: + strategy: A reference to the strategy that is used together with this launcher + """ + + def __init__(self, strategy: "Strategy") -> None: + super().__init__(strategy=strategy, start_method="fork") + + @property + def is_interactive_compatible(self) -> bool: + return True + + def launch(self, function: Callable, *args: Any, **kwargs: Any) -> Any: + """Launches processes that run the given function in parallel. + + The function is allowed to have a return value. However, when all processes join, only the return value + of worker process 0 gets returned from this `launch` method in the main process. + + Arguments: + function: The entry point for all launched processes. + *args: Optional positional arguments to be passed to the given function. + **kwargs: Optional keyword arguments to be passed to the given function. + """ + context = mp.get_context(self._start_method) + return_queue = context.SimpleQueue() + _save_spawn( + self._wrapping_function, + args=(function, args, kwargs, return_queue), + nprocs=len(self._strategy.parallel_devices), + start_method=self._start_method, + ) + return return_queue.get() + + def _wrapping_function( + self, + process_idx: int, + function: Callable, + args: Any, + kwargs: Any, + return_queue: SimpleQueue, + global_states: Optional[_GlobalStateSnapshot] = None, + ) -> None: + # TODO(lite): Update worker setup once TPUSpawn strategy is in Lite + self._strategy._worker_setup(process_idx) + results = function(*args, **kwargs) + + if self._strategy.local_rank == 0: + return_queue.put(move_data_to_device(results, "cpu")) + + +def _save_spawn( + fn: Callable, + args: Tuple = (), + nprocs: Optional[int] = None, + join: bool = True, + daemon: bool = False, + start_method: str = "spawn", +) -> Optional[ProcessContext]: + """Wraps the :func:`torch_xla.distributed.xla_multiprocessing.spawn` with added teardown logic for the worker + processes.""" + + @wraps(fn) + def wrapped(rank: int, *_args: Any) -> None: + fn(rank, *_args) + + # Make all processes wait for each other before joining + # https://github.com/pytorch/xla/issues/1801#issuecomment-602799542 + xm.rendezvous("end-process") + + # Ensure that the rank 0 process is the one exiting last + # https://github.com/pytorch/xla/issues/2190#issuecomment-641665358 + if rank == 0: + time.sleep(1) + + return xmp.spawn(wrapped, args=args, nprocs=nprocs, join=join, daemon=daemon, start_method=start_method) diff --git a/src/lightning_lite/utilities/device_parser.py b/src/lightning_lite/utilities/device_parser.py index f0e5802d07..6967f7bf0a 100644 --- a/src/lightning_lite/utilities/device_parser.py +++ b/src/lightning_lite/utilities/device_parser.py @@ -1,13 +1,10 @@ import multiprocessing -import os from typing import Any, List, MutableSequence, Optional, Tuple, Union import torch from lightning_lite.plugins.environments.torchelastic_environment import TorchElasticEnvironment - -# TODO(lite): Fix the imports -# from lightning_lite.strategies.launchers.multiprocessing import _is_forking_disabled +from lightning_lite.strategies.launchers.multiprocessing import _is_forking_disabled from lightning_lite.utilities.exceptions import MisconfigurationException from lightning_lite.utilities.types import _DEVICE @@ -309,9 +306,3 @@ def is_cuda_available() -> bool: return torch.cuda.is_available() with multiprocessing.get_context("fork").Pool(1) as pool: return pool.apply(torch.cuda.is_available) - - -# TODO(lite): move this back to launchers/multiprocessing.py once launchers have moved -def _is_forking_disabled() -> bool: - """Returns whether forking is disabled through the environment variable ``PL_DISABLE_FORK``.""" - return bool(int(os.environ.get("PL_DISABLE_FORK", "0"))) diff --git a/src/pytorch_lightning/strategies/launchers/__init__.py b/src/pytorch_lightning/strategies/launchers/__init__.py index d75df88b2d..1c106cc8ff 100644 --- a/src/pytorch_lightning/strategies/launchers/__init__.py +++ b/src/pytorch_lightning/strategies/launchers/__init__.py @@ -11,13 +11,11 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from pytorch_lightning.strategies.launchers.base import _Launcher from pytorch_lightning.strategies.launchers.multiprocessing import _MultiProcessingLauncher from pytorch_lightning.strategies.launchers.subprocess_script import _SubprocessScriptLauncher from pytorch_lightning.strategies.launchers.xla import _XLALauncher __all__ = [ - "_Launcher", "_MultiProcessingLauncher", "_SubprocessScriptLauncher", "_XLALauncher", diff --git a/src/pytorch_lightning/strategies/launchers/multiprocessing.py b/src/pytorch_lightning/strategies/launchers/multiprocessing.py index fdc17f8b8d..be6a56b2e3 100644 --- a/src/pytorch_lightning/strategies/launchers/multiprocessing.py +++ b/src/pytorch_lightning/strategies/launchers/multiprocessing.py @@ -26,10 +26,10 @@ from torch import Tensor from typing_extensions import Literal import pytorch_lightning as pl +from lightning_lite.strategies.launchers.base import _Launcher from lightning_lite.utilities.apply_func import move_data_to_device from lightning_lite.utilities.seed import _collect_rng_states, _set_rng_states from lightning_lite.utilities.types import _PATH -from pytorch_lightning.strategies.launchers.base import _Launcher from pytorch_lightning.trainer.states import TrainerFn, TrainerState from pytorch_lightning.utilities.imports import _TORCH_GREATER_EQUAL_1_11 from pytorch_lightning.utilities.rank_zero import rank_zero_debug diff --git a/src/pytorch_lightning/strategies/launchers/subprocess_script.py b/src/pytorch_lightning/strategies/launchers/subprocess_script.py index f9e565260f..6713f636b9 100644 --- a/src/pytorch_lightning/strategies/launchers/subprocess_script.py +++ b/src/pytorch_lightning/strategies/launchers/subprocess_script.py @@ -23,7 +23,7 @@ from lightning_utilities.core.imports import RequirementCache import pytorch_lightning as pl from lightning_lite.plugins.environments.cluster_environment import ClusterEnvironment -from pytorch_lightning.strategies.launchers.base import _Launcher +from lightning_lite.strategies.launchers.base import _Launcher _HYDRA_AVAILABLE = RequirementCache("hydra") diff --git a/src/pytorch_lightning/strategies/strategy.py b/src/pytorch_lightning/strategies/strategy.py index 0f73b7b24e..dc2a5b6397 100644 --- a/src/pytorch_lightning/strategies/strategy.py +++ b/src/pytorch_lightning/strategies/strategy.py @@ -24,6 +24,7 @@ from torch.utils.data import DataLoader import pytorch_lightning as pl from lightning_lite.plugins.io.checkpoint_plugin import CheckpointIO +from lightning_lite.strategies.launchers.base import _Launcher from lightning_lite.utilities.apply_func import move_data_to_device from lightning_lite.utilities.distributed import ReduceOp from lightning_lite.utilities.optimizer import optimizer_to_device, optimizers_to_device @@ -32,7 +33,6 @@ from pytorch_lightning.core.optimizer import _init_optimizers_and_lr_schedulers, from pytorch_lightning.plugins import TorchCheckpointIO from pytorch_lightning.plugins.io.wrapper import _WrappingCheckpointIO from pytorch_lightning.plugins.precision import PrecisionPlugin -from pytorch_lightning.strategies.launchers.base import _Launcher from pytorch_lightning.trainer.states import TrainerFn from pytorch_lightning.utilities.types import ( LRSchedulerConfig, diff --git a/tests/tests_lite/conftest.py b/tests/tests_lite/conftest.py index 209d6869a1..952d32e4a9 100644 --- a/tests/tests_lite/conftest.py +++ b/tests/tests_lite/conftest.py @@ -76,6 +76,13 @@ def teardown_process_group(): torch.distributed.destroy_process_group() +@pytest.fixture +def reset_deterministic_algorithm(): + """Ensures that torch determinism settings are reset before the next test runs.""" + yield + torch.use_deterministic_algorithms(False) + + @pytest.fixture def caplog(caplog): """Workaround for https://github.com/pytest-dev/pytest/issues/3697. diff --git a/tests/tests_lite/strategies/__init__.py b/tests/tests_lite/strategies/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/tests_lite/strategies/launchers/__init__.py b/tests/tests_lite/strategies/launchers/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/tests_lite/strategies/launchers/test_multiprocessing.py b/tests/tests_lite/strategies/launchers/test_multiprocessing.py new file mode 100644 index 0000000000..70b45763fe --- /dev/null +++ b/tests/tests_lite/strategies/launchers/test_multiprocessing.py @@ -0,0 +1,95 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os +from unittest import mock +from unittest.mock import ANY, Mock + +import pytest +import torch + +from lightning_lite.strategies.launchers.multiprocessing import _GlobalStateSnapshot, _MultiProcessingLauncher +from tests_pytorch.helpers.runif import RunIf + + +@RunIf(skip_windows=True) +@pytest.mark.parametrize("start_method", ["fork", "forkserver"]) +def test_multiprocessing_launcher_interactive_compatible(start_method): + launcher = _MultiProcessingLauncher(strategy=Mock(), start_method=start_method) + assert launcher.is_interactive_compatible == (start_method == "fork") + + +@mock.patch("lightning_lite.strategies.launchers.multiprocessing.mp.get_all_start_methods", return_value=[]) +def test_multiprocessing_launcher_forking_on_unsupported_platform(_): + with pytest.raises(ValueError, match="The start method 'fork' is not available on this platform"): + _MultiProcessingLauncher(strategy=Mock(), start_method="fork") + + +@RunIf(skip_windows=True) +@pytest.mark.parametrize("start_method", ["fork", "forkserver"]) +@mock.patch.dict(os.environ, {"PL_DISABLE_FORK": "1"}, clear=True) +def test_multiprocessing_launcher_disabled_forking(start_method): + with pytest.raises(ValueError, match="Forking is disabled in this environment"): + _MultiProcessingLauncher(strategy=Mock(), start_method=start_method) + + +@pytest.mark.parametrize("start_method", ["spawn", "fork"]) +@mock.patch("lightning_lite.strategies.launchers.multiprocessing.mp") +def test_multiprocessing_launcher_start_method(mp_mock, start_method): + mp_mock.get_all_start_methods.return_value = [start_method] + launcher = _MultiProcessingLauncher(strategy=Mock(), start_method=start_method) + launcher.launch(function=Mock()) + mp_mock.get_context.assert_called_with(start_method) + mp_mock.start_processes.assert_called_with( + ANY, + args=ANY, + nprocs=ANY, + start_method=start_method, + ) + + +@pytest.mark.parametrize("start_method", ["spawn", "fork"]) +@mock.patch("lightning_lite.strategies.launchers.multiprocessing.mp") +def test_multiprocessing_launcher_restore_globals(mp_mock, start_method): + """Test that we pass the global state snapshot to the worker function only if we are starting with 'spawn'.""" + mp_mock.get_all_start_methods.return_value = [start_method] + launcher = _MultiProcessingLauncher(strategy=Mock(), start_method=start_method) + launcher.launch(function=Mock()) + function_args = mp_mock.start_processes.call_args[1]["args"] + if start_method == "spawn": + assert len(function_args) == 5 + assert isinstance(function_args[4], _GlobalStateSnapshot) + else: + assert len(function_args) == 4 + + +@pytest.mark.usefixtures("reset_deterministic_algorithm") +def test_global_state_snapshot(): + """Test the capture() and restore() methods for the global state snapshot.""" + torch.use_deterministic_algorithms(True) + torch.backends.cudnn.benchmark = False + torch.manual_seed(123) + + # capture the state of globals + snapshot = _GlobalStateSnapshot.capture() + + # simulate there is a process boundary and flags get reset here + torch.use_deterministic_algorithms(False) + torch.backends.cudnn.benchmark = True + torch.manual_seed(321) + + # restore the state of globals + snapshot.restore() + assert torch.are_deterministic_algorithms_enabled() + assert not torch.backends.cudnn.benchmark + assert torch.initial_seed() == 123 diff --git a/tests/tests_lite/strategies/launchers/test_subprocess_script.py b/tests/tests_lite/strategies/launchers/test_subprocess_script.py new file mode 100644 index 0000000000..c9af07343b --- /dev/null +++ b/tests/tests_lite/strategies/launchers/test_subprocess_script.py @@ -0,0 +1,78 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os +from unittest import mock +from unittest.mock import Mock + +import pytest + +from lightning_lite.strategies.launchers.subprocess_script import _SubprocessScriptLauncher + + +def test_subprocess_script_launcher_interactive_compatible(): + launcher = _SubprocessScriptLauncher(Mock(), num_processes=2, num_nodes=1) + assert not launcher.is_interactive_compatible + + +@mock.patch("lightning_lite.strategies.launchers.subprocess_script.subprocess.Popen") +def test_subprocess_script_launcher_error_launching_on_non_zero_rank(popen_mock): + cluster_env = Mock() + cluster_env.creates_processes_externally = False + cluster_env.local_rank.return_value = 1 + launcher = _SubprocessScriptLauncher(cluster_env, num_processes=2, num_nodes=1) + with pytest.raises(RuntimeError, match="attempted to launch new distributed processes with `local_rank > 0`"): + launcher.launch(Mock()) + + +@mock.patch("lightning_lite.strategies.launchers.subprocess_script.subprocess.Popen") +def test_subprocess_script_launcher_external_processes(popen_mock): + cluster_env = Mock() + cluster_env.creates_processes_externally = True + function = Mock() + launcher = _SubprocessScriptLauncher(cluster_env, num_processes=4, num_nodes=2) + launcher.launch(function, "positional-arg", keyword_arg=0) + function.assert_called_with("positional-arg", keyword_arg=0) + popen_mock.assert_not_called() + + +@mock.patch("lightning_lite.strategies.launchers.subprocess_script.sleep") +@mock.patch("lightning_lite.strategies.launchers.subprocess_script.subprocess.Popen") +def test_subprocess_script_launcher_launch_processes(popen_mock, _): + cluster_env = Mock() + cluster_env.creates_processes_externally = False + cluster_env.local_rank.return_value = 0 + cluster_env.main_address = "address" + cluster_env.main_port = 1234 + + function = Mock() + launcher = _SubprocessScriptLauncher(cluster_env, num_processes=4, num_nodes=2) + num_new_processes = launcher.num_processes - 1 + + # launches n-1 new processes, the current one will participate too + launcher.launch(function, "positional-arg", keyword_arg=0) + + calls = popen_mock.call_args_list + assert len(calls) == num_new_processes + + # world size in child processes + world_sizes = [int(calls[i][1]["env"]["WORLD_SIZE"]) for i in range(num_new_processes)] + assert world_sizes == [launcher.num_processes * launcher.num_nodes] * num_new_processes + + # local rank in child processes + local_ranks = [int(calls[i][1]["env"]["LOCAL_RANK"]) for i in range(num_new_processes)] + assert local_ranks == list(range(1, num_new_processes + 1)) + + # the current process + assert int(os.environ["WORLD_SIZE"]) == launcher.num_processes * launcher.num_nodes + assert int(os.environ["LOCAL_RANK"]) == 0 diff --git a/tests/tests_lite/strategies/launchers/test_xla.py b/tests/tests_lite/strategies/launchers/test_xla.py new file mode 100644 index 0000000000..0136cb6a27 --- /dev/null +++ b/tests/tests_lite/strategies/launchers/test_xla.py @@ -0,0 +1,39 @@ +from unittest import mock +from unittest.mock import ANY, Mock + +from tests_lite.helpers.runif import RunIf + +from lightning_lite.strategies.launchers.xla import _XLALauncher + + +@RunIf(skip_windows=True) +def test_xla_launcher_default_start_method(): + launcher = _XLALauncher(strategy=Mock()) + assert launcher._start_method == "fork" + + +@RunIf(skip_windows=True) +def test_xla_launcher_interactive_compatible(): + launcher = _XLALauncher(strategy=Mock()) + assert launcher.is_interactive_compatible + + +@RunIf(skip_windows=True) +@mock.patch("lightning_lite.strategies.launchers.xla.mp") +@mock.patch("lightning_lite.strategies.launchers.xla.xm") +@mock.patch("lightning_lite.strategies.launchers.xla.xmp") +def test_xla_launcher_xmp_spawn(xmp_mock, xm_mock, mp_mock): + strategy = Mock() + strategy.parallel_devices = [0, 1, 2, 3] + launcher = _XLALauncher(strategy=strategy) + function = Mock() + launcher.launch(function, "positional-arg", keyword_arg=0) + # mp_mock.get_context.assert_called_with(start_method) + xmp_mock.spawn.assert_called_with( + ANY, + args=(function, ("positional-arg",), {"keyword_arg": 0}, ANY), + nprocs=4, + join=True, + daemon=False, + start_method="fork", + )