Add descriptions to accelerator broadcast function/clean up all_gather (#6044)
* Add descriptions to accelerator broadcast function/clean up all_gather * Remove todo
This commit is contained in:
parent
049006a59c
commit
b019c25dde
|
@ -383,12 +383,18 @@ class Accelerator(object):
|
|||
self.training_type_plugin.barrier(name=name)
|
||||
|
||||
def broadcast(self, obj: object, src: int = 0) -> object:
|
||||
"""Broadcasts an object to all processes"""
|
||||
"""Broadcasts an object to all processes, such that the src object is broadcast to all other ranks if needed.
|
||||
|
||||
Args:
|
||||
obj: Object to broadcast to all process, usually a tensor or collection of tensors.
|
||||
src: The source rank of which the object will be broadcast from
|
||||
"""
|
||||
return self.training_type_plugin.broadcast(obj, src)
|
||||
|
||||
def all_gather(self, tensor: Union[torch.Tensor], group: Optional[Any] = None, sync_grads: bool = False):
|
||||
"""
|
||||
Function to gather a tensor from several distributed processes
|
||||
Function to gather a tensor from several distributed processes.
|
||||
|
||||
Args:
|
||||
tensor: tensor of shape (batch, ...)
|
||||
group: the process group to gather results from. Defaults to all processes (world)
|
||||
|
@ -409,8 +415,7 @@ class Accelerator(object):
|
|||
@property
|
||||
def results(self) -> Any:
|
||||
"""
|
||||
The results of the last training/testing run will be cached here.
|
||||
The results of the last training/testing run will be cached within the training type plugin.
|
||||
In distributed training, we make sure to transfer the results to the appropriate master process.
|
||||
"""
|
||||
# TODO: improve these docs
|
||||
return self.training_type_plugin.results
|
||||
|
|
Loading…
Reference in New Issue