diff --git a/CHANGELOG.md b/CHANGELOG.md index b3513794e3..27e5f4be2d 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -95,6 +95,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fixed duplicate logs appearing in console when using the python logging module ([#5509](https://github.com/PyTorchLightning/pytorch-lightning/pull/5509), [#6275](https://github.com/PyTorchLightning/pytorch-lightning/pull/6275)) +- Fixed `SingleTPU` calling `all_gather` ([#6296](https://github.com/PyTorchLightning/pytorch-lightning/pull/6296)) + + - Fixed error thrown when using valid distributed mode in multi node ([#6297](https://github.com/PyTorchLightning/pytorch-lightning/pull/6297) diff --git a/pytorch_lightning/accelerators/tpu.py b/pytorch_lightning/accelerators/tpu.py index 8f98cb8ac5..c36f7287f3 100644 --- a/pytorch_lightning/accelerators/tpu.py +++ b/pytorch_lightning/accelerators/tpu.py @@ -44,4 +44,7 @@ class TPUAccelerator(Accelerator): Return: A tensor of shape (world_size, batch, ...) """ - return xm.all_gather(tensor, group=group, sync_grads=sync_grads) + # todo: Add support for backward with all_gather + if torch.distributed.is_initialized(): + return xm.all_gather(tensor, group=group, sync_grads=sync_grads) + return tensor