Fix torch bfloat import version (#9089)

This commit is contained in:
Sean Naren 2021-08-24 20:18:12 +01:00 committed by GitHub
parent f959b13ab9
commit 1bab0a17a9
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 11 additions and 5 deletions

View File

@ -19,7 +19,7 @@ from torch.optim import LBFGS, Optimizer
import pytorch_lightning as pl import pytorch_lightning as pl
from pytorch_lightning.plugins.precision.mixed import MixedPrecisionPlugin from pytorch_lightning.plugins.precision.mixed import MixedPrecisionPlugin
from pytorch_lightning.utilities import _NATIVE_AMP_AVAILABLE, _TORCH_GREATER_EQUAL_1_10, AMPType from pytorch_lightning.utilities import _NATIVE_AMP_AVAILABLE, _TORCH_BFLOAT_AVAILABLE, AMPType
from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.exceptions import MisconfigurationException
@ -46,7 +46,7 @@ class NativeMixedPrecisionPlugin(MixedPrecisionPlugin):
def _select_precision_dtype(self, precision: Union[int, str] = 16) -> torch.dtype: def _select_precision_dtype(self, precision: Union[int, str] = 16) -> torch.dtype:
if precision == "bf16": if precision == "bf16":
if not _TORCH_GREATER_EQUAL_1_10: if not _TORCH_BFLOAT_AVAILABLE:
raise MisconfigurationException( raise MisconfigurationException(
"To use bfloat16 with native amp you must install torch greater or equal to 1.10." "To use bfloat16 with native amp you must install torch greater or equal to 1.10."
) )

View File

@ -44,6 +44,7 @@ from pytorch_lightning.utilities.imports import ( # noqa: F401
_OMEGACONF_AVAILABLE, _OMEGACONF_AVAILABLE,
_POPTORCH_AVAILABLE, _POPTORCH_AVAILABLE,
_RICH_AVAILABLE, _RICH_AVAILABLE,
_TORCH_BFLOAT_AVAILABLE,
_TORCH_GREATER_EQUAL_1_7, _TORCH_GREATER_EQUAL_1_7,
_TORCH_GREATER_EQUAL_1_8, _TORCH_GREATER_EQUAL_1_8,
_TORCH_GREATER_EQUAL_1_9, _TORCH_GREATER_EQUAL_1_9,

View File

@ -87,8 +87,13 @@ _NATIVE_AMP_AVAILABLE = _module_available("torch.cuda.amp") and hasattr(torch.cu
_OMEGACONF_AVAILABLE = _module_available("omegaconf") _OMEGACONF_AVAILABLE = _module_available("omegaconf")
_POPTORCH_AVAILABLE = _module_available("poptorch") _POPTORCH_AVAILABLE = _module_available("poptorch")
_RICH_AVAILABLE = _module_available("rich") _RICH_AVAILABLE = _module_available("rich")
_TORCH_BFLOAT_AVAILABLE = _compare_version(
"torch", operator.ge, "1.10.0.dev20210820"
) # todo: swap to 1.10.0 once released
_TORCH_QUANTIZE_AVAILABLE = bool([eg for eg in torch.backends.quantized.supported_engines if eg != "none"]) _TORCH_QUANTIZE_AVAILABLE = bool([eg for eg in torch.backends.quantized.supported_engines if eg != "none"])
_TORCH_SHARDED_TENSOR_AVAILABLE = _compare_version("torch", operator.ge, "1.10.0.dev20210809") _TORCH_SHARDED_TENSOR_AVAILABLE = _compare_version(
"torch", operator.ge, "1.10.0.dev20210809"
) # todo: swap to 1.10.0 once released
_TORCHTEXT_AVAILABLE = _module_available("torchtext") _TORCHTEXT_AVAILABLE = _module_available("torchtext")
_TORCHVISION_AVAILABLE = _module_available("torchvision") _TORCHVISION_AVAILABLE = _module_available("torchvision")
_TORCHMETRICS_LOWER_THAN_0_3 = _compare_version("torchmetrics", operator.lt, "0.3.0") _TORCHMETRICS_LOWER_THAN_0_3 = _compare_version("torchmetrics", operator.lt, "0.3.0")

View File

@ -22,7 +22,7 @@ from torch.utils.data import DataLoader
import tests.helpers.utils as tutils import tests.helpers.utils as tutils
from pytorch_lightning import Trainer from pytorch_lightning import Trainer
from pytorch_lightning.plugins.environments import SLURMEnvironment from pytorch_lightning.plugins.environments import SLURMEnvironment
from pytorch_lightning.utilities import _TORCH_GREATER_EQUAL_1_10 from pytorch_lightning.utilities import _TORCH_BFLOAT_AVAILABLE
from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.exceptions import MisconfigurationException
from tests.helpers import BoringModel, RandomDataset from tests.helpers import BoringModel, RandomDataset
from tests.helpers.runif import RunIf from tests.helpers.runif import RunIf
@ -71,7 +71,7 @@ class AMPTestModel(BoringModel):
16, 16,
pytest.param( pytest.param(
"bf16", "bf16",
marks=pytest.mark.skipif(not _TORCH_GREATER_EQUAL_1_10, reason="torch.bfloat16 not available"), marks=pytest.mark.skipif(not _TORCH_BFLOAT_AVAILABLE, reason="torch.bfloat16 not available"),
), ),
], ],
) )