From b446b08be5898981de11a1a8db9e7c4ad26c7258 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Carlos=20Mochol=C3=AD?= Date: Wed, 24 Jan 2024 16:21:56 +0100 Subject: [PATCH] Fallback to `ACCELERATOR_TYPE` for TPU flops (#19314) --- src/lightning/fabric/utilities/throughput.py | 4 +++- tests/tests_fabric/utilities/test_throughput.py | 6 +++++- 2 files changed, 8 insertions(+), 2 deletions(-) diff --git a/src/lightning/fabric/utilities/throughput.py b/src/lightning/fabric/utilities/throughput.py index 6b971a58ad..d96673df64 100644 --- a/src/lightning/fabric/utilities/throughput.py +++ b/src/lightning/fabric/utilities/throughput.py @@ -597,7 +597,9 @@ def get_available_flops(device: torch.device, dtype: Union[torch.dtype, str]) -> else: from torch_xla.experimental import tpu - device_name = tpu.get_tpu_env()["TYPE"] + tpu_env = tpu.get_tpu_env() + # not all TPU generations define the "TYPE" envar. example: TYPE="V4", ACCELERATOR_TYPE="v4-8" + device_name = tpu_env.get("TYPE") or tpu_env["ACCELERATOR_TYPE"].split("-")[0] chip = device_name.lower() assert isinstance(device_name, str) if chip not in _TPU_FLOPS: diff --git a/tests/tests_fabric/utilities/test_throughput.py b/tests/tests_fabric/utilities/test_throughput.py index 2a66ce25e9..f2c3de30a3 100644 --- a/tests/tests_fabric/utilities/test_throughput.py +++ b/tests/tests_fabric/utilities/test_throughput.py @@ -49,8 +49,8 @@ def test_get_available_flops(xla_available): from torch_xla.experimental import tpu assert isinstance(tpu, Mock) - tpu.get_tpu_env.return_value = {"TYPE": "V4"} + tpu.get_tpu_env.return_value = {"TYPE": "V4"} flops = get_available_flops(torch.device("xla"), torch.bfloat16) assert flops == 275e12 @@ -58,6 +58,10 @@ def test_get_available_flops(xla_available): with pytest.warns(match="not found for TPU 'V1'"): assert get_available_flops(torch.device("xla"), torch.bfloat16) is None + tpu.get_tpu_env.return_value = {"ACCELERATOR_TYPE": "v3-8"} + flops = get_available_flops(torch.device("xla"), torch.bfloat16) + assert flops == 123e12 + tpu.reset_mock()