Make gradients available for all_gather on TPU (#15003)
* Make gradients available for all_gather on TPU * Modify switch and tests * Apply suggestions from code review * Modify tests * Fix test * Drop test Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com> Co-authored-by: Carlos Mocholí <carlossmocholi@gmail.com> Co-authored-by: Jirka Borovec <6035284+Borda@users.noreply.github.com>
This commit is contained in:
parent
d5b9c678ba
commit
0d822e4ba1
|
@ -156,20 +156,22 @@ class XLAStrategy(DDPSpawnStrategy):
|
|||
return obj
|
||||
|
||||
def all_gather(self, tensor: Tensor, group: Optional[Any] = None, sync_grads: bool = False) -> Tensor:
|
||||
"""
|
||||
Function to gather a tensor from several distributed processes
|
||||
"""Function to gather a tensor from several distributed processes.
|
||||
|
||||
Args:
|
||||
tensor: tensor of shape (batch, ...)
|
||||
group: not available with TPUs
|
||||
sync_grads: not available with TPUs
|
||||
sync_grads: flag that allows users to synchronize gradients for the all_gather operation
|
||||
Return:
|
||||
A tensor of shape (world_size, batch, ...)
|
||||
"""
|
||||
if isinstance(tensor, Tensor) and tensor.dim() == 0:
|
||||
tensor = tensor.unsqueeze(0)
|
||||
|
||||
import torch_xla.core.functions as xf
|
||||
import torch_xla.core.xla_model as xm
|
||||
|
||||
return xm.all_gather(tensor)
|
||||
return xf.all_gather(tensor) if sync_grads else xm.all_gather(tensor)
|
||||
|
||||
def save_checkpoint(
|
||||
self, checkpoint: Dict[str, Any], filepath: _PATH, storage_options: Optional[Any] = None
|
||||
|
|
|
@ -289,20 +289,22 @@ class TPUSpawnStrategy(DDPSpawnStrategy):
|
|||
self.checkpoint_io.remove_checkpoint(filepath)
|
||||
|
||||
def all_gather(self, tensor: Tensor, group: Optional[Any] = None, sync_grads: bool = False) -> Tensor:
|
||||
"""
|
||||
Function to gather a tensor from several distributed processes
|
||||
"""Function to gather a tensor from several distributed processes.
|
||||
|
||||
Args:
|
||||
tensor: tensor of shape (batch, ...)
|
||||
group: not available with TPUs
|
||||
sync_grads: not available with TPUs
|
||||
sync_grads: flag that allows users to synchronize gradients for the all_gather operation
|
||||
Return:
|
||||
A tensor of shape (world_size, batch, ...)
|
||||
"""
|
||||
if isinstance(tensor, Tensor) and tensor.dim() == 0:
|
||||
tensor = tensor.unsqueeze(0)
|
||||
|
||||
import torch_xla.core.functions as xf
|
||||
import torch_xla.core.xla_model as xm
|
||||
|
||||
return xm.all_gather(tensor)
|
||||
return xf.all_gather(tensor) if sync_grads else xm.all_gather(tensor)
|
||||
|
||||
def teardown(self) -> None:
|
||||
super().teardown()
|
||||
|
|
|
@ -17,6 +17,7 @@ from unittest import mock
|
|||
from unittest.mock import Mock
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
from tests_lite.helpers.dataloaders import CustomNotImplementedErrorDataloader
|
||||
from tests_lite.helpers.models import RandomDataset, RandomIterableDataset
|
||||
from tests_lite.helpers.runif import RunIf
|
||||
|
@ -113,3 +114,24 @@ def test_xla_validate_unsupported_iterable_dataloaders(_, dataloader, monkeypatc
|
|||
|
||||
with pytest.raises(TypeError, match="TPUs do not currently support"):
|
||||
XLAStrategy().process_dataloader(dataloader)
|
||||
|
||||
|
||||
def tpu_all_gather_fn(strategy):
|
||||
for sync_grads in [True, False]:
|
||||
tensor = torch.tensor(1.0, device=strategy.root_device, requires_grad=True)
|
||||
result = strategy.all_gather(tensor, sync_grads=sync_grads)
|
||||
summed = result.sum()
|
||||
assert torch.equal(summed, torch.tensor(8.0))
|
||||
summed.backward()
|
||||
if sync_grads:
|
||||
assert torch.equal(tensor.grad, torch.tensor(1.0))
|
||||
else:
|
||||
# As gradients are not synced, the original tensor will not have gradients.
|
||||
assert tensor.grad is None
|
||||
|
||||
|
||||
@RunIf(tpu=True)
|
||||
@mock.patch.dict(os.environ, os.environ.copy(), clear=True)
|
||||
def test_tpu_all_gather():
|
||||
"""Test the all_gather operation on TPU."""
|
||||
xla_launch(tpu_all_gather_fn)
|
||||
|
|
Loading…
Reference in New Issue