From b019c25dde6c2284d85bb96189dee386743bf8e5 Mon Sep 17 00:00:00 2001 From: Sean Naren Date: Thu, 18 Feb 2021 11:12:52 +0000 Subject: [PATCH] Add descriptions to accelerator broadcast function/clean up all_gather (#6044) * Add descriptions to accelerator broadcast function/clean up all_gather * Remove todo --- pytorch_lightning/accelerators/accelerator.py | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/pytorch_lightning/accelerators/accelerator.py b/pytorch_lightning/accelerators/accelerator.py index 0600aca92e..2ba97d7d9a 100644 --- a/pytorch_lightning/accelerators/accelerator.py +++ b/pytorch_lightning/accelerators/accelerator.py @@ -383,12 +383,18 @@ class Accelerator(object): self.training_type_plugin.barrier(name=name) def broadcast(self, obj: object, src: int = 0) -> object: - """Broadcasts an object to all processes""" + """Broadcasts an object to all processes, such that the src object is broadcast to all other ranks if needed. + + Args: + obj: Object to broadcast to all process, usually a tensor or collection of tensors. + src: The source rank of which the object will be broadcast from + """ return self.training_type_plugin.broadcast(obj, src) def all_gather(self, tensor: Union[torch.Tensor], group: Optional[Any] = None, sync_grads: bool = False): """ - Function to gather a tensor from several distributed processes + Function to gather a tensor from several distributed processes. + Args: tensor: tensor of shape (batch, ...) group: the process group to gather results from. Defaults to all processes (world) @@ -409,8 +415,7 @@ class Accelerator(object): @property def results(self) -> Any: """ - The results of the last training/testing run will be cached here. + The results of the last training/testing run will be cached within the training type plugin. In distributed training, we make sure to transfer the results to the appropriate master process. """ - # TODO: improve these docs return self.training_type_plugin.results