diff --git a/CHANGELOG.md b/CHANGELOG.md index 6461b95a05..b3eb09547e 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -469,6 +469,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fixed `BasePredictionWriter` not returning the batch_indices in a non-distributed setting ([#9432](https://github.com/PyTorchLightning/pytorch-lightning/pull/9432)) +- Fixed an error when running on in XLA environments with no TPU attached ([#9572](https://github.com/PyTorchLightning/pytorch-lightning/pull/9572)) + + - Fixed check on torchmetrics logged whose `compute()` output is a multielement tensor ([#9582](https://github.com/PyTorchLightning/pytorch-lightning/pull/9582)) diff --git a/pytorch_lightning/utilities/xla_device.py b/pytorch_lightning/utilities/xla_device.py index b922a749e7..2feef71c56 100644 --- a/pytorch_lightning/utilities/xla_device.py +++ b/pytorch_lightning/utilities/xla_device.py @@ -70,9 +70,7 @@ class XLADeviceUtils: # we would have to use `torch_xla.distributed.xla_dist` for # multiple VMs and TPU_CONFIG won't be available, running # `xm.get_xla_supported_devices("TPU")` won't be possible. - if xm.xrt_world_size() > 1: - return True - return len(xm.get_xla_supported_devices("TPU")) > 0 + return (xm.xrt_world_size() > 1) or bool(xm.get_xla_supported_devices("TPU")) @staticmethod def xla_available() -> bool: