guard against None in pytorch get_xla_supported_devices (#9572)

Co-authored-by: Chris Chow <cchow@nianticlabs.com>
Co-authored-by: thomas chaton <thomas@grid.ai>
This commit is contained in:
Chris Chow 2021-10-12 05:01:32 -07:00 committed by GitHub
parent b530b7afd2
commit f14a47a0b2
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 4 additions and 3 deletions

View File

@ -469,6 +469,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Fixed `BasePredictionWriter` not returning the batch_indices in a non-distributed setting ([#9432](https://github.com/PyTorchLightning/pytorch-lightning/pull/9432))
- Fixed an error when running on in XLA environments with no TPU attached ([#9572](https://github.com/PyTorchLightning/pytorch-lightning/pull/9572))
- Fixed check on torchmetrics logged whose `compute()` output is a multielement tensor ([#9582](https://github.com/PyTorchLightning/pytorch-lightning/pull/9582))

View File

@ -70,9 +70,7 @@ class XLADeviceUtils:
# we would have to use `torch_xla.distributed.xla_dist` for
# multiple VMs and TPU_CONFIG won't be available, running
# `xm.get_xla_supported_devices("TPU")` won't be possible.
if xm.xrt_world_size() > 1:
return True
return len(xm.get_xla_supported_devices("TPU")) > 0
return (xm.xrt_world_size() > 1) or bool(xm.get_xla_supported_devices("TPU"))
@staticmethod
def xla_available() -> bool: