Fallback to `ACCELERATOR_TYPE` for TPU flops (#19314)

This commit is contained in:
Carlos Mocholí 2024-01-24 16:21:56 +01:00 committed by GitHub
parent 7cc79fe7ba
commit b446b08be5
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 8 additions and 2 deletions

View File

@ -597,7 +597,9 @@ def get_available_flops(device: torch.device, dtype: Union[torch.dtype, str]) ->
else:
from torch_xla.experimental import tpu
device_name = tpu.get_tpu_env()["TYPE"]
tpu_env = tpu.get_tpu_env()
# not all TPU generations define the "TYPE" envar. example: TYPE="V4", ACCELERATOR_TYPE="v4-8"
device_name = tpu_env.get("TYPE") or tpu_env["ACCELERATOR_TYPE"].split("-")[0]
chip = device_name.lower()
assert isinstance(device_name, str)
if chip not in _TPU_FLOPS:

View File

@ -49,8 +49,8 @@ def test_get_available_flops(xla_available):
from torch_xla.experimental import tpu
assert isinstance(tpu, Mock)
tpu.get_tpu_env.return_value = {"TYPE": "V4"}
tpu.get_tpu_env.return_value = {"TYPE": "V4"}
flops = get_available_flops(torch.device("xla"), torch.bfloat16)
assert flops == 275e12
@ -58,6 +58,10 @@ def test_get_available_flops(xla_available):
with pytest.warns(match="not found for TPU 'V1'"):
assert get_available_flops(torch.device("xla"), torch.bfloat16) is None
tpu.get_tpu_env.return_value = {"ACCELERATOR_TYPE": "v3-8"}
flops = get_available_flops(torch.device("xla"), torch.bfloat16)
assert flops == 123e12
tpu.reset_mock()