Use `PrecisionType` enum instead of checking raw values (#10704)
* use precision type * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
parent
b28ab34ff5
commit
89d0064b33
|
@ -37,7 +37,7 @@ from pytorch_lightning.trainer.states import TrainerFn
|
|||
from pytorch_lightning.utilities import GradClipAlgorithmType
|
||||
from pytorch_lightning.utilities.apply_func import apply_to_collection
|
||||
from pytorch_lightning.utilities.distributed import log, rank_zero_info
|
||||
from pytorch_lightning.utilities.enums import _StrategyType, AMPType
|
||||
from pytorch_lightning.utilities.enums import _StrategyType, AMPType, PrecisionType
|
||||
from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
||||
from pytorch_lightning.utilities.imports import _DEEPSPEED_AVAILABLE
|
||||
from pytorch_lightning.utilities.model_helpers import is_overridden
|
||||
|
@ -445,7 +445,11 @@ class DeepSpeedPlugin(DDPPlugin):
|
|||
|
||||
if self.zero_stage_3 and self.partition_module:
|
||||
# Ensure the entire model has been moved to the appropriate device
|
||||
dtype = torch.float16 if self.precision_plugin.precision in (16, "mixed") else torch.float32
|
||||
dtype = (
|
||||
torch.float16
|
||||
if self.precision_plugin.precision in (PrecisionType.HALF, PrecisionType.MIXED)
|
||||
else torch.float32
|
||||
)
|
||||
deepspeed.zero.Init(
|
||||
module=model, remote_device=self.remote_device, pin_memory=True, config=self.config, dtype=dtype
|
||||
)
|
||||
|
@ -502,7 +506,11 @@ class DeepSpeedPlugin(DDPPlugin):
|
|||
def model_sharded_context(self) -> Generator[None, None, None]:
|
||||
if self.zero_stage_3:
|
||||
assert self._config_initialized
|
||||
dtype = torch.float16 if self.precision_plugin.precision in (16, "mixed") else torch.float32
|
||||
dtype = (
|
||||
torch.float16
|
||||
if self.precision_plugin.precision in (PrecisionType.HALF, PrecisionType.MIXED)
|
||||
else torch.float32
|
||||
)
|
||||
model_parallel_context = deepspeed.zero.Init(
|
||||
remote_device=self.remote_device, pin_memory=True, config=self.config, dtype=dtype
|
||||
)
|
||||
|
@ -629,7 +637,7 @@ class DeepSpeedPlugin(DDPPlugin):
|
|||
return batch_size
|
||||
|
||||
def _format_precision_config(self) -> None:
|
||||
if self.precision_plugin.precision in (16, "mixed"):
|
||||
if self.precision_plugin.precision in (PrecisionType.HALF, PrecisionType.MIXED):
|
||||
if "fp16" not in self.config and self.precision_plugin.amp_type == AMPType.NATIVE:
|
||||
# FP16 is a DeepSpeed standalone AMP implementation
|
||||
rank_zero_info("Enabling DeepSpeed FP16.")
|
||||
|
|
|
@ -21,7 +21,7 @@ from pytorch_lightning.plugins.io.checkpoint_plugin import CheckpointIO
|
|||
from pytorch_lightning.plugins.precision import PrecisionPlugin
|
||||
from pytorch_lightning.plugins.training_type.ddp import DDPPlugin
|
||||
from pytorch_lightning.utilities import _FAIRSCALE_FULLY_SHARDED_AVAILABLE
|
||||
from pytorch_lightning.utilities.enums import _StrategyType
|
||||
from pytorch_lightning.utilities.enums import _StrategyType, PrecisionType
|
||||
from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
||||
|
||||
if _FAIRSCALE_FULLY_SHARDED_AVAILABLE:
|
||||
|
@ -139,7 +139,7 @@ class DDPFullyShardedPlugin(DDPPlugin):
|
|||
cpu_offload=self.cpu_offload,
|
||||
move_grads_to_cpu=self.move_grads_to_cpu,
|
||||
flatten_parameters=self.flatten_parameters,
|
||||
mixed_precision=precision == "mixed",
|
||||
mixed_precision=(precision == PrecisionType.MIXED),
|
||||
reshard_after_forward=self.reshard_after_forward,
|
||||
fp32_reduce_scatter=self.fp32_reduce_scatter,
|
||||
compute_dtype=self.compute_dtype,
|
||||
|
|
|
@ -29,6 +29,7 @@ from pytorch_lightning.utilities import _IPU_AVAILABLE, _POPTORCH_AVAILABLE
|
|||
from pytorch_lightning.utilities.apply_func import apply_to_collection
|
||||
from pytorch_lightning.utilities.cloud_io import get_filesystem
|
||||
from pytorch_lightning.utilities.data import _get_dataloader_init_kwargs
|
||||
from pytorch_lightning.utilities.enums import PrecisionType
|
||||
from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
||||
|
||||
if _POPTORCH_AVAILABLE:
|
||||
|
@ -41,7 +42,7 @@ class LightningIPUModule(_LightningModuleWrapperBase):
|
|||
self.precision = precision
|
||||
|
||||
def forward(self, *inputs: Any, **kwargs: Any) -> Any:
|
||||
if self.precision in ("mixed", 16):
|
||||
if self.precision in (PrecisionType.MIXED, PrecisionType.HALF):
|
||||
inputs = self._move_float_tensors_to_half(inputs)
|
||||
|
||||
return super().forward(*inputs, **kwargs)
|
||||
|
|
|
@ -23,7 +23,7 @@ from pytorch_lightning.core.optimizer import LightningOptimizer
|
|||
from pytorch_lightning.plugins.training_type.ddp import DDPPlugin
|
||||
from pytorch_lightning.trainer.states import TrainerFn
|
||||
from pytorch_lightning.utilities import _FAIRSCALE_AVAILABLE, _FAIRSCALE_OSS_FP16_BROADCAST_AVAILABLE, rank_zero_only
|
||||
from pytorch_lightning.utilities.enums import _StrategyType
|
||||
from pytorch_lightning.utilities.enums import _StrategyType, PrecisionType
|
||||
from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
||||
|
||||
if _FAIRSCALE_AVAILABLE:
|
||||
|
@ -71,7 +71,7 @@ class DDPShardedPlugin(DDPPlugin):
|
|||
optim_class = type(optimizer)
|
||||
zero_optimizer = OSS(params=optimizer.param_groups, optim=optim_class, **optimizer.defaults)
|
||||
if _FAIRSCALE_OSS_FP16_BROADCAST_AVAILABLE:
|
||||
is_fp16 = self.precision_plugin.precision in ("mixed", 16)
|
||||
is_fp16 = self.precision_plugin.precision in (PrecisionType.MIXED, PrecisionType.HALF)
|
||||
# For multi-node training, compressing the model shards in fp16 before broadcasting
|
||||
# improves performance. When using PyTorch AMP, it will not degrade
|
||||
# the model performance.
|
||||
|
|
Loading…
Reference in New Issue