update amp
This commit is contained in:
parent
5cdd9e72b1
commit
06a3303742
|
@ -56,7 +56,7 @@ class MixedPrecision(Precision):
|
||||||
if _TORCH_GREATER_EQUAL_2_4
|
if _TORCH_GREATER_EQUAL_2_4
|
||||||
else getattr(
|
else getattr(
|
||||||
torch,
|
torch,
|
||||||
"cuda" if not isinstance(device, str) or device.split(":")[0] == "cpu" else device.split(":")[0],
|
"cuda" if device.split(":")[0] == "cpu" else device.split(":")[0],
|
||||||
).amp.GradScaler()
|
).amp.GradScaler()
|
||||||
)
|
)
|
||||||
if scaler is not None and self.precision == "bf16-mixed":
|
if scaler is not None and self.precision == "bf16-mixed":
|
||||||
|
|
|
@ -56,7 +56,7 @@ class MixedPrecision(Precision):
|
||||||
if _TORCH_GREATER_EQUAL_2_4
|
if _TORCH_GREATER_EQUAL_2_4
|
||||||
else getattr(
|
else getattr(
|
||||||
torch,
|
torch,
|
||||||
"cuda" if not isinstance(device, str) or device.split(":")[0] == "cpu" else device.split(":")[0],
|
"cuda" if device.split(":")[0] == "cpu" else device.split(":")[0],
|
||||||
).amp.GradScaler()
|
).amp.GradScaler()
|
||||||
)
|
)
|
||||||
if scaler is not None and self.precision == "bf16-mixed":
|
if scaler is not None and self.precision == "bf16-mixed":
|
||||||
|
|
Loading…
Reference in New Issue