update amp

This commit is contained in:
Zhiyuan Li 2024-12-20 03:20:18 +00:00
parent 5cdd9e72b1
commit 06a3303742
2 changed files with 2 additions and 2 deletions

View File

@ -56,7 +56,7 @@ class MixedPrecision(Precision):
if _TORCH_GREATER_EQUAL_2_4 if _TORCH_GREATER_EQUAL_2_4
else getattr( else getattr(
torch, torch,
"cuda" if not isinstance(device, str) or device.split(":")[0] == "cpu" else device.split(":")[0], "cuda" if device.split(":")[0] == "cpu" else device.split(":")[0],
).amp.GradScaler() ).amp.GradScaler()
) )
if scaler is not None and self.precision == "bf16-mixed": if scaler is not None and self.precision == "bf16-mixed":

View File

@ -56,7 +56,7 @@ class MixedPrecision(Precision):
if _TORCH_GREATER_EQUAL_2_4 if _TORCH_GREATER_EQUAL_2_4
else getattr( else getattr(
torch, torch,
"cuda" if not isinstance(device, str) or device.split(":")[0] == "cpu" else device.split(":")[0], "cuda" if device.split(":")[0] == "cpu" else device.split(":")[0],
).amp.GradScaler() ).amp.GradScaler()
) )
if scaler is not None and self.precision == "bf16-mixed": if scaler is not None and self.precision == "bf16-mixed":