Make gradients available for all_gather on TPU (#15003)

* Make gradients available for all_gather on TPU
* Modify switch and tests
* Apply suggestions from code review
* Modify tests
* Fix test
* Drop test

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>
Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com>
Co-authored-by: Carlos Mocholí <carlossmocholi@gmail.com>
Co-authored-by: Jirka Borovec <6035284+Borda@users.noreply.github.com>
This commit is contained in:
stekiri 2022-12-08 08:08:04 +01:00 committed by GitHub
parent d5b9c678ba
commit 0d822e4ba1
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 34 additions and 8 deletions

View File

@ -156,20 +156,22 @@ class XLAStrategy(DDPSpawnStrategy):
return obj
def all_gather(self, tensor: Tensor, group: Optional[Any] = None, sync_grads: bool = False) -> Tensor:
"""
Function to gather a tensor from several distributed processes
"""Function to gather a tensor from several distributed processes.
Args:
tensor: tensor of shape (batch, ...)
group: not available with TPUs
sync_grads: not available with TPUs
sync_grads: flag that allows users to synchronize gradients for the all_gather operation
Return:
A tensor of shape (world_size, batch, ...)
"""
if isinstance(tensor, Tensor) and tensor.dim() == 0:
tensor = tensor.unsqueeze(0)
import torch_xla.core.functions as xf
import torch_xla.core.xla_model as xm
return xm.all_gather(tensor)
return xf.all_gather(tensor) if sync_grads else xm.all_gather(tensor)
def save_checkpoint(
self, checkpoint: Dict[str, Any], filepath: _PATH, storage_options: Optional[Any] = None

View File

@ -289,20 +289,22 @@ class TPUSpawnStrategy(DDPSpawnStrategy):
self.checkpoint_io.remove_checkpoint(filepath)
def all_gather(self, tensor: Tensor, group: Optional[Any] = None, sync_grads: bool = False) -> Tensor:
"""
Function to gather a tensor from several distributed processes
"""Function to gather a tensor from several distributed processes.
Args:
tensor: tensor of shape (batch, ...)
group: not available with TPUs
sync_grads: not available with TPUs
sync_grads: flag that allows users to synchronize gradients for the all_gather operation
Return:
A tensor of shape (world_size, batch, ...)
"""
if isinstance(tensor, Tensor) and tensor.dim() == 0:
tensor = tensor.unsqueeze(0)
import torch_xla.core.functions as xf
import torch_xla.core.xla_model as xm
return xm.all_gather(tensor)
return xf.all_gather(tensor) if sync_grads else xm.all_gather(tensor)
def teardown(self) -> None:
super().teardown()

View File

@ -17,6 +17,7 @@ from unittest import mock
from unittest.mock import Mock
import pytest
import torch
from tests_lite.helpers.dataloaders import CustomNotImplementedErrorDataloader
from tests_lite.helpers.models import RandomDataset, RandomIterableDataset
from tests_lite.helpers.runif import RunIf
@ -113,3 +114,24 @@ def test_xla_validate_unsupported_iterable_dataloaders(_, dataloader, monkeypatc
with pytest.raises(TypeError, match="TPUs do not currently support"):
XLAStrategy().process_dataloader(dataloader)
def tpu_all_gather_fn(strategy):
for sync_grads in [True, False]:
tensor = torch.tensor(1.0, device=strategy.root_device, requires_grad=True)
result = strategy.all_gather(tensor, sync_grads=sync_grads)
summed = result.sum()
assert torch.equal(summed, torch.tensor(8.0))
summed.backward()
if sync_grads:
assert torch.equal(tensor.grad, torch.tensor(1.0))
else:
# As gradients are not synced, the original tensor will not have gradients.
assert tensor.grad is None
@RunIf(tpu=True)
@mock.patch.dict(os.environ, os.environ.copy(), clear=True)
def test_tpu_all_gather():
"""Test the all_gather operation on TPU."""
xla_launch(tpu_all_gather_fn)