# Copyright The PyTorch Lightning team. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import os from collections import UserList from multiprocessing.queues import SimpleQueue from typing import Any, Callable, NamedTuple, Optional import numpy as np import torch import torch.multiprocessing as mp import pytorch_lightning as pl from pytorch_lightning.strategies.launchers.base import _Launcher from pytorch_lightning.strategies.strategy import Strategy from pytorch_lightning.trainer.states import TrainerFn, TrainerState from pytorch_lightning.utilities.apply_func import apply_to_collection, move_data_to_device from pytorch_lightning.utilities.model_helpers import is_overridden from pytorch_lightning.utilities.rank_zero import rank_zero_debug from pytorch_lightning.utilities.types import _PATH class _SpawnLauncher(_Launcher): r"""Spawns processes that run a given function in parallel, and joins them all at the end. The main process in which this launcher is invoked creates N so-called worker processes (using :func:`torch.multiprocessing.spawn`) that run the given function. Worker processes have a rank that ranges from 0 to N - 1. Note: - This launcher requires all objects to be pickleable. - It is important that the entry point to the program/script is guarded by ``if __name__ == "__main__"``. Args: strategy: A reference to the strategy that is used together with this launcher. """ def __init__(self, strategy: Strategy) -> None: self._strategy = strategy def launch(self, function: Callable, *args: Any, **kwargs: Any) -> Any: """Spawns processes that run the given function in parallel. The function is allowed to have a return value. However, when all processes join, only the return value of worker process 0 gets returned from this `launch` method in the main process. Arguments: function: The entry point for all spawned processes. *args: Optional positional arguments to be passed to the given function. **kwargs: Optional keyword arguments to be passed to the given function. If a keyword argument named `trainer` is present and is an instance of :class:`~pytorch_lightning.trainer.trainer.Trainer`, a selected set of attributes from the trainer get restored in the main process after processes join. The `trainer` keyword argument will NOT be passed into the function. """ trainer = kwargs.pop("trainer", None) os.environ["MASTER_PORT"] = str(self._strategy.cluster_environment.main_port) context = mp.get_context("spawn") return_queue = context.SimpleQueue() mp.spawn( self._wrapping_function, args=(trainer, function, args, kwargs, return_queue), nprocs=self._strategy.num_processes, ) spawn_output = return_queue.get() if trainer is None: return spawn_output self._recover_results_in_main_process(spawn_output, trainer) return spawn_output.trainer_results def _wrapping_function( self, process_idx: int, trainer: Optional["pl.Trainer"], function: Callable, args: Any, kwargs: Any, return_queue: SimpleQueue, ) -> None: self._strategy._worker_setup(process_idx) results = function(*args, **kwargs) if trainer is not None: results = self._collect_rank_zero_results(trainer, results) if self._strategy.local_rank == 0: return_queue.put(move_data_to_device(results, "cpu")) def _recover_results_in_main_process(self, spawn_output: "_SpawnOutput", trainer: "pl.Trainer") -> None: # transfer back the best path to the trainer if trainer.checkpoint_callback: trainer.checkpoint_callback.best_model_path = str(spawn_output.best_model_path) # TODO: pass also best score # load last weights if spawn_output.weights_path is not None: ckpt = self._strategy.checkpoint_io.load_checkpoint(spawn_output.weights_path) trainer.lightning_module.load_state_dict(ckpt) # type: ignore[arg-type] self._strategy.checkpoint_io.remove_checkpoint(spawn_output.weights_path) trainer.state = spawn_output.trainer_state # get the `callback_metrics` and set it to the trainer if is_overridden("get_from_queue", trainer.lightning_module): # only in case the user does not override it. # TODO: Remove the if in v1.7 trainer.lightning_module.get_from_queue(spawn_output.extra) self.get_from_queue(trainer, spawn_output.extra) def _collect_rank_zero_results(self, trainer: "pl.Trainer", results: Any) -> Optional["_SpawnOutput"]: rank_zero_debug("Finalizing the DDP spawn environment.") checkpoint_callback = trainer.checkpoint_callback best_model_path = checkpoint_callback.best_model_path if checkpoint_callback else None # requires to compute the state_dict on all processes in case Metrics are present state_dict = trainer.lightning_module.state_dict() if self._strategy.global_rank != 0: return None # save the last weights weights_path = None if trainer.state.fn == TrainerFn.FITTING: weights_path = os.path.join(trainer.default_root_dir, ".temp.ckpt") self._strategy.checkpoint_io.save_checkpoint(state_dict, weights_path) # adds the `callback_metrics` to the queue extra = _FakeQueue() if is_overridden("add_to_queue", trainer.lightning_module): # TODO: Remove the if in v1.7 trainer.lightning_module.add_to_queue(extra) self.add_to_queue(trainer, extra) return _SpawnOutput(best_model_path, weights_path, trainer.state, results, extra) def add_to_queue(self, trainer: "pl.Trainer", queue: "_FakeQueue") -> None: """Appends the :attr:`trainer.callback_metrics` dictionary to the given queue. To avoid issues with memory sharing, we cast the data to numpy. Args: trainer: reference to the Trainer. queue: the instance of the queue to append the data. """ callback_metrics: dict = apply_to_collection( trainer.callback_metrics, torch.Tensor, lambda x: x.cpu().numpy() ) # send as numpy to avoid issues with memory sharing queue.put(callback_metrics) def get_from_queue(self, trainer: "pl.Trainer", queue: "_FakeQueue") -> None: """Retrieve the :attr:`trainer.callback_metrics` dictionary from the given queue. To preserve consistency, we cast back the data to ``torch.Tensor``. Args: trainer: reference to the Trainer. queue: the instance of the queue from where to get the data. """ # NOTE: `add_to_queue` needs to be called before callback_metrics: dict = queue.get() trainer.callback_metrics.update(apply_to_collection(callback_metrics, np.ndarray, lambda x: torch.tensor(x))) class _FakeQueue(UserList): """Simulates a :class:`torch.multiprocessing.queue.SimpleQueue` interface using the Python list.""" def get(self) -> Any: return self.pop(0) def put(self, item: Any) -> None: self.append(item) def empty(self) -> bool: return len(self) == 0 class _SpawnOutput(NamedTuple): best_model_path: Optional[_PATH] weights_path: Optional[_PATH] trainer_state: TrainerState trainer_results: Any extra: _FakeQueue