diff --git a/pytorch_lightning/plugins/training_type/deepspeed.py b/pytorch_lightning/plugins/training_type/deepspeed.py index 86d380ac24..3a704fc5f1 100644 --- a/pytorch_lightning/plugins/training_type/deepspeed.py +++ b/pytorch_lightning/plugins/training_type/deepspeed.py @@ -37,7 +37,7 @@ from pytorch_lightning.trainer.states import TrainerFn from pytorch_lightning.utilities import GradClipAlgorithmType from pytorch_lightning.utilities.apply_func import apply_to_collection from pytorch_lightning.utilities.distributed import log, rank_zero_info -from pytorch_lightning.utilities.enums import _StrategyType, AMPType +from pytorch_lightning.utilities.enums import _StrategyType, AMPType, PrecisionType from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.imports import _DEEPSPEED_AVAILABLE from pytorch_lightning.utilities.model_helpers import is_overridden @@ -445,7 +445,11 @@ class DeepSpeedPlugin(DDPPlugin): if self.zero_stage_3 and self.partition_module: # Ensure the entire model has been moved to the appropriate device - dtype = torch.float16 if self.precision_plugin.precision in (16, "mixed") else torch.float32 + dtype = ( + torch.float16 + if self.precision_plugin.precision in (PrecisionType.HALF, PrecisionType.MIXED) + else torch.float32 + ) deepspeed.zero.Init( module=model, remote_device=self.remote_device, pin_memory=True, config=self.config, dtype=dtype ) @@ -502,7 +506,11 @@ class DeepSpeedPlugin(DDPPlugin): def model_sharded_context(self) -> Generator[None, None, None]: if self.zero_stage_3: assert self._config_initialized - dtype = torch.float16 if self.precision_plugin.precision in (16, "mixed") else torch.float32 + dtype = ( + torch.float16 + if self.precision_plugin.precision in (PrecisionType.HALF, PrecisionType.MIXED) + else torch.float32 + ) model_parallel_context = deepspeed.zero.Init( remote_device=self.remote_device, pin_memory=True, config=self.config, dtype=dtype ) @@ -629,7 +637,7 @@ class DeepSpeedPlugin(DDPPlugin): return batch_size def _format_precision_config(self) -> None: - if self.precision_plugin.precision in (16, "mixed"): + if self.precision_plugin.precision in (PrecisionType.HALF, PrecisionType.MIXED): if "fp16" not in self.config and self.precision_plugin.amp_type == AMPType.NATIVE: # FP16 is a DeepSpeed standalone AMP implementation rank_zero_info("Enabling DeepSpeed FP16.") diff --git a/pytorch_lightning/plugins/training_type/fully_sharded.py b/pytorch_lightning/plugins/training_type/fully_sharded.py index 73ea87b058..38fa2942a7 100644 --- a/pytorch_lightning/plugins/training_type/fully_sharded.py +++ b/pytorch_lightning/plugins/training_type/fully_sharded.py @@ -21,7 +21,7 @@ from pytorch_lightning.plugins.io.checkpoint_plugin import CheckpointIO from pytorch_lightning.plugins.precision import PrecisionPlugin from pytorch_lightning.plugins.training_type.ddp import DDPPlugin from pytorch_lightning.utilities import _FAIRSCALE_FULLY_SHARDED_AVAILABLE -from pytorch_lightning.utilities.enums import _StrategyType +from pytorch_lightning.utilities.enums import _StrategyType, PrecisionType from pytorch_lightning.utilities.exceptions import MisconfigurationException if _FAIRSCALE_FULLY_SHARDED_AVAILABLE: @@ -139,7 +139,7 @@ class DDPFullyShardedPlugin(DDPPlugin): cpu_offload=self.cpu_offload, move_grads_to_cpu=self.move_grads_to_cpu, flatten_parameters=self.flatten_parameters, - mixed_precision=precision == "mixed", + mixed_precision=(precision == PrecisionType.MIXED), reshard_after_forward=self.reshard_after_forward, fp32_reduce_scatter=self.fp32_reduce_scatter, compute_dtype=self.compute_dtype, diff --git a/pytorch_lightning/plugins/training_type/ipu.py b/pytorch_lightning/plugins/training_type/ipu.py index ef9b3d1f02..8f8f082280 100644 --- a/pytorch_lightning/plugins/training_type/ipu.py +++ b/pytorch_lightning/plugins/training_type/ipu.py @@ -29,6 +29,7 @@ from pytorch_lightning.utilities import _IPU_AVAILABLE, _POPTORCH_AVAILABLE from pytorch_lightning.utilities.apply_func import apply_to_collection from pytorch_lightning.utilities.cloud_io import get_filesystem from pytorch_lightning.utilities.data import _get_dataloader_init_kwargs +from pytorch_lightning.utilities.enums import PrecisionType from pytorch_lightning.utilities.exceptions import MisconfigurationException if _POPTORCH_AVAILABLE: @@ -41,7 +42,7 @@ class LightningIPUModule(_LightningModuleWrapperBase): self.precision = precision def forward(self, *inputs: Any, **kwargs: Any) -> Any: - if self.precision in ("mixed", 16): + if self.precision in (PrecisionType.MIXED, PrecisionType.HALF): inputs = self._move_float_tensors_to_half(inputs) return super().forward(*inputs, **kwargs) diff --git a/pytorch_lightning/plugins/training_type/sharded.py b/pytorch_lightning/plugins/training_type/sharded.py index c9627324eb..e7f57e9c92 100644 --- a/pytorch_lightning/plugins/training_type/sharded.py +++ b/pytorch_lightning/plugins/training_type/sharded.py @@ -23,7 +23,7 @@ from pytorch_lightning.core.optimizer import LightningOptimizer from pytorch_lightning.plugins.training_type.ddp import DDPPlugin from pytorch_lightning.trainer.states import TrainerFn from pytorch_lightning.utilities import _FAIRSCALE_AVAILABLE, _FAIRSCALE_OSS_FP16_BROADCAST_AVAILABLE, rank_zero_only -from pytorch_lightning.utilities.enums import _StrategyType +from pytorch_lightning.utilities.enums import _StrategyType, PrecisionType from pytorch_lightning.utilities.exceptions import MisconfigurationException if _FAIRSCALE_AVAILABLE: @@ -71,7 +71,7 @@ class DDPShardedPlugin(DDPPlugin): optim_class = type(optimizer) zero_optimizer = OSS(params=optimizer.param_groups, optim=optim_class, **optimizer.defaults) if _FAIRSCALE_OSS_FP16_BROADCAST_AVAILABLE: - is_fp16 = self.precision_plugin.precision in ("mixed", 16) + is_fp16 = self.precision_plugin.precision in (PrecisionType.MIXED, PrecisionType.HALF) # For multi-node training, compressing the model shards in fp16 before broadcasting # improves performance. When using PyTorch AMP, it will not degrade # the model performance.