[bugfix] TPU + all_gather + SingleTPU shouldn't call xm.all_gather (#6296)

* resolve an issue with TPU

* update

* add changelog
This commit is contained in:
thomas chaton 2021-03-03 13:54:20 +00:00 committed by GitHub
parent 4a8422c2dc
commit 484dce11ec
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 7 additions and 1 deletions

View File

@ -95,6 +95,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Fixed duplicate logs appearing in console when using the python logging module ([#5509](https://github.com/PyTorchLightning/pytorch-lightning/pull/5509), [#6275](https://github.com/PyTorchLightning/pytorch-lightning/pull/6275)) - Fixed duplicate logs appearing in console when using the python logging module ([#5509](https://github.com/PyTorchLightning/pytorch-lightning/pull/5509), [#6275](https://github.com/PyTorchLightning/pytorch-lightning/pull/6275))
- Fixed `SingleTPU` calling `all_gather` ([#6296](https://github.com/PyTorchLightning/pytorch-lightning/pull/6296))
- Fixed error thrown when using valid distributed mode in multi node ([#6297](https://github.com/PyTorchLightning/pytorch-lightning/pull/6297) - Fixed error thrown when using valid distributed mode in multi node ([#6297](https://github.com/PyTorchLightning/pytorch-lightning/pull/6297)

View File

@ -44,4 +44,7 @@ class TPUAccelerator(Accelerator):
Return: Return:
A tensor of shape (world_size, batch, ...) A tensor of shape (world_size, batch, ...)
""" """
return xm.all_gather(tensor, group=group, sync_grads=sync_grads) # todo: Add support for backward with all_gather
if torch.distributed.is_initialized():
return xm.all_gather(tensor, group=group, sync_grads=sync_grads)
return tensor