Fix torch bfloat import version (#9089)
This commit is contained in:
parent
f959b13ab9
commit
1bab0a17a9
|
@ -19,7 +19,7 @@ from torch.optim import LBFGS, Optimizer
|
||||||
|
|
||||||
import pytorch_lightning as pl
|
import pytorch_lightning as pl
|
||||||
from pytorch_lightning.plugins.precision.mixed import MixedPrecisionPlugin
|
from pytorch_lightning.plugins.precision.mixed import MixedPrecisionPlugin
|
||||||
from pytorch_lightning.utilities import _NATIVE_AMP_AVAILABLE, _TORCH_GREATER_EQUAL_1_10, AMPType
|
from pytorch_lightning.utilities import _NATIVE_AMP_AVAILABLE, _TORCH_BFLOAT_AVAILABLE, AMPType
|
||||||
from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
||||||
|
|
||||||
|
|
||||||
|
@ -46,7 +46,7 @@ class NativeMixedPrecisionPlugin(MixedPrecisionPlugin):
|
||||||
|
|
||||||
def _select_precision_dtype(self, precision: Union[int, str] = 16) -> torch.dtype:
|
def _select_precision_dtype(self, precision: Union[int, str] = 16) -> torch.dtype:
|
||||||
if precision == "bf16":
|
if precision == "bf16":
|
||||||
if not _TORCH_GREATER_EQUAL_1_10:
|
if not _TORCH_BFLOAT_AVAILABLE:
|
||||||
raise MisconfigurationException(
|
raise MisconfigurationException(
|
||||||
"To use bfloat16 with native amp you must install torch greater or equal to 1.10."
|
"To use bfloat16 with native amp you must install torch greater or equal to 1.10."
|
||||||
)
|
)
|
||||||
|
|
|
@ -44,6 +44,7 @@ from pytorch_lightning.utilities.imports import ( # noqa: F401
|
||||||
_OMEGACONF_AVAILABLE,
|
_OMEGACONF_AVAILABLE,
|
||||||
_POPTORCH_AVAILABLE,
|
_POPTORCH_AVAILABLE,
|
||||||
_RICH_AVAILABLE,
|
_RICH_AVAILABLE,
|
||||||
|
_TORCH_BFLOAT_AVAILABLE,
|
||||||
_TORCH_GREATER_EQUAL_1_7,
|
_TORCH_GREATER_EQUAL_1_7,
|
||||||
_TORCH_GREATER_EQUAL_1_8,
|
_TORCH_GREATER_EQUAL_1_8,
|
||||||
_TORCH_GREATER_EQUAL_1_9,
|
_TORCH_GREATER_EQUAL_1_9,
|
||||||
|
|
|
@ -87,8 +87,13 @@ _NATIVE_AMP_AVAILABLE = _module_available("torch.cuda.amp") and hasattr(torch.cu
|
||||||
_OMEGACONF_AVAILABLE = _module_available("omegaconf")
|
_OMEGACONF_AVAILABLE = _module_available("omegaconf")
|
||||||
_POPTORCH_AVAILABLE = _module_available("poptorch")
|
_POPTORCH_AVAILABLE = _module_available("poptorch")
|
||||||
_RICH_AVAILABLE = _module_available("rich")
|
_RICH_AVAILABLE = _module_available("rich")
|
||||||
|
_TORCH_BFLOAT_AVAILABLE = _compare_version(
|
||||||
|
"torch", operator.ge, "1.10.0.dev20210820"
|
||||||
|
) # todo: swap to 1.10.0 once released
|
||||||
_TORCH_QUANTIZE_AVAILABLE = bool([eg for eg in torch.backends.quantized.supported_engines if eg != "none"])
|
_TORCH_QUANTIZE_AVAILABLE = bool([eg for eg in torch.backends.quantized.supported_engines if eg != "none"])
|
||||||
_TORCH_SHARDED_TENSOR_AVAILABLE = _compare_version("torch", operator.ge, "1.10.0.dev20210809")
|
_TORCH_SHARDED_TENSOR_AVAILABLE = _compare_version(
|
||||||
|
"torch", operator.ge, "1.10.0.dev20210809"
|
||||||
|
) # todo: swap to 1.10.0 once released
|
||||||
_TORCHTEXT_AVAILABLE = _module_available("torchtext")
|
_TORCHTEXT_AVAILABLE = _module_available("torchtext")
|
||||||
_TORCHVISION_AVAILABLE = _module_available("torchvision")
|
_TORCHVISION_AVAILABLE = _module_available("torchvision")
|
||||||
_TORCHMETRICS_LOWER_THAN_0_3 = _compare_version("torchmetrics", operator.lt, "0.3.0")
|
_TORCHMETRICS_LOWER_THAN_0_3 = _compare_version("torchmetrics", operator.lt, "0.3.0")
|
||||||
|
|
|
@ -22,7 +22,7 @@ from torch.utils.data import DataLoader
|
||||||
import tests.helpers.utils as tutils
|
import tests.helpers.utils as tutils
|
||||||
from pytorch_lightning import Trainer
|
from pytorch_lightning import Trainer
|
||||||
from pytorch_lightning.plugins.environments import SLURMEnvironment
|
from pytorch_lightning.plugins.environments import SLURMEnvironment
|
||||||
from pytorch_lightning.utilities import _TORCH_GREATER_EQUAL_1_10
|
from pytorch_lightning.utilities import _TORCH_BFLOAT_AVAILABLE
|
||||||
from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
||||||
from tests.helpers import BoringModel, RandomDataset
|
from tests.helpers import BoringModel, RandomDataset
|
||||||
from tests.helpers.runif import RunIf
|
from tests.helpers.runif import RunIf
|
||||||
|
@ -71,7 +71,7 @@ class AMPTestModel(BoringModel):
|
||||||
16,
|
16,
|
||||||
pytest.param(
|
pytest.param(
|
||||||
"bf16",
|
"bf16",
|
||||||
marks=pytest.mark.skipif(not _TORCH_GREATER_EQUAL_1_10, reason="torch.bfloat16 not available"),
|
marks=pytest.mark.skipif(not _TORCH_BFLOAT_AVAILABLE, reason="torch.bfloat16 not available"),
|
||||||
),
|
),
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
Loading…
Reference in New Issue