diff --git a/src/lightning/fabric/plugins/precision/amp.py b/src/lightning/fabric/plugins/precision/amp.py index a7bf1113d1..d89ff1597e 100644 --- a/src/lightning/fabric/plugins/precision/amp.py +++ b/src/lightning/fabric/plugins/precision/amp.py @@ -56,7 +56,7 @@ class MixedPrecision(Precision): if _TORCH_GREATER_EQUAL_2_4 else getattr( torch, - "cuda" if not isinstance(device, str) or device.split(":")[0] == "cpu" else device.split(":")[0], + "cuda" if device.split(":")[0] == "cpu" else device.split(":")[0], ).amp.GradScaler() ) if scaler is not None and self.precision == "bf16-mixed": diff --git a/src/lightning/pytorch/plugins/precision/amp.py b/src/lightning/pytorch/plugins/precision/amp.py index d610057e2a..b6a6fd1277 100644 --- a/src/lightning/pytorch/plugins/precision/amp.py +++ b/src/lightning/pytorch/plugins/precision/amp.py @@ -56,7 +56,7 @@ class MixedPrecision(Precision): if _TORCH_GREATER_EQUAL_2_4 else getattr( torch, - "cuda" if not isinstance(device, str) or device.split(":")[0] == "cpu" else device.split(":")[0], + "cuda" if device.split(":")[0] == "cpu" else device.split(":")[0], ).amp.GradScaler() ) if scaler is not None and self.precision == "bf16-mixed":