Add descriptions to accelerator broadcast function/clean up all_gather (#6044)

* Add descriptions to accelerator broadcast function/clean up all_gather

* Remove todo
This commit is contained in:
Sean Naren 2021-02-18 11:12:52 +00:00 committed by GitHub
parent 049006a59c
commit b019c25dde
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 9 additions and 4 deletions

View File

@ -383,12 +383,18 @@ class Accelerator(object):
self.training_type_plugin.barrier(name=name)
def broadcast(self, obj: object, src: int = 0) -> object:
"""Broadcasts an object to all processes"""
"""Broadcasts an object to all processes, such that the src object is broadcast to all other ranks if needed.
Args:
obj: Object to broadcast to all process, usually a tensor or collection of tensors.
src: The source rank of which the object will be broadcast from
"""
return self.training_type_plugin.broadcast(obj, src)
def all_gather(self, tensor: Union[torch.Tensor], group: Optional[Any] = None, sync_grads: bool = False):
"""
Function to gather a tensor from several distributed processes
Function to gather a tensor from several distributed processes.
Args:
tensor: tensor of shape (batch, ...)
group: the process group to gather results from. Defaults to all processes (world)
@ -409,8 +415,7 @@ class Accelerator(object):
@property
def results(self) -> Any:
"""
The results of the last training/testing run will be cached here.
The results of the last training/testing run will be cached within the training type plugin.
In distributed training, we make sure to transfer the results to the appropriate master process.
"""
# TODO: improve these docs
return self.training_type_plugin.results