420 lines
16 KiB
Python
420 lines
16 KiB
Python
# Copyright The PyTorch Lightning team.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
"""Utilities that can be used with distributed training."""
|
|
|
|
import logging
|
|
import os
|
|
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
|
|
|
import torch
|
|
from torch.nn import Module
|
|
from torch.nn.parallel.distributed import DistributedDataParallel
|
|
|
|
import pytorch_lightning as pl
|
|
from pytorch_lightning.utilities.imports import _TORCH_GREATER_EQUAL_1_8, _TORCH_GREATER_EQUAL_1_9, _TPU_AVAILABLE
|
|
from pytorch_lightning.utilities.rank_zero import rank_zero_debug as new_rank_zero_debug
|
|
from pytorch_lightning.utilities.rank_zero import rank_zero_only # noqa: F401
|
|
from pytorch_lightning.utilities.rank_zero import rank_zero_deprecation
|
|
from pytorch_lightning.utilities.rank_zero import rank_zero_info as new_rank_zero_info
|
|
|
|
if _TPU_AVAILABLE:
|
|
import torch_xla.core.xla_model as xm
|
|
|
|
if torch.distributed.is_available():
|
|
from torch.distributed import group, ReduceOp
|
|
|
|
else:
|
|
|
|
class ReduceOp: # type: ignore # (see https://github.com/python/mypy/issues/1153)
|
|
SUM = None
|
|
|
|
class group: # type: ignore
|
|
WORLD = None
|
|
|
|
|
|
log = logging.getLogger(__name__)
|
|
|
|
|
|
def gather_all_tensors(result: torch.Tensor, group: Optional[Any] = None) -> List[torch.Tensor]:
|
|
"""Function to gather all tensors from several ddp processes onto a list that is broadcasted to all processes.
|
|
|
|
Args:
|
|
result: the value to sync
|
|
group: the process group to gather results from. Defaults to all processes (world)
|
|
|
|
Return:
|
|
gathered_result: list with size equal to the process group where
|
|
gathered_result[i] corresponds to result tensor from process i
|
|
"""
|
|
if group is None:
|
|
group = torch.distributed.group.WORLD
|
|
|
|
# convert tensors to contiguous format
|
|
result = result.contiguous()
|
|
|
|
world_size = torch.distributed.get_world_size(group)
|
|
|
|
gathered_result = [torch.zeros_like(result) for _ in range(world_size)]
|
|
|
|
# sync and broadcast all
|
|
torch.distributed.barrier(group=group)
|
|
torch.distributed.all_gather(gathered_result, result, group)
|
|
|
|
return gathered_result
|
|
|
|
|
|
def distributed_available() -> bool:
|
|
return torch.distributed.is_available() and torch.distributed.is_initialized() or tpu_distributed()
|
|
|
|
|
|
def sync_ddp_if_available(
|
|
result: torch.Tensor, group: Optional[Any] = None, reduce_op: Optional[Union[ReduceOp, str]] = None
|
|
) -> torch.Tensor:
|
|
"""Function to reduce a tensor across worker processes during distributed training.
|
|
|
|
Args:
|
|
result: the value to sync and reduce (typically tensor or number)
|
|
group: the process group to gather results from. Defaults to all processes (world)
|
|
reduce_op: the reduction operation. Defaults to sum.
|
|
Can also be a string of 'avg', 'mean' to calculate the mean during reduction.
|
|
|
|
Return:
|
|
reduced value
|
|
"""
|
|
if distributed_available():
|
|
return sync_ddp(result, group=group, reduce_op=reduce_op)
|
|
return result
|
|
|
|
|
|
def sync_ddp(
|
|
result: torch.Tensor, group: Optional[Any] = None, reduce_op: Optional[Union[ReduceOp, str]] = None
|
|
) -> torch.Tensor:
|
|
"""Function to reduce the tensors from several ddp processes to one main process.
|
|
|
|
Args:
|
|
result: the value to sync and reduce (typically tensor or number)
|
|
group: the process group to gather results from. Defaults to all processes (world)
|
|
reduce_op: the reduction operation. Defaults to sum.
|
|
Can also be a string of 'avg', 'mean' to calculate the mean during reduction.
|
|
|
|
Return:
|
|
reduced value
|
|
"""
|
|
divide_by_world_size = False
|
|
|
|
if group is None:
|
|
group = torch.distributed.group.WORLD
|
|
|
|
if isinstance(reduce_op, str):
|
|
if reduce_op.lower() in ("avg", "mean"):
|
|
op = ReduceOp.SUM
|
|
divide_by_world_size = True
|
|
else:
|
|
op = getattr(ReduceOp, reduce_op.upper())
|
|
else:
|
|
op = reduce_op
|
|
|
|
# sync all processes before reduction
|
|
torch.distributed.barrier(group=group)
|
|
torch.distributed.all_reduce(result, op=op, group=group, async_op=False)
|
|
|
|
if divide_by_world_size:
|
|
result = result / torch.distributed.get_world_size(group)
|
|
|
|
return result
|
|
|
|
|
|
class AllGatherGrad(torch.autograd.Function):
|
|
@staticmethod
|
|
def forward(
|
|
ctx: Any,
|
|
tensor: torch.Tensor,
|
|
group: Optional["torch.distributed.ProcessGroup"] = group.WORLD,
|
|
) -> torch.Tensor:
|
|
ctx.group = group
|
|
|
|
gathered_tensor = [torch.zeros_like(tensor) for _ in range(torch.distributed.get_world_size())]
|
|
|
|
torch.distributed.all_gather(gathered_tensor, tensor, group=group)
|
|
gathered_tensor = torch.stack(gathered_tensor, dim=0)
|
|
|
|
return gathered_tensor
|
|
|
|
@staticmethod
|
|
def backward(ctx: Any, *grad_output: torch.Tensor) -> Tuple[torch.Tensor, None]:
|
|
grad_output = torch.cat(grad_output)
|
|
|
|
torch.distributed.all_reduce(grad_output, op=torch.distributed.ReduceOp.SUM, async_op=False, group=ctx.group)
|
|
|
|
return grad_output[torch.distributed.get_rank()], None
|
|
|
|
|
|
def all_gather_ddp_if_available(
|
|
tensor: torch.Tensor, group: Optional["torch.distributed.ProcessGroup"] = None, sync_grads: bool = False
|
|
) -> torch.Tensor:
|
|
"""Function to gather a tensor from several distributed processes.
|
|
|
|
Args:
|
|
tensor: tensor of shape (batch, ...)
|
|
group: the process group to gather results from. Defaults to all processes (world)
|
|
sync_grads: flag that allows users to synchronize gradients for all_gather op
|
|
|
|
Return:
|
|
A tensor of shape (world_size, batch, ...)
|
|
"""
|
|
group = group if group is not None else torch.distributed.group.WORLD
|
|
if distributed_available():
|
|
if sync_grads:
|
|
return AllGatherGrad.apply(tensor, group)
|
|
with torch.no_grad():
|
|
return AllGatherGrad.apply(tensor, group)
|
|
return tensor
|
|
|
|
|
|
def register_ddp_comm_hook(
|
|
model: DistributedDataParallel,
|
|
ddp_comm_state: Optional[object] = None,
|
|
ddp_comm_hook: Optional[Callable] = None,
|
|
ddp_comm_wrapper: Optional[Callable] = None,
|
|
) -> None:
|
|
"""Function to register communication hook for DDP model https://pytorch.org/docs/master/ddp_comm_hooks.html.
|
|
|
|
Args:
|
|
model:
|
|
DDP model
|
|
ddp_comm_state:
|
|
state is passed to the hook and can be used to maintain
|
|
and update any state information that users would like to
|
|
maintain as part of the training process. Examples: error
|
|
feedback in gradient compression, peers to communicate with
|
|
next in GossipGrad etc.
|
|
ddp_comm_hook:
|
|
hook(state: object, bucket: dist._GradBucket) -> torch.futures.Future
|
|
|
|
This callable function is called once the bucket is ready. The
|
|
hook can perform whatever processing is needed and return
|
|
a Future indicating completion of any async work (ex: allreduce).
|
|
If the hook doesn't perform any communication, it can also
|
|
just return a completed Future. The Future should hold the
|
|
new value of grad bucket's tensors. Once a bucket is ready,
|
|
c10d reducer would call this hook and use the tensors returned
|
|
by the Future and copy grads to individual parameters.
|
|
ddp_comm_wrapper:
|
|
communication hook wrapper to support a communication hook such
|
|
as FP16 compression as wrapper, which could be combined with
|
|
ddp_comm_hook
|
|
|
|
.. warning ::
|
|
DDP communication hook needs pytorch version at least 1.8.0
|
|
|
|
.. warning ::
|
|
DDP communication wrapper needs pytorch version at least 1.9.0
|
|
Post-localSGD hook needs pytorch version at least 1.9.0
|
|
|
|
Examples:
|
|
|
|
>>> from torch.distributed.algorithms.ddp_comm_hooks import ( # doctest: +SKIP
|
|
... default_hooks as default,
|
|
... powerSGD_hook as powerSGD,
|
|
... post_localSGD_hook as post_localSGD,
|
|
... )
|
|
>>>
|
|
>>> # fp16_compress_hook for compress gradients
|
|
>>> ddp_model = ...
|
|
>>> register_ddp_comm_hook( # doctest: +SKIP
|
|
... model=ddp_model,
|
|
... ddp_comm_hook=default.fp16_compress_hook,
|
|
... )
|
|
>>>
|
|
>>> # powerSGD_hook
|
|
>>> ddp_model = ...
|
|
>>> register_ddp_comm_hook( # doctest: +SKIP
|
|
... model=ddp_model,
|
|
... ddp_comm_state=powerSGD.PowerSGDState(
|
|
... process_group=None,
|
|
... matrix_approximation_rank=1,
|
|
... start_powerSGD_iter=5000,
|
|
... ),
|
|
... ddp_comm_hook=powerSGD.powerSGD_hook,
|
|
... )
|
|
>>>
|
|
>>> # post_localSGD_hook
|
|
>>> subgroup, _ = torch.distributed.new_subgroups() # doctest: +SKIP
|
|
>>> ddp_model = ...
|
|
>>> register_ddp_comm_hook( # doctest: +SKIP
|
|
... model=ddp_model,
|
|
... state=post_localSGD.PostLocalSGDState(
|
|
... process_group=None,
|
|
... subgroup=subgroup,
|
|
... start_localSGD_iter=1_000,
|
|
... ),
|
|
... ddp_comm_hook=post_localSGD.post_localSGD_hook,
|
|
... )
|
|
>>>
|
|
>>> # fp16_compress_wrapper combined with other communication hook
|
|
>>> ddp_model = ...
|
|
>>> register_ddp_comm_hook( # doctest: +SKIP
|
|
... model=ddp_model,
|
|
... ddp_comm_state=powerSGD.PowerSGDState(
|
|
... process_group=None,
|
|
... matrix_approximation_rank=1,
|
|
... start_powerSGD_iter=5000,
|
|
... ),
|
|
... ddp_comm_hook=powerSGD.powerSGD_hook,
|
|
... ddp_comm_wrapper=default.fp16_compress_wrapper,
|
|
... )
|
|
"""
|
|
from pytorch_lightning.utilities import rank_zero_warn
|
|
|
|
if not _TORCH_GREATER_EQUAL_1_8:
|
|
rank_zero_warn("Not registering DDP comm hook. To use communication hooks, please use pytorch>=1.8.0.")
|
|
return
|
|
if ddp_comm_hook is None:
|
|
return
|
|
# inform mypy that ddp_comm_hook is callable
|
|
ddp_comm_hook: Callable = ddp_comm_hook
|
|
|
|
if ddp_comm_wrapper is not None:
|
|
if not _TORCH_GREATER_EQUAL_1_9:
|
|
rank_zero_warn("Not applying DDP comm wrapper. To use communication wrapper, please use pytorch>=1.9.0.")
|
|
else:
|
|
new_rank_zero_info(
|
|
f"DDP comm wrapper is provided, apply {ddp_comm_wrapper.__qualname__}({ddp_comm_hook.__qualname__})."
|
|
)
|
|
ddp_comm_hook = ddp_comm_wrapper(ddp_comm_hook)
|
|
|
|
rank_zero_debug(f"Registering DDP comm hook: {ddp_comm_hook.__qualname__}.")
|
|
model.register_comm_hook(state=ddp_comm_state, hook=ddp_comm_hook)
|
|
|
|
|
|
def tpu_distributed() -> bool:
|
|
return _TPU_AVAILABLE and xm.xrt_world_size() > 1
|
|
|
|
|
|
def init_dist_connection(
|
|
cluster_environment: "pl.plugins.environments.ClusterEnvironment",
|
|
torch_distributed_backend: str,
|
|
global_rank: Optional[int] = None,
|
|
world_size: Optional[int] = None,
|
|
**kwargs: Any,
|
|
) -> None:
|
|
"""Utility function to initialize distributed connection by setting env variables and initializing the
|
|
distributed process group.
|
|
|
|
Args:
|
|
cluster_environment: ``ClusterEnvironment`` instance
|
|
torch_distributed_backend: backend to use (includes `nccl` and `gloo`)
|
|
global_rank: rank of the current process
|
|
world_size: number of processes in the group
|
|
kwargs: kwargs for ``init_process_group``
|
|
|
|
Raises:
|
|
RuntimeError:
|
|
If ``torch.distributed`` is not available
|
|
"""
|
|
if not torch.distributed.is_available():
|
|
raise RuntimeError("torch.distributed is not available. Cannot initialize distributed process group")
|
|
if torch.distributed.is_initialized():
|
|
log.debug("torch.distributed is already initialized. Exiting early")
|
|
return
|
|
global_rank = global_rank if global_rank is not None else cluster_environment.global_rank()
|
|
world_size = world_size if world_size is not None else cluster_environment.world_size()
|
|
os.environ["MASTER_ADDR"] = cluster_environment.main_address
|
|
os.environ["MASTER_PORT"] = str(cluster_environment.main_port)
|
|
log.info(f"Initializing distributed: GLOBAL_RANK: {global_rank}, MEMBER: {global_rank + 1}/{world_size}")
|
|
torch.distributed.init_process_group(torch_distributed_backend, rank=global_rank, world_size=world_size, **kwargs)
|
|
|
|
# on rank=0 let everyone know training is starting
|
|
new_rank_zero_info(
|
|
f"{'-' * 100}\n"
|
|
f"distributed_backend={torch_distributed_backend}\n"
|
|
f"All distributed processes registered. Starting with {world_size} processes\n"
|
|
f"{'-' * 100}\n"
|
|
)
|
|
|
|
|
|
def _broadcast_object_list(obj: Any, rank: int) -> Any:
|
|
objects = [obj if torch.distributed.get_rank() == rank else None]
|
|
torch.distributed.broadcast_object_list(objects, src=rank)
|
|
return objects[0]
|
|
|
|
|
|
# TODO: Refactor with the Strategy Collectives once finalized.
|
|
def _collect_states_on_rank_zero(state: Dict[str, Any]) -> Dict[int, Any]:
|
|
"""This distributed utility collects dictionary state across all processes.
|
|
|
|
Args:
|
|
state: Dictionary containing the state of the current process
|
|
device: Current process device.
|
|
|
|
Returns:
|
|
states: On global rank 0, a dictionary where the primary keys are
|
|
the process rank and the values their associated states. Otherwise, returns None.
|
|
"""
|
|
if not distributed_available():
|
|
return {0: state}
|
|
return {rank: _broadcast_object_list(state, rank) for rank in range(torch.distributed.get_world_size())}
|
|
|
|
|
|
class _BatchNormXd(torch.nn.modules.batchnorm._BatchNorm):
|
|
def _check_input_dim(self, input: torch.Tensor) -> None:
|
|
# The only difference between BatchNorm1d, BatchNorm2d, BatchNorm3d, etc
|
|
# is this method that is overwritten by the subclass.
|
|
# Here, we are bypassing some tensor sanity checks and trusting that the user
|
|
# provides the right input dimensions at inference.
|
|
return
|
|
|
|
|
|
def _revert_sync_batchnorm(module: Module) -> Module:
|
|
# Code adapted from https://github.com/pytorch/pytorch/issues/41081#issuecomment-783961547
|
|
# Original author: Kapil Yedidi (@kapily)
|
|
converted_module = module
|
|
if isinstance(module, torch.nn.modules.batchnorm.SyncBatchNorm):
|
|
# Unfortunately, SyncBatchNorm does not store the original class - if it did
|
|
# we could return the one that was originally created.
|
|
converted_module = _BatchNormXd(
|
|
module.num_features, module.eps, module.momentum, module.affine, module.track_running_stats
|
|
)
|
|
if module.affine:
|
|
with torch.no_grad():
|
|
converted_module.weight = module.weight
|
|
converted_module.bias = module.bias
|
|
converted_module.running_mean = module.running_mean
|
|
converted_module.running_var = module.running_var
|
|
converted_module.num_batches_tracked = module.num_batches_tracked
|
|
if hasattr(module, "qconfig"):
|
|
converted_module.qconfig = module.qconfig
|
|
for name, child in module.named_children():
|
|
converted_module.add_module(name, _revert_sync_batchnorm(child))
|
|
del module
|
|
return converted_module
|
|
|
|
|
|
def rank_zero_info(*args: Any, **kwargs: Any) -> Any:
|
|
rank_zero_deprecation(
|
|
"pytorch_lightning.utilities.distributed.rank_zero_info has been deprecated in v1.6"
|
|
" and will be removed in v1.8."
|
|
" Use the equivalent function from the pytorch_lightning.utilities.rank_zero module instead."
|
|
)
|
|
return new_rank_zero_info(*args, **kwargs)
|
|
|
|
|
|
def rank_zero_debug(*args: Any, **kwargs: Any) -> Any:
|
|
rank_zero_deprecation(
|
|
"pytorch_lightning.utilities.distributed.rank_zero_debug has been deprecated in v1.6"
|
|
" and will be removed in v1.8."
|
|
" Use the equivalent function from the pytorch_lightning.utilities.rank_zero module instead."
|
|
)
|
|
return new_rank_zero_debug(*args, **kwargs)
|