update amp
This commit is contained in:
parent
5cdd9e72b1
commit
06a3303742
|
@ -56,7 +56,7 @@ class MixedPrecision(Precision):
|
|||
if _TORCH_GREATER_EQUAL_2_4
|
||||
else getattr(
|
||||
torch,
|
||||
"cuda" if not isinstance(device, str) or device.split(":")[0] == "cpu" else device.split(":")[0],
|
||||
"cuda" if device.split(":")[0] == "cpu" else device.split(":")[0],
|
||||
).amp.GradScaler()
|
||||
)
|
||||
if scaler is not None and self.precision == "bf16-mixed":
|
||||
|
|
|
@ -56,7 +56,7 @@ class MixedPrecision(Precision):
|
|||
if _TORCH_GREATER_EQUAL_2_4
|
||||
else getattr(
|
||||
torch,
|
||||
"cuda" if not isinstance(device, str) or device.split(":")[0] == "cpu" else device.split(":")[0],
|
||||
"cuda" if device.split(":")[0] == "cpu" else device.split(":")[0],
|
||||
).amp.GradScaler()
|
||||
)
|
||||
if scaler is not None and self.precision == "bf16-mixed":
|
||||
|
|
Loading…
Reference in New Issue