Use `PrecisionType` enum instead of checking raw values (#10704)

* use precision type

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
Adrian Wälchli 2021-11-23 18:23:36 +01:00 committed by GitHub
parent b28ab34ff5
commit 89d0064b33
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 18 additions and 9 deletions

View File

@ -37,7 +37,7 @@ from pytorch_lightning.trainer.states import TrainerFn
from pytorch_lightning.utilities import GradClipAlgorithmType from pytorch_lightning.utilities import GradClipAlgorithmType
from pytorch_lightning.utilities.apply_func import apply_to_collection from pytorch_lightning.utilities.apply_func import apply_to_collection
from pytorch_lightning.utilities.distributed import log, rank_zero_info from pytorch_lightning.utilities.distributed import log, rank_zero_info
from pytorch_lightning.utilities.enums import _StrategyType, AMPType from pytorch_lightning.utilities.enums import _StrategyType, AMPType, PrecisionType
from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.imports import _DEEPSPEED_AVAILABLE from pytorch_lightning.utilities.imports import _DEEPSPEED_AVAILABLE
from pytorch_lightning.utilities.model_helpers import is_overridden from pytorch_lightning.utilities.model_helpers import is_overridden
@ -445,7 +445,11 @@ class DeepSpeedPlugin(DDPPlugin):
if self.zero_stage_3 and self.partition_module: if self.zero_stage_3 and self.partition_module:
# Ensure the entire model has been moved to the appropriate device # Ensure the entire model has been moved to the appropriate device
dtype = torch.float16 if self.precision_plugin.precision in (16, "mixed") else torch.float32 dtype = (
torch.float16
if self.precision_plugin.precision in (PrecisionType.HALF, PrecisionType.MIXED)
else torch.float32
)
deepspeed.zero.Init( deepspeed.zero.Init(
module=model, remote_device=self.remote_device, pin_memory=True, config=self.config, dtype=dtype module=model, remote_device=self.remote_device, pin_memory=True, config=self.config, dtype=dtype
) )
@ -502,7 +506,11 @@ class DeepSpeedPlugin(DDPPlugin):
def model_sharded_context(self) -> Generator[None, None, None]: def model_sharded_context(self) -> Generator[None, None, None]:
if self.zero_stage_3: if self.zero_stage_3:
assert self._config_initialized assert self._config_initialized
dtype = torch.float16 if self.precision_plugin.precision in (16, "mixed") else torch.float32 dtype = (
torch.float16
if self.precision_plugin.precision in (PrecisionType.HALF, PrecisionType.MIXED)
else torch.float32
)
model_parallel_context = deepspeed.zero.Init( model_parallel_context = deepspeed.zero.Init(
remote_device=self.remote_device, pin_memory=True, config=self.config, dtype=dtype remote_device=self.remote_device, pin_memory=True, config=self.config, dtype=dtype
) )
@ -629,7 +637,7 @@ class DeepSpeedPlugin(DDPPlugin):
return batch_size return batch_size
def _format_precision_config(self) -> None: def _format_precision_config(self) -> None:
if self.precision_plugin.precision in (16, "mixed"): if self.precision_plugin.precision in (PrecisionType.HALF, PrecisionType.MIXED):
if "fp16" not in self.config and self.precision_plugin.amp_type == AMPType.NATIVE: if "fp16" not in self.config and self.precision_plugin.amp_type == AMPType.NATIVE:
# FP16 is a DeepSpeed standalone AMP implementation # FP16 is a DeepSpeed standalone AMP implementation
rank_zero_info("Enabling DeepSpeed FP16.") rank_zero_info("Enabling DeepSpeed FP16.")

View File

@ -21,7 +21,7 @@ from pytorch_lightning.plugins.io.checkpoint_plugin import CheckpointIO
from pytorch_lightning.plugins.precision import PrecisionPlugin from pytorch_lightning.plugins.precision import PrecisionPlugin
from pytorch_lightning.plugins.training_type.ddp import DDPPlugin from pytorch_lightning.plugins.training_type.ddp import DDPPlugin
from pytorch_lightning.utilities import _FAIRSCALE_FULLY_SHARDED_AVAILABLE from pytorch_lightning.utilities import _FAIRSCALE_FULLY_SHARDED_AVAILABLE
from pytorch_lightning.utilities.enums import _StrategyType from pytorch_lightning.utilities.enums import _StrategyType, PrecisionType
from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.exceptions import MisconfigurationException
if _FAIRSCALE_FULLY_SHARDED_AVAILABLE: if _FAIRSCALE_FULLY_SHARDED_AVAILABLE:
@ -139,7 +139,7 @@ class DDPFullyShardedPlugin(DDPPlugin):
cpu_offload=self.cpu_offload, cpu_offload=self.cpu_offload,
move_grads_to_cpu=self.move_grads_to_cpu, move_grads_to_cpu=self.move_grads_to_cpu,
flatten_parameters=self.flatten_parameters, flatten_parameters=self.flatten_parameters,
mixed_precision=precision == "mixed", mixed_precision=(precision == PrecisionType.MIXED),
reshard_after_forward=self.reshard_after_forward, reshard_after_forward=self.reshard_after_forward,
fp32_reduce_scatter=self.fp32_reduce_scatter, fp32_reduce_scatter=self.fp32_reduce_scatter,
compute_dtype=self.compute_dtype, compute_dtype=self.compute_dtype,

View File

@ -29,6 +29,7 @@ from pytorch_lightning.utilities import _IPU_AVAILABLE, _POPTORCH_AVAILABLE
from pytorch_lightning.utilities.apply_func import apply_to_collection from pytorch_lightning.utilities.apply_func import apply_to_collection
from pytorch_lightning.utilities.cloud_io import get_filesystem from pytorch_lightning.utilities.cloud_io import get_filesystem
from pytorch_lightning.utilities.data import _get_dataloader_init_kwargs from pytorch_lightning.utilities.data import _get_dataloader_init_kwargs
from pytorch_lightning.utilities.enums import PrecisionType
from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.exceptions import MisconfigurationException
if _POPTORCH_AVAILABLE: if _POPTORCH_AVAILABLE:
@ -41,7 +42,7 @@ class LightningIPUModule(_LightningModuleWrapperBase):
self.precision = precision self.precision = precision
def forward(self, *inputs: Any, **kwargs: Any) -> Any: def forward(self, *inputs: Any, **kwargs: Any) -> Any:
if self.precision in ("mixed", 16): if self.precision in (PrecisionType.MIXED, PrecisionType.HALF):
inputs = self._move_float_tensors_to_half(inputs) inputs = self._move_float_tensors_to_half(inputs)
return super().forward(*inputs, **kwargs) return super().forward(*inputs, **kwargs)

View File

@ -23,7 +23,7 @@ from pytorch_lightning.core.optimizer import LightningOptimizer
from pytorch_lightning.plugins.training_type.ddp import DDPPlugin from pytorch_lightning.plugins.training_type.ddp import DDPPlugin
from pytorch_lightning.trainer.states import TrainerFn from pytorch_lightning.trainer.states import TrainerFn
from pytorch_lightning.utilities import _FAIRSCALE_AVAILABLE, _FAIRSCALE_OSS_FP16_BROADCAST_AVAILABLE, rank_zero_only from pytorch_lightning.utilities import _FAIRSCALE_AVAILABLE, _FAIRSCALE_OSS_FP16_BROADCAST_AVAILABLE, rank_zero_only
from pytorch_lightning.utilities.enums import _StrategyType from pytorch_lightning.utilities.enums import _StrategyType, PrecisionType
from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.exceptions import MisconfigurationException
if _FAIRSCALE_AVAILABLE: if _FAIRSCALE_AVAILABLE:
@ -71,7 +71,7 @@ class DDPShardedPlugin(DDPPlugin):
optim_class = type(optimizer) optim_class = type(optimizer)
zero_optimizer = OSS(params=optimizer.param_groups, optim=optim_class, **optimizer.defaults) zero_optimizer = OSS(params=optimizer.param_groups, optim=optim_class, **optimizer.defaults)
if _FAIRSCALE_OSS_FP16_BROADCAST_AVAILABLE: if _FAIRSCALE_OSS_FP16_BROADCAST_AVAILABLE:
is_fp16 = self.precision_plugin.precision in ("mixed", 16) is_fp16 = self.precision_plugin.precision in (PrecisionType.MIXED, PrecisionType.HALF)
# For multi-node training, compressing the model shards in fp16 before broadcasting # For multi-node training, compressing the model shards in fp16 before broadcasting
# improves performance. When using PyTorch AMP, it will not degrade # improves performance. When using PyTorch AMP, it will not degrade
# the model performance. # the model performance.