Use `PrecisionType` enum instead of checking raw values (#10704)
* use precision type * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
parent
b28ab34ff5
commit
89d0064b33
|
@ -37,7 +37,7 @@ from pytorch_lightning.trainer.states import TrainerFn
|
||||||
from pytorch_lightning.utilities import GradClipAlgorithmType
|
from pytorch_lightning.utilities import GradClipAlgorithmType
|
||||||
from pytorch_lightning.utilities.apply_func import apply_to_collection
|
from pytorch_lightning.utilities.apply_func import apply_to_collection
|
||||||
from pytorch_lightning.utilities.distributed import log, rank_zero_info
|
from pytorch_lightning.utilities.distributed import log, rank_zero_info
|
||||||
from pytorch_lightning.utilities.enums import _StrategyType, AMPType
|
from pytorch_lightning.utilities.enums import _StrategyType, AMPType, PrecisionType
|
||||||
from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
||||||
from pytorch_lightning.utilities.imports import _DEEPSPEED_AVAILABLE
|
from pytorch_lightning.utilities.imports import _DEEPSPEED_AVAILABLE
|
||||||
from pytorch_lightning.utilities.model_helpers import is_overridden
|
from pytorch_lightning.utilities.model_helpers import is_overridden
|
||||||
|
@ -445,7 +445,11 @@ class DeepSpeedPlugin(DDPPlugin):
|
||||||
|
|
||||||
if self.zero_stage_3 and self.partition_module:
|
if self.zero_stage_3 and self.partition_module:
|
||||||
# Ensure the entire model has been moved to the appropriate device
|
# Ensure the entire model has been moved to the appropriate device
|
||||||
dtype = torch.float16 if self.precision_plugin.precision in (16, "mixed") else torch.float32
|
dtype = (
|
||||||
|
torch.float16
|
||||||
|
if self.precision_plugin.precision in (PrecisionType.HALF, PrecisionType.MIXED)
|
||||||
|
else torch.float32
|
||||||
|
)
|
||||||
deepspeed.zero.Init(
|
deepspeed.zero.Init(
|
||||||
module=model, remote_device=self.remote_device, pin_memory=True, config=self.config, dtype=dtype
|
module=model, remote_device=self.remote_device, pin_memory=True, config=self.config, dtype=dtype
|
||||||
)
|
)
|
||||||
|
@ -502,7 +506,11 @@ class DeepSpeedPlugin(DDPPlugin):
|
||||||
def model_sharded_context(self) -> Generator[None, None, None]:
|
def model_sharded_context(self) -> Generator[None, None, None]:
|
||||||
if self.zero_stage_3:
|
if self.zero_stage_3:
|
||||||
assert self._config_initialized
|
assert self._config_initialized
|
||||||
dtype = torch.float16 if self.precision_plugin.precision in (16, "mixed") else torch.float32
|
dtype = (
|
||||||
|
torch.float16
|
||||||
|
if self.precision_plugin.precision in (PrecisionType.HALF, PrecisionType.MIXED)
|
||||||
|
else torch.float32
|
||||||
|
)
|
||||||
model_parallel_context = deepspeed.zero.Init(
|
model_parallel_context = deepspeed.zero.Init(
|
||||||
remote_device=self.remote_device, pin_memory=True, config=self.config, dtype=dtype
|
remote_device=self.remote_device, pin_memory=True, config=self.config, dtype=dtype
|
||||||
)
|
)
|
||||||
|
@ -629,7 +637,7 @@ class DeepSpeedPlugin(DDPPlugin):
|
||||||
return batch_size
|
return batch_size
|
||||||
|
|
||||||
def _format_precision_config(self) -> None:
|
def _format_precision_config(self) -> None:
|
||||||
if self.precision_plugin.precision in (16, "mixed"):
|
if self.precision_plugin.precision in (PrecisionType.HALF, PrecisionType.MIXED):
|
||||||
if "fp16" not in self.config and self.precision_plugin.amp_type == AMPType.NATIVE:
|
if "fp16" not in self.config and self.precision_plugin.amp_type == AMPType.NATIVE:
|
||||||
# FP16 is a DeepSpeed standalone AMP implementation
|
# FP16 is a DeepSpeed standalone AMP implementation
|
||||||
rank_zero_info("Enabling DeepSpeed FP16.")
|
rank_zero_info("Enabling DeepSpeed FP16.")
|
||||||
|
|
|
@ -21,7 +21,7 @@ from pytorch_lightning.plugins.io.checkpoint_plugin import CheckpointIO
|
||||||
from pytorch_lightning.plugins.precision import PrecisionPlugin
|
from pytorch_lightning.plugins.precision import PrecisionPlugin
|
||||||
from pytorch_lightning.plugins.training_type.ddp import DDPPlugin
|
from pytorch_lightning.plugins.training_type.ddp import DDPPlugin
|
||||||
from pytorch_lightning.utilities import _FAIRSCALE_FULLY_SHARDED_AVAILABLE
|
from pytorch_lightning.utilities import _FAIRSCALE_FULLY_SHARDED_AVAILABLE
|
||||||
from pytorch_lightning.utilities.enums import _StrategyType
|
from pytorch_lightning.utilities.enums import _StrategyType, PrecisionType
|
||||||
from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
||||||
|
|
||||||
if _FAIRSCALE_FULLY_SHARDED_AVAILABLE:
|
if _FAIRSCALE_FULLY_SHARDED_AVAILABLE:
|
||||||
|
@ -139,7 +139,7 @@ class DDPFullyShardedPlugin(DDPPlugin):
|
||||||
cpu_offload=self.cpu_offload,
|
cpu_offload=self.cpu_offload,
|
||||||
move_grads_to_cpu=self.move_grads_to_cpu,
|
move_grads_to_cpu=self.move_grads_to_cpu,
|
||||||
flatten_parameters=self.flatten_parameters,
|
flatten_parameters=self.flatten_parameters,
|
||||||
mixed_precision=precision == "mixed",
|
mixed_precision=(precision == PrecisionType.MIXED),
|
||||||
reshard_after_forward=self.reshard_after_forward,
|
reshard_after_forward=self.reshard_after_forward,
|
||||||
fp32_reduce_scatter=self.fp32_reduce_scatter,
|
fp32_reduce_scatter=self.fp32_reduce_scatter,
|
||||||
compute_dtype=self.compute_dtype,
|
compute_dtype=self.compute_dtype,
|
||||||
|
|
|
@ -29,6 +29,7 @@ from pytorch_lightning.utilities import _IPU_AVAILABLE, _POPTORCH_AVAILABLE
|
||||||
from pytorch_lightning.utilities.apply_func import apply_to_collection
|
from pytorch_lightning.utilities.apply_func import apply_to_collection
|
||||||
from pytorch_lightning.utilities.cloud_io import get_filesystem
|
from pytorch_lightning.utilities.cloud_io import get_filesystem
|
||||||
from pytorch_lightning.utilities.data import _get_dataloader_init_kwargs
|
from pytorch_lightning.utilities.data import _get_dataloader_init_kwargs
|
||||||
|
from pytorch_lightning.utilities.enums import PrecisionType
|
||||||
from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
||||||
|
|
||||||
if _POPTORCH_AVAILABLE:
|
if _POPTORCH_AVAILABLE:
|
||||||
|
@ -41,7 +42,7 @@ class LightningIPUModule(_LightningModuleWrapperBase):
|
||||||
self.precision = precision
|
self.precision = precision
|
||||||
|
|
||||||
def forward(self, *inputs: Any, **kwargs: Any) -> Any:
|
def forward(self, *inputs: Any, **kwargs: Any) -> Any:
|
||||||
if self.precision in ("mixed", 16):
|
if self.precision in (PrecisionType.MIXED, PrecisionType.HALF):
|
||||||
inputs = self._move_float_tensors_to_half(inputs)
|
inputs = self._move_float_tensors_to_half(inputs)
|
||||||
|
|
||||||
return super().forward(*inputs, **kwargs)
|
return super().forward(*inputs, **kwargs)
|
||||||
|
|
|
@ -23,7 +23,7 @@ from pytorch_lightning.core.optimizer import LightningOptimizer
|
||||||
from pytorch_lightning.plugins.training_type.ddp import DDPPlugin
|
from pytorch_lightning.plugins.training_type.ddp import DDPPlugin
|
||||||
from pytorch_lightning.trainer.states import TrainerFn
|
from pytorch_lightning.trainer.states import TrainerFn
|
||||||
from pytorch_lightning.utilities import _FAIRSCALE_AVAILABLE, _FAIRSCALE_OSS_FP16_BROADCAST_AVAILABLE, rank_zero_only
|
from pytorch_lightning.utilities import _FAIRSCALE_AVAILABLE, _FAIRSCALE_OSS_FP16_BROADCAST_AVAILABLE, rank_zero_only
|
||||||
from pytorch_lightning.utilities.enums import _StrategyType
|
from pytorch_lightning.utilities.enums import _StrategyType, PrecisionType
|
||||||
from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
||||||
|
|
||||||
if _FAIRSCALE_AVAILABLE:
|
if _FAIRSCALE_AVAILABLE:
|
||||||
|
@ -71,7 +71,7 @@ class DDPShardedPlugin(DDPPlugin):
|
||||||
optim_class = type(optimizer)
|
optim_class = type(optimizer)
|
||||||
zero_optimizer = OSS(params=optimizer.param_groups, optim=optim_class, **optimizer.defaults)
|
zero_optimizer = OSS(params=optimizer.param_groups, optim=optim_class, **optimizer.defaults)
|
||||||
if _FAIRSCALE_OSS_FP16_BROADCAST_AVAILABLE:
|
if _FAIRSCALE_OSS_FP16_BROADCAST_AVAILABLE:
|
||||||
is_fp16 = self.precision_plugin.precision in ("mixed", 16)
|
is_fp16 = self.precision_plugin.precision in (PrecisionType.MIXED, PrecisionType.HALF)
|
||||||
# For multi-node training, compressing the model shards in fp16 before broadcasting
|
# For multi-node training, compressing the model shards in fp16 before broadcasting
|
||||||
# improves performance. When using PyTorch AMP, it will not degrade
|
# improves performance. When using PyTorch AMP, it will not degrade
|
||||||
# the model performance.
|
# the model performance.
|
||||||
|
|
Loading…
Reference in New Issue