Use `PrecisionType` enum instead of checking raw values (#10704)

* use precision type

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
Adrian Wälchli 2021-11-23 18:23:36 +01:00 committed by GitHub
parent b28ab34ff5
commit 89d0064b33
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 18 additions and 9 deletions

View File

@ -37,7 +37,7 @@ from pytorch_lightning.trainer.states import TrainerFn
from pytorch_lightning.utilities import GradClipAlgorithmType
from pytorch_lightning.utilities.apply_func import apply_to_collection
from pytorch_lightning.utilities.distributed import log, rank_zero_info
from pytorch_lightning.utilities.enums import _StrategyType, AMPType
from pytorch_lightning.utilities.enums import _StrategyType, AMPType, PrecisionType
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.imports import _DEEPSPEED_AVAILABLE
from pytorch_lightning.utilities.model_helpers import is_overridden
@ -445,7 +445,11 @@ class DeepSpeedPlugin(DDPPlugin):
if self.zero_stage_3 and self.partition_module:
# Ensure the entire model has been moved to the appropriate device
dtype = torch.float16 if self.precision_plugin.precision in (16, "mixed") else torch.float32
dtype = (
torch.float16
if self.precision_plugin.precision in (PrecisionType.HALF, PrecisionType.MIXED)
else torch.float32
)
deepspeed.zero.Init(
module=model, remote_device=self.remote_device, pin_memory=True, config=self.config, dtype=dtype
)
@ -502,7 +506,11 @@ class DeepSpeedPlugin(DDPPlugin):
def model_sharded_context(self) -> Generator[None, None, None]:
if self.zero_stage_3:
assert self._config_initialized
dtype = torch.float16 if self.precision_plugin.precision in (16, "mixed") else torch.float32
dtype = (
torch.float16
if self.precision_plugin.precision in (PrecisionType.HALF, PrecisionType.MIXED)
else torch.float32
)
model_parallel_context = deepspeed.zero.Init(
remote_device=self.remote_device, pin_memory=True, config=self.config, dtype=dtype
)
@ -629,7 +637,7 @@ class DeepSpeedPlugin(DDPPlugin):
return batch_size
def _format_precision_config(self) -> None:
if self.precision_plugin.precision in (16, "mixed"):
if self.precision_plugin.precision in (PrecisionType.HALF, PrecisionType.MIXED):
if "fp16" not in self.config and self.precision_plugin.amp_type == AMPType.NATIVE:
# FP16 is a DeepSpeed standalone AMP implementation
rank_zero_info("Enabling DeepSpeed FP16.")

View File

@ -21,7 +21,7 @@ from pytorch_lightning.plugins.io.checkpoint_plugin import CheckpointIO
from pytorch_lightning.plugins.precision import PrecisionPlugin
from pytorch_lightning.plugins.training_type.ddp import DDPPlugin
from pytorch_lightning.utilities import _FAIRSCALE_FULLY_SHARDED_AVAILABLE
from pytorch_lightning.utilities.enums import _StrategyType
from pytorch_lightning.utilities.enums import _StrategyType, PrecisionType
from pytorch_lightning.utilities.exceptions import MisconfigurationException
if _FAIRSCALE_FULLY_SHARDED_AVAILABLE:
@ -139,7 +139,7 @@ class DDPFullyShardedPlugin(DDPPlugin):
cpu_offload=self.cpu_offload,
move_grads_to_cpu=self.move_grads_to_cpu,
flatten_parameters=self.flatten_parameters,
mixed_precision=precision == "mixed",
mixed_precision=(precision == PrecisionType.MIXED),
reshard_after_forward=self.reshard_after_forward,
fp32_reduce_scatter=self.fp32_reduce_scatter,
compute_dtype=self.compute_dtype,

View File

@ -29,6 +29,7 @@ from pytorch_lightning.utilities import _IPU_AVAILABLE, _POPTORCH_AVAILABLE
from pytorch_lightning.utilities.apply_func import apply_to_collection
from pytorch_lightning.utilities.cloud_io import get_filesystem
from pytorch_lightning.utilities.data import _get_dataloader_init_kwargs
from pytorch_lightning.utilities.enums import PrecisionType
from pytorch_lightning.utilities.exceptions import MisconfigurationException
if _POPTORCH_AVAILABLE:
@ -41,7 +42,7 @@ class LightningIPUModule(_LightningModuleWrapperBase):
self.precision = precision
def forward(self, *inputs: Any, **kwargs: Any) -> Any:
if self.precision in ("mixed", 16):
if self.precision in (PrecisionType.MIXED, PrecisionType.HALF):
inputs = self._move_float_tensors_to_half(inputs)
return super().forward(*inputs, **kwargs)

View File

@ -23,7 +23,7 @@ from pytorch_lightning.core.optimizer import LightningOptimizer
from pytorch_lightning.plugins.training_type.ddp import DDPPlugin
from pytorch_lightning.trainer.states import TrainerFn
from pytorch_lightning.utilities import _FAIRSCALE_AVAILABLE, _FAIRSCALE_OSS_FP16_BROADCAST_AVAILABLE, rank_zero_only
from pytorch_lightning.utilities.enums import _StrategyType
from pytorch_lightning.utilities.enums import _StrategyType, PrecisionType
from pytorch_lightning.utilities.exceptions import MisconfigurationException
if _FAIRSCALE_AVAILABLE:
@ -71,7 +71,7 @@ class DDPShardedPlugin(DDPPlugin):
optim_class = type(optimizer)
zero_optimizer = OSS(params=optimizer.param_groups, optim=optim_class, **optimizer.defaults)
if _FAIRSCALE_OSS_FP16_BROADCAST_AVAILABLE:
is_fp16 = self.precision_plugin.precision in ("mixed", 16)
is_fp16 = self.precision_plugin.precision in (PrecisionType.MIXED, PrecisionType.HALF)
# For multi-node training, compressing the model shards in fp16 before broadcasting
# improves performance. When using PyTorch AMP, it will not degrade
# the model performance.