From 1bab0a17a9139ba95841fee3c716346a2ca05ed7 Mon Sep 17 00:00:00 2001 From: Sean Naren Date: Tue, 24 Aug 2021 20:18:12 +0100 Subject: [PATCH] Fix torch bfloat import version (#9089) --- pytorch_lightning/plugins/precision/native_amp.py | 4 ++-- pytorch_lightning/utilities/__init__.py | 1 + pytorch_lightning/utilities/imports.py | 7 ++++++- tests/models/test_amp.py | 4 ++-- 4 files changed, 11 insertions(+), 5 deletions(-) diff --git a/pytorch_lightning/plugins/precision/native_amp.py b/pytorch_lightning/plugins/precision/native_amp.py index ae9f261085..4c34bed972 100644 --- a/pytorch_lightning/plugins/precision/native_amp.py +++ b/pytorch_lightning/plugins/precision/native_amp.py @@ -19,7 +19,7 @@ from torch.optim import LBFGS, Optimizer import pytorch_lightning as pl from pytorch_lightning.plugins.precision.mixed import MixedPrecisionPlugin -from pytorch_lightning.utilities import _NATIVE_AMP_AVAILABLE, _TORCH_GREATER_EQUAL_1_10, AMPType +from pytorch_lightning.utilities import _NATIVE_AMP_AVAILABLE, _TORCH_BFLOAT_AVAILABLE, AMPType from pytorch_lightning.utilities.exceptions import MisconfigurationException @@ -46,7 +46,7 @@ class NativeMixedPrecisionPlugin(MixedPrecisionPlugin): def _select_precision_dtype(self, precision: Union[int, str] = 16) -> torch.dtype: if precision == "bf16": - if not _TORCH_GREATER_EQUAL_1_10: + if not _TORCH_BFLOAT_AVAILABLE: raise MisconfigurationException( "To use bfloat16 with native amp you must install torch greater or equal to 1.10." ) diff --git a/pytorch_lightning/utilities/__init__.py b/pytorch_lightning/utilities/__init__.py index 747c0be617..56e3e03910 100644 --- a/pytorch_lightning/utilities/__init__.py +++ b/pytorch_lightning/utilities/__init__.py @@ -44,6 +44,7 @@ from pytorch_lightning.utilities.imports import ( # noqa: F401 _OMEGACONF_AVAILABLE, _POPTORCH_AVAILABLE, _RICH_AVAILABLE, + _TORCH_BFLOAT_AVAILABLE, _TORCH_GREATER_EQUAL_1_7, _TORCH_GREATER_EQUAL_1_8, _TORCH_GREATER_EQUAL_1_9, diff --git a/pytorch_lightning/utilities/imports.py b/pytorch_lightning/utilities/imports.py index fa6598f884..dd31d23f6c 100644 --- a/pytorch_lightning/utilities/imports.py +++ b/pytorch_lightning/utilities/imports.py @@ -87,8 +87,13 @@ _NATIVE_AMP_AVAILABLE = _module_available("torch.cuda.amp") and hasattr(torch.cu _OMEGACONF_AVAILABLE = _module_available("omegaconf") _POPTORCH_AVAILABLE = _module_available("poptorch") _RICH_AVAILABLE = _module_available("rich") +_TORCH_BFLOAT_AVAILABLE = _compare_version( + "torch", operator.ge, "1.10.0.dev20210820" +) # todo: swap to 1.10.0 once released _TORCH_QUANTIZE_AVAILABLE = bool([eg for eg in torch.backends.quantized.supported_engines if eg != "none"]) -_TORCH_SHARDED_TENSOR_AVAILABLE = _compare_version("torch", operator.ge, "1.10.0.dev20210809") +_TORCH_SHARDED_TENSOR_AVAILABLE = _compare_version( + "torch", operator.ge, "1.10.0.dev20210809" +) # todo: swap to 1.10.0 once released _TORCHTEXT_AVAILABLE = _module_available("torchtext") _TORCHVISION_AVAILABLE = _module_available("torchvision") _TORCHMETRICS_LOWER_THAN_0_3 = _compare_version("torchmetrics", operator.lt, "0.3.0") diff --git a/tests/models/test_amp.py b/tests/models/test_amp.py index 79c0cf7c12..173d1f8a0d 100644 --- a/tests/models/test_amp.py +++ b/tests/models/test_amp.py @@ -22,7 +22,7 @@ from torch.utils.data import DataLoader import tests.helpers.utils as tutils from pytorch_lightning import Trainer from pytorch_lightning.plugins.environments import SLURMEnvironment -from pytorch_lightning.utilities import _TORCH_GREATER_EQUAL_1_10 +from pytorch_lightning.utilities import _TORCH_BFLOAT_AVAILABLE from pytorch_lightning.utilities.exceptions import MisconfigurationException from tests.helpers import BoringModel, RandomDataset from tests.helpers.runif import RunIf @@ -71,7 +71,7 @@ class AMPTestModel(BoringModel): 16, pytest.param( "bf16", - marks=pytest.mark.skipif(not _TORCH_GREATER_EQUAL_1_10, reason="torch.bfloat16 not available"), + marks=pytest.mark.skipif(not _TORCH_BFLOAT_AVAILABLE, reason="torch.bfloat16 not available"), ), ], )