lightning/pytorch_lightning/strategies/launchers/spawn.py

192 lines
8.0 KiB
Python
Raw Normal View History

# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
from collections import UserList
from multiprocessing.queues import SimpleQueue
from typing import Any, Callable, NamedTuple, Optional
import numpy as np
import torch
import torch.multiprocessing as mp
import pytorch_lightning as pl
from pytorch_lightning.strategies.launchers.base import _Launcher
from pytorch_lightning.strategies.strategy import Strategy
from pytorch_lightning.trainer.states import TrainerFn, TrainerState
from pytorch_lightning.utilities.apply_func import apply_to_collection, move_data_to_device
from pytorch_lightning.utilities.model_helpers import is_overridden
from pytorch_lightning.utilities.rank_zero import rank_zero_debug
from pytorch_lightning.utilities.types import _PATH
class _SpawnLauncher(_Launcher):
r"""Spawns processes that run a given function in parallel, and joins them all at the end.
The main process in which this launcher is invoked creates N so-called worker processes (using
:func:`torch.multiprocessing.spawn`) that run the given function.
Worker processes have a rank that ranges from 0 to N - 1.
Note:
- This launcher requires all objects to be pickleable.
- It is important that the entry point to the program/script is guarded by ``if __name__ == "__main__"``.
Args:
strategy: A reference to the strategy that is used together with this launcher.
"""
def __init__(self, strategy: Strategy) -> None:
self._strategy = strategy
def launch(self, function: Callable, *args: Any, **kwargs: Any) -> Any:
"""Spawns processes that run the given function in parallel.
The function is allowed to have a return value. However, when all processes join, only the return value
of worker process 0 gets returned from this `launch` method in the main process.
Arguments:
function: The entry point for all spawned processes.
*args: Optional positional arguments to be passed to the given function.
**kwargs: Optional keyword arguments to be passed to the given function.
If a keyword argument named `trainer` is present and is an instance of
:class:`~pytorch_lightning.trainer.trainer.Trainer`, a selected set of attributes from the trainer get
restored in the main process after processes join. The `trainer` keyword argument will NOT be passed
into the function.
"""
trainer = kwargs.pop("trainer", None)
os.environ["MASTER_PORT"] = str(self._strategy.cluster_environment.main_port)
context = mp.get_context("spawn")
return_queue = context.SimpleQueue()
mp.spawn(
self._wrapping_function,
args=(trainer, function, args, kwargs, return_queue),
nprocs=self._strategy.num_processes,
)
spawn_output = return_queue.get()
if trainer is None:
return spawn_output
self._recover_results_in_main_process(spawn_output, trainer)
return spawn_output.trainer_results
def _wrapping_function(
self,
process_idx: int,
trainer: Optional["pl.Trainer"],
function: Callable,
args: Any,
kwargs: Any,
return_queue: SimpleQueue,
) -> None:
self._strategy._worker_setup(process_idx)
results = function(*args, **kwargs)
if trainer is not None:
results = self._collect_rank_zero_results(trainer, results)
if self._strategy.local_rank == 0:
return_queue.put(move_data_to_device(results, "cpu"))
def _recover_results_in_main_process(self, spawn_output: "_SpawnOutput", trainer: "pl.Trainer") -> None:
# transfer back the best path to the trainer
if trainer.checkpoint_callback:
trainer.checkpoint_callback.best_model_path = str(spawn_output.best_model_path)
# TODO: pass also best score
# load last weights
if spawn_output.weights_path is not None:
ckpt = self._strategy.checkpoint_io.load_checkpoint(spawn_output.weights_path)
trainer.lightning_module.load_state_dict(ckpt) # type: ignore[arg-type]
self._strategy.checkpoint_io.remove_checkpoint(spawn_output.weights_path)
trainer.state = spawn_output.trainer_state
# get the `callback_metrics` and set it to the trainer
if is_overridden("get_from_queue", trainer.lightning_module):
# only in case the user does not override it.
# TODO: Remove the if in v1.7
trainer.lightning_module.get_from_queue(spawn_output.extra)
self.get_from_queue(trainer, spawn_output.extra)
def _collect_rank_zero_results(self, trainer: "pl.Trainer", results: Any) -> Optional["_SpawnOutput"]:
rank_zero_debug("Finalizing the DDP spawn environment.")
checkpoint_callback = trainer.checkpoint_callback
best_model_path = checkpoint_callback.best_model_path if checkpoint_callback else None
# requires to compute the state_dict on all processes in case Metrics are present
state_dict = trainer.lightning_module.state_dict()
if self._strategy.global_rank != 0:
return None
# save the last weights
weights_path = None
if trainer.state.fn == TrainerFn.FITTING:
weights_path = os.path.join(trainer.default_root_dir, ".temp.ckpt")
self._strategy.checkpoint_io.save_checkpoint(state_dict, weights_path)
# adds the `callback_metrics` to the queue
extra = _FakeQueue()
if is_overridden("add_to_queue", trainer.lightning_module):
# TODO: Remove the if in v1.7
trainer.lightning_module.add_to_queue(extra)
self.add_to_queue(trainer, extra)
return _SpawnOutput(best_model_path, weights_path, trainer.state, results, extra)
def add_to_queue(self, trainer: "pl.Trainer", queue: "_FakeQueue") -> None:
"""Appends the :attr:`trainer.callback_metrics` dictionary to the given queue. To avoid issues with memory
sharing, we cast the data to numpy.
Args:
trainer: reference to the Trainer.
queue: the instance of the queue to append the data.
"""
callback_metrics: dict = apply_to_collection(
trainer.callback_metrics, torch.Tensor, lambda x: x.cpu().numpy()
) # send as numpy to avoid issues with memory sharing
queue.put(callback_metrics)
def get_from_queue(self, trainer: "pl.Trainer", queue: "_FakeQueue") -> None:
"""Retrieve the :attr:`trainer.callback_metrics` dictionary from the given queue. To preserve consistency,
we cast back the data to ``torch.Tensor``.
Args:
trainer: reference to the Trainer.
queue: the instance of the queue from where to get the data.
"""
# NOTE: `add_to_queue` needs to be called before
callback_metrics: dict = queue.get()
trainer.callback_metrics.update(apply_to_collection(callback_metrics, np.ndarray, lambda x: torch.tensor(x)))
class _FakeQueue(UserList):
"""Simulates a :class:`torch.multiprocessing.queue.SimpleQueue` interface using the Python list."""
def get(self) -> Any:
return self.pop(0)
def put(self, item: Any) -> None:
self.append(item)
def empty(self) -> bool:
return len(self) == 0
class _SpawnOutput(NamedTuple):
best_model_path: Optional[_PATH]
weights_path: Optional[_PATH]
trainer_state: TrainerState
trainer_results: Any
extra: _FakeQueue