Update `all_gather` docs (#19469)

This commit is contained in:
awaelchli 2024-02-14 19:37:50 +01:00 committed by GitHub
parent 1d04c10e2d
commit 59e45d6f6d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 9 additions and 2 deletions

View File

@ -236,6 +236,11 @@ Full example:
result = fabric.all_gather(data)
print("Result of all-gather:", result) # tensor([ 0, 10, 20, 30])
.. warning::
For the special case where ``world_size`` is 1, no additional dimension is added to the tensor(s). This inconsistency
is kept for backward compatibility and you may need to handle this special case in your code to make it agnostic.
----

View File

@ -583,7 +583,8 @@ class Fabric:
Return:
A tensor of shape (world_size, batch, ...), or if the input was a collection
the output will also be a collection with tensors of this shape.
the output will also be a collection with tensors of this shape. For the special case where
world_size is 1, no additional dimension is added to the tensor(s).
"""
self._validate_launched()

View File

@ -668,7 +668,8 @@ class LightningModule(
Return:
A tensor of shape (world_size, batch, ...), or if the input was a collection
the output will also be a collection with tensors of this shape.
the output will also be a collection with tensors of this shape. For the special case where
world_size is 1, no additional dimension is added to the tensor(s).
"""
group = group if group is not None else torch.distributed.group.WORLD