import torch from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors import torch.distributed as dist from torch.nn.modules import Module class DistributedDataParallel(Module): def __init__(self, module): super(DistributedDataParallel, self).__init__() self.warn_on_half = True#$ True if dist._backend == dist.dist_backend.GLOO else False self.module = module for p in self.module.state_dict().values(): if torch.is_tensor(p): dist.broadcast(p, 0) def allreduce_params(): if(self.needs_reduction): self.needs_reduction = False buckets = {} for param in self.module.parameters(): if param.requires_grad and param.grad is not None: tp = type(param.data) if tp not in buckets: buckets[tp] = [] buckets[tp].append(param) if self.warn_on_half: if torch.cuda.HalfTensor in buckets: print("WARNING: gloo dist backend for half parameters may be extremely slow." + " It is recommended to use the NCCL backend in this case.") self.warn_on_half = False for tp in buckets: bucket = buckets[tp] grads = [param.grad.data for param in bucket] coalesced = _flatten_dense_tensors(grads) dist.all_reduce(coalesced) coalesced /= dist.get_world_size() for buf, synced in zip(grads, _unflatten_dense_tensors(coalesced, grads)): buf.copy_(synced) for param in list(self.module.parameters()): if param.requires_grad: def allreduce_hook(*unused): param._execution_engine.queue_callback(allreduce_params) param.register_hook(allreduce_hook) def forward(self, *inputs, **kwargs): self.needs_reduction = True return self.module(*inputs, **kwargs)