lightning/pytorch_lightning/overrides/torch_distributed.py

102 lines
3.6 KiB
Python

import logging
import pickle
import torch
from pytorch_lightning.utilities.imports import _TORCH_GREATER_EQUAL_1_8
log = logging.getLogger(__name__)
if torch.distributed.is_available():
from torch.distributed import Backend, broadcast, get_backend, get_rank, GroupMember
# The code underneath is taken from PyTorch `torch/distributed/distributed_c10d.py`
# and enable broadcasting for PyTorch 1.6 and lower.
# https://github.com/pytorch/pytorch/blob/1.7/torch/distributed/distributed_c10d.py#L160
def _rank_not_in_group(group):
"""
Helper that checks if the current process's rank is not in a given group.
"""
if group is None:
return False
return group == GroupMember.NON_GROUP_MEMBER
# Taken from https://github.com/pytorch/pytorch/blob/1.7/torch/distributed/distributed_c10d.py#L1164
def _object_to_tensor(obj):
buffer = pickle.dumps(obj)
byte_storage = torch.ByteStorage.from_buffer(buffer) # type: ignore[attr-defined]
byte_tensor = torch.ByteTensor(byte_storage)
local_size = torch.LongTensor([byte_tensor.numel()])
return byte_tensor, local_size
# Taken from https://github.com/pytorch/pytorch/blob/1.7/torch/distributed/distributed_c10d.py
def _tensor_to_object(tensor, tensor_size):
buf = tensor.numpy().tobytes()[:tensor_size]
out = pickle.loads(buf)
return out
# Taken from https://github.com/pytorch/pytorch/blob/1.7/torch/distributed/distributed_c10d.py#L1327
def _broadcast_object_list(object_list, src=0, group=None):
if _rank_not_in_group(group):
return
my_rank = get_rank()
# Serialize object_list elements to tensors on src rank.
if my_rank == src:
tensor_list, size_list = zip(*(_object_to_tensor(obj) for obj in object_list))
object_sizes_tensor = torch.cat(size_list)
else:
object_sizes_tensor = torch.LongTensor(len(object_list))
group_backend = get_backend(group)
is_nccl_backend = group_backend == Backend.NCCL
current_device = torch.device("cpu")
if is_nccl_backend:
# See note about using torch.cuda.current_device() here in docstring.
# We cannot simply use my_rank since rank == device is not necessarily
# true.
current_device = torch.device("cuda", torch.cuda.current_device())
object_sizes_tensor = object_sizes_tensor.to(current_device)
object_sizes_tensor = object_sizes_tensor.to(current_device)
# Broadcast object sizes
broadcast(object_sizes_tensor, src=src, group=group)
# Concatenate and broadcast serialized object tensors
if my_rank == src:
object_tensor = torch.cat(tensor_list)
else:
object_tensor = torch.ByteTensor(torch.sum(object_sizes_tensor).item())
if is_nccl_backend:
object_tensor = object_tensor.to(current_device)
broadcast(object_tensor, src=src, group=group)
# Deserialize objects using their stored sizes.
offset = 0
if my_rank != src:
for i, obj_size in enumerate(object_sizes_tensor):
obj_view = object_tensor[offset : offset + obj_size]
obj_view = obj_view.type(torch.ByteTensor) # type: ignore[call-overload]
offset += obj_size
object_list[i] = _tensor_to_object(obj_view, obj_size)
if not torch.distributed.is_available():
# avoid failures on early PyTorch versions for Windows where
# not all functions used in `broadcast_object_list` are available.
def _broadcast_noop(obj, *_, **__):
return obj
broadcast_object_list = _broadcast_noop
elif _TORCH_GREATER_EQUAL_1_8:
from torch.distributed.distributed_c10d import broadcast_object_list
else:
broadcast_object_list = _broadcast_object_list