lightning/pytorch_lightning/strategies/ddp_spawn.py

407 lines
17 KiB
Python

# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import os
from collections import UserList
from multiprocessing.queues import SimpleQueue
from typing import Any, Callable, Dict, List, NamedTuple, Optional, Union
import numpy as np
import torch
import torch.distributed
import torch.multiprocessing as mp
from torch.nn import Module
from torch.nn.parallel.distributed import DistributedDataParallel
import pytorch_lightning as pl
from pytorch_lightning.overrides import LightningDistributedModule
from pytorch_lightning.overrides.distributed import prepare_for_backward
from pytorch_lightning.plugins.environments.cluster_environment import ClusterEnvironment
from pytorch_lightning.plugins.io.checkpoint_plugin import CheckpointIO
from pytorch_lightning.plugins.precision import PrecisionPlugin
from pytorch_lightning.strategies.parallel import ParallelStrategy
from pytorch_lightning.trainer.states import TrainerFn, TrainerState
from pytorch_lightning.utilities import _TORCH_GREATER_EQUAL_1_8, rank_zero_warn
from pytorch_lightning.utilities.apply_func import apply_to_collection, move_data_to_device
from pytorch_lightning.utilities.distributed import _revert_sync_batchnorm, distributed_available
from pytorch_lightning.utilities.distributed import group as _group
from pytorch_lightning.utilities.distributed import (
init_dist_connection,
rank_zero_debug,
rank_zero_only,
ReduceOp,
sync_ddp_if_available,
)
from pytorch_lightning.utilities.enums import _StrategyType
from pytorch_lightning.utilities.model_helpers import is_overridden
from pytorch_lightning.utilities.seed import reset_seed
from pytorch_lightning.utilities.types import _PATH, STEP_OUTPUT
if _TORCH_GREATER_EQUAL_1_8:
from pytorch_lightning.utilities.distributed import register_ddp_comm_hook
log = logging.getLogger(__name__)
class DDPSpawnStrategy(ParallelStrategy):
"""Spawns processes using the :func:`torch.multiprocessing.spawn` method and joins processes after training
finishes."""
distributed_backend = _StrategyType.DDP_SPAWN
def __init__(
self,
accelerator: Optional["pl.accelerators.accelerator.Accelerator"] = None,
parallel_devices: Optional[List[torch.device]] = None,
cluster_environment: Optional[ClusterEnvironment] = None,
checkpoint_io: Optional[CheckpointIO] = None,
precision_plugin: Optional[PrecisionPlugin] = None,
ddp_comm_state: Optional[object] = None,
ddp_comm_hook: Optional[callable] = None,
ddp_comm_wrapper: Optional[callable] = None,
**kwargs: Any,
):
super().__init__(
accelerator=accelerator,
parallel_devices=parallel_devices,
cluster_environment=cluster_environment,
checkpoint_io=checkpoint_io,
precision_plugin=precision_plugin,
)
self._num_nodes = 1
self.sync_batchnorm = False
self._ddp_kwargs = kwargs
self.num_processes = len(parallel_devices) if parallel_devices is not None else 0
self._ddp_comm_state = ddp_comm_state
self._ddp_comm_hook = ddp_comm_hook
self._ddp_comm_wrapper = ddp_comm_wrapper
self._local_rank = 0
self.set_world_ranks()
@property
def num_nodes(self) -> int:
return self._num_nodes
@num_nodes.setter
def num_nodes(self, num_nodes: int) -> None:
# note that world ranks is related to num_nodes, when resetting it, need to reset world ranks
self._num_nodes = num_nodes
self.set_world_ranks()
@property
def local_rank(self) -> int:
return self._local_rank
@property
def root_device(self):
return self.parallel_devices[self.local_rank]
@property
def distributed_sampler_kwargs(self):
distributed_sampler_kwargs = dict(num_replicas=(self.num_nodes * self.num_processes), rank=self.global_rank)
return distributed_sampler_kwargs
@property
def _is_single_process_single_device(self):
return True
def setup(self, trainer: "pl.Trainer") -> None:
os.environ["MASTER_PORT"] = str(self.cluster_environment.main_port)
super().setup(trainer)
# move the model to the correct device
self.model_to_device()
if self.sync_batchnorm:
self.model = self.configure_sync_batchnorm(self.model)
# skip wrapping the model if we are not fitting as no gradients need to be exchanged
trainer_fn = self.lightning_module.trainer.state.fn
if trainer_fn == TrainerFn.FITTING:
self.configure_ddp()
def _setup_model(self, model: Module) -> DistributedDataParallel:
"""Wraps the model into a :class:`~torch.nn.parallel.distributed.DistributedDataParallel` module."""
return DistributedDataParallel(module=model, device_ids=self.determine_ddp_device_ids(), **self._ddp_kwargs)
def set_world_ranks(self, process_idx: int = 0) -> None:
self._local_rank = process_idx
if self.cluster_environment is None:
return
self.cluster_environment.set_global_rank(self.node_rank * self.num_processes + self.local_rank)
self.cluster_environment.set_world_size(self.num_nodes * self.num_processes)
rank_zero_only.rank = self.cluster_environment.global_rank()
def get_mp_spawn_kwargs(self, trainer: Optional["pl.Trainer"] = None) -> Dict[str, Any]:
return {"nprocs": self.num_processes}
def spawn(self, function: Callable, *args: Any, **kwargs: Any) -> Optional[Union[Any, "_SpawnOutput"]]:
"""Spawn processes that run the given function.
Args:
function: The function to spawn processes from.
*args: Optional positional arguments that will be passed to the function in addition to the process index.
These arguments must be pickleable.
**kwargs: Optional named arguments that will be passed to the function in addition to the process index.
These arguments must be pickleable.
Return:
The output of the function of process 0.
"""
os.environ["MASTER_PORT"] = str(self.cluster_environment.main_port)
context = mp.get_context("spawn")
return_queue = context.SimpleQueue()
mp.spawn(self._wrapped_function, args=(function, args, kwargs, return_queue), nprocs=self.num_processes)
return return_queue.get()
def _wrapped_function(
self, process_idx: int, function: Callable, args: Any, kwargs: Any, return_queue: SimpleQueue
) -> None:
self._worker_setup(process_idx)
result = function(*args, **kwargs)
if self.local_rank == 0:
return_queue.put(move_data_to_device(result, "cpu"))
def _worker_setup(self, process_idx: int):
reset_seed()
self.set_world_ranks(process_idx)
rank_zero_only.rank = self.global_rank
init_dist_connection(
self.cluster_environment, self.torch_distributed_backend, self.global_rank, self.world_size
)
def pre_configure_ddp(self):
# if unset, default `find_unused_parameters` `True`
# Many models require setting this parameter to True, as there are corner cases
# when not all parameter backward hooks are fired by the autograd engine even if require_grad is set to True.
# This flag does come with a performance hit, so it is suggested to disable in cases where it is possible.
self._ddp_kwargs["find_unused_parameters"] = self._ddp_kwargs.get("find_unused_parameters", True)
if not self.lightning_module.automatic_optimization and not self._ddp_kwargs.get(
"find_unused_parameters", False
):
# TODO: PyTorch 1.7.0 DDP introduces `self.reducer._rebuild_buckets()` breaking manual_optimization
rank_zero_warn(
"From PyTorch 1.7.0, Lightning `manual_optimization` needs to set `find_unused_parameters=True` to"
" properly work with DDP. Using `find_unused_parameters=True`."
)
self._ddp_kwargs["find_unused_parameters"] = True
def _register_ddp_hooks(self) -> None:
# currently, DDP communication hooks only work with NCCL backend and SPSD (single process single device) mode
# https://github.com/pytorch/pytorch/blob/v1.8.0/torch/nn/parallel/distributed.py#L1080-L1084
if _TORCH_GREATER_EQUAL_1_8 and self.on_gpu and self._is_single_process_single_device:
register_ddp_comm_hook(
model=self.model,
ddp_comm_state=self._ddp_comm_state,
ddp_comm_hook=self._ddp_comm_hook,
ddp_comm_wrapper=self._ddp_comm_wrapper,
)
def configure_ddp(self) -> None:
self.pre_configure_ddp()
self.model = self._setup_model(LightningDistributedModule(self.model))
self._register_ddp_hooks()
def determine_ddp_device_ids(self):
if self.root_device.type == "cpu":
return None
return [self.root_device.index]
def _collect_rank_zero_results(self, trainer: "pl.Trainer", results: Any) -> Optional["_SpawnOutput"]:
rank_zero_debug("Finalizing the DDP spawn environment.")
checkpoint_callback = trainer.checkpoint_callback
best_model_path = checkpoint_callback.best_model_path if checkpoint_callback else None
# requires to compute the state_dict on all processes in case Metrics are present
state_dict = self.lightning_module.state_dict()
if self.global_rank != 0:
return
# save the last weights
weights_path = None
if trainer.state.fn == TrainerFn.FITTING:
weights_path = os.path.join(trainer.default_root_dir, ".temp.ckpt")
self.checkpoint_io.save_checkpoint(state_dict, weights_path)
# adds the `callback_metrics` to the queue
extra = _FakeQueue()
if is_overridden("add_to_queue", self.lightning_module):
# TODO: Remove the if in v1.7
self.lightning_module.add_to_queue(extra)
self.add_to_queue(trainer, extra)
return _SpawnOutput(best_model_path, weights_path, trainer.state, results, extra)
def _recover_results_in_main_process(self, spawn_output: "_SpawnOutput", trainer: "pl.Trainer") -> None:
# transfer back the best path to the trainer
if trainer.checkpoint_callback:
trainer.checkpoint_callback.best_model_path = spawn_output.best_model_path
# TODO: pass also best score
# load last weights
if spawn_output.weights_path is not None:
ckpt = self.checkpoint_io.load_checkpoint(
spawn_output.weights_path, map_location=(lambda storage, loc: storage)
)
self.lightning_module.load_state_dict(ckpt)
self.checkpoint_io.remove_checkpoint(spawn_output.weights_path)
trainer.state = spawn_output.trainer_state
# get the `callback_metrics` and set it to the trainer
if is_overridden("get_from_queue", self.lightning_module):
# only in case the user does not override it.
# TODO: Remove the if in v1.7
self.lightning_module.get_from_queue(spawn_output.extra)
self.get_from_queue(trainer, spawn_output.extra)
def barrier(self, *args, **kwargs) -> None:
if not distributed_available():
return
if _TORCH_GREATER_EQUAL_1_8 and torch.distributed.get_backend() == "nccl":
torch.distributed.barrier(device_ids=self.determine_ddp_device_ids())
else:
torch.distributed.barrier()
def broadcast(self, obj: object, src: int = 0) -> object:
if not distributed_available():
return obj
obj = [obj]
if self.global_rank != src:
obj = [None]
torch.distributed.broadcast_object_list(obj, src, group=_group.WORLD)
return obj[0]
def model_to_device(self):
if self.root_device.type == "cuda":
# set the device on the spawned subprocesses
torch.cuda.set_device(self.root_device)
self.model.to(self.root_device)
def pre_backward(self, closure_loss: torch.Tensor) -> None:
"""Run before precision plugin executes backward."""
if not self.lightning_module.automatic_optimization:
prepare_for_backward(self.model, closure_loss)
def reduce(self, tensor, group: Optional[Any] = None, reduce_op: Union[ReduceOp, str] = "mean") -> torch.Tensor:
"""Reduces a tensor from several distributed processes to one aggregated tensor.
Args:
tensor: the tensor to sync and reduce
group: the process group to gather results from. Defaults to all processes (world)
reduce_op: the reduction operation. Defaults to 'mean'/'avg'.
Can also be a string 'sum' to calculate the sum during reduction.
Return:
reduced value, except when the input was not a tensor the output remains is unchanged
"""
if isinstance(tensor, torch.Tensor):
tensor = sync_ddp_if_available(tensor, group, reduce_op=reduce_op)
return tensor
def training_step(self, *args, **kwargs) -> STEP_OUTPUT:
with self.precision_plugin.train_step_context():
return self.model(*args, **kwargs)
def validation_step(self, *args, **kwargs) -> Optional[STEP_OUTPUT]:
with self.precision_plugin.val_step_context():
if isinstance(self.model, DistributedDataParallel):
# used when calling `trainer.fit`
return self.model(*args, **kwargs)
else:
# used when calling `trainer.validate`
return self.lightning_module.validation_step(*args, **kwargs)
def test_step(self, *args, **kwargs) -> Optional[STEP_OUTPUT]:
with self.precision_plugin.test_step_context():
return self.lightning_module.test_step(*args, **kwargs)
def predict_step(self, *args, **kwargs) -> STEP_OUTPUT:
with self.precision_plugin.predict_step_context():
return self.lightning_module.predict_step(*args, **kwargs)
def post_training_step(self):
if not self.lightning_module.automatic_optimization:
self.model.require_backward_grad_sync = True
def add_to_queue(self, trainer: "pl.Trainer", queue: "_FakeQueue") -> None:
"""Appends the :attr:`trainer.callback_metrics` dictionary to the given queue. To avoid issues with memory
sharing, we cast the data to numpy.
Args:
trainer: reference to the Trainer.
queue: the instance of the queue to append the data.
"""
callback_metrics: dict = apply_to_collection(
trainer.callback_metrics, torch.Tensor, lambda x: x.cpu().numpy()
) # send as numpy to avoid issues with memory sharing
queue.put(callback_metrics)
def get_from_queue(self, trainer: "pl.Trainer", queue: "_FakeQueue") -> None:
"""Retrieve the :attr:`trainer.callback_metrics` dictionary from the given queue. To preserve consistency,
we cast back the data to ``torch.Tensor``.
Args:
trainer: reference to the Trainer.
queue: the instance of the queue from where to get the data.
"""
# NOTE: `add_to_queue` needs to be called before
callback_metrics: dict = queue.get()
trainer.callback_metrics.update(apply_to_collection(callback_metrics, np.ndarray, lambda x: torch.tensor(x)))
@classmethod
def register_strategies(cls, strategy_registry: Dict) -> None:
strategy_registry.register(
"ddp_spawn_find_unused_parameters_false",
cls,
description="DDPSpawn Strategy with `find_unused_parameters` as False",
find_unused_parameters=False,
)
def teardown(self) -> None:
super().teardown()
if isinstance(self.model, DistributedDataParallel):
self.model = self.lightning_module
if self.sync_batchnorm:
self.model = _revert_sync_batchnorm(self.model)
if self.on_gpu:
# GPU teardown
self.lightning_module.cpu()
# clean up memory
torch.cuda.empty_cache()
class _FakeQueue(UserList):
"""Simulates a :class:`torch.multiprocessing.queue.SimpleQueue` interface using the Python list."""
def get(self) -> Any:
return self.pop(0)
def put(self, item: Any) -> None:
self.append(item)
def empty(self) -> bool:
return len(self) == 0
class _SpawnOutput(NamedTuple):
best_model_path: Optional[_PATH]
weights_path: Optional[_PATH]
trainer_state: TrainerState
trainer_results: Any
extra: _FakeQueue