Fallback to `ACCELERATOR_TYPE` for TPU flops (#19314)
This commit is contained in:
parent
7cc79fe7ba
commit
b446b08be5
|
@ -597,7 +597,9 @@ def get_available_flops(device: torch.device, dtype: Union[torch.dtype, str]) ->
|
|||
else:
|
||||
from torch_xla.experimental import tpu
|
||||
|
||||
device_name = tpu.get_tpu_env()["TYPE"]
|
||||
tpu_env = tpu.get_tpu_env()
|
||||
# not all TPU generations define the "TYPE" envar. example: TYPE="V4", ACCELERATOR_TYPE="v4-8"
|
||||
device_name = tpu_env.get("TYPE") or tpu_env["ACCELERATOR_TYPE"].split("-")[0]
|
||||
chip = device_name.lower()
|
||||
assert isinstance(device_name, str)
|
||||
if chip not in _TPU_FLOPS:
|
||||
|
|
|
@ -49,8 +49,8 @@ def test_get_available_flops(xla_available):
|
|||
from torch_xla.experimental import tpu
|
||||
|
||||
assert isinstance(tpu, Mock)
|
||||
tpu.get_tpu_env.return_value = {"TYPE": "V4"}
|
||||
|
||||
tpu.get_tpu_env.return_value = {"TYPE": "V4"}
|
||||
flops = get_available_flops(torch.device("xla"), torch.bfloat16)
|
||||
assert flops == 275e12
|
||||
|
||||
|
@ -58,6 +58,10 @@ def test_get_available_flops(xla_available):
|
|||
with pytest.warns(match="not found for TPU 'V1'"):
|
||||
assert get_available_flops(torch.device("xla"), torch.bfloat16) is None
|
||||
|
||||
tpu.get_tpu_env.return_value = {"ACCELERATOR_TYPE": "v3-8"}
|
||||
flops = get_available_flops(torch.device("xla"), torch.bfloat16)
|
||||
assert flops == 123e12
|
||||
|
||||
tpu.reset_mock()
|
||||
|
||||
|
||||
|
|
Loading…
Reference in New Issue