guard against None in pytorch get_xla_supported_devices (#9572)
Co-authored-by: Chris Chow <cchow@nianticlabs.com> Co-authored-by: thomas chaton <thomas@grid.ai>
This commit is contained in:
parent
b530b7afd2
commit
f14a47a0b2
|
@ -469,6 +469,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
|
|||
- Fixed `BasePredictionWriter` not returning the batch_indices in a non-distributed setting ([#9432](https://github.com/PyTorchLightning/pytorch-lightning/pull/9432))
|
||||
|
||||
|
||||
- Fixed an error when running on in XLA environments with no TPU attached ([#9572](https://github.com/PyTorchLightning/pytorch-lightning/pull/9572))
|
||||
|
||||
|
||||
- Fixed check on torchmetrics logged whose `compute()` output is a multielement tensor ([#9582](https://github.com/PyTorchLightning/pytorch-lightning/pull/9582))
|
||||
|
||||
|
||||
|
|
|
@ -70,9 +70,7 @@ class XLADeviceUtils:
|
|||
# we would have to use `torch_xla.distributed.xla_dist` for
|
||||
# multiple VMs and TPU_CONFIG won't be available, running
|
||||
# `xm.get_xla_supported_devices("TPU")` won't be possible.
|
||||
if xm.xrt_world_size() > 1:
|
||||
return True
|
||||
return len(xm.get_xla_supported_devices("TPU")) > 0
|
||||
return (xm.xrt_world_size() > 1) or bool(xm.get_xla_supported_devices("TPU"))
|
||||
|
||||
@staticmethod
|
||||
def xla_available() -> bool:
|
||||
|
|
Loading…
Reference in New Issue