diff --git a/CHANGELOG.md b/CHANGELOG.md index 9ada4815bc..ae0515cf22 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -37,6 +37,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Raise `MisconfigurationException` when `enable_progress_bar=False` and a progress bar instance has been passed in the callback list ([#10520](https://github.com/PyTorchLightning/pytorch-lightning/issues/10520)) +- Moved ownership of the `PrecisionPlugin` into `TrainingTypePlugin` and updated all references ([#10570](https://github.com/PyTorchLightning/pytorch-lightning/pull/10570)) + + - @@ -50,7 +53,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Deprecated `DistributedType` in favor of `_StrategyType` ([#10505](https://github.com/PyTorchLightning/pytorch-lightning/pull/10505)) -- +- Deprecated the `precision_plugin` constructor argument from `Accelerator` ([#10570](https://github.com/PyTorchLightning/pytorch-lightning/pull/10570)) - @@ -139,6 +142,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Removed deprecated `reload_dataloaders_every_epoch` from `Trainer` in favour of `reload_dataloaders_every_n_epochs` ([#10481](https://github.com/PyTorchLightning/pytorch-lightning/pull/10481)) +- Removed the `precision_plugin` attribute from `Accelerator` in favor of its equivalent attribute `precision_plugin` in the `TrainingTypePlugin` ([#10570](https://github.com/PyTorchLightning/pytorch-lightning/pull/10570)) ### Fixed diff --git a/pytorch_lightning/accelerators/accelerator.py b/pytorch_lightning/accelerators/accelerator.py index 14b6a47c72..eb3886b209 100644 --- a/pytorch_lightning/accelerators/accelerator.py +++ b/pytorch_lightning/accelerators/accelerator.py @@ -25,6 +25,7 @@ import pytorch_lightning as pl from pytorch_lightning.plugins.precision import ApexMixedPrecisionPlugin, NativeMixedPrecisionPlugin, PrecisionPlugin from pytorch_lightning.plugins.training_type import DataParallelPlugin, TrainingTypePlugin from pytorch_lightning.trainer.states import TrainerFn +from pytorch_lightning.utilities import rank_zero_deprecation from pytorch_lightning.utilities.apply_func import apply_to_collection, move_data_to_device from pytorch_lightning.utilities.enums import AMPType, LightningEnum from pytorch_lightning.utilities.types import STEP_OUTPUT @@ -44,15 +45,23 @@ class Accelerator: One to handle differences from the training routine and one to handle different precisions. """ - def __init__(self, precision_plugin: PrecisionPlugin, training_type_plugin: TrainingTypePlugin) -> None: + def __init__(self, precision_plugin: Optional[PrecisionPlugin], training_type_plugin: TrainingTypePlugin) -> None: """ Args: precision_plugin: the plugin to handle precision-specific parts + + .. deprecated:: + The ``precision_plugin`` parameter has been deprecated and will be removed soon. + Pass the precision plugin as a parameter to the ``TrainingTypePlugin`` instead. + training_type_plugin: the plugin to handle different training routines """ - self.precision_plugin = precision_plugin + self.training_type_plugin = training_type_plugin + if precision_plugin is not None: + self.training_type_plugin._precision_plugin = precision_plugin + self.optimizers: List = [] self.lr_schedulers: List = [] self.optimizer_frequencies: List = [] @@ -84,7 +93,7 @@ class Accelerator: if self.training_type_plugin.setup_optimizers_in_pre_dispatch: self.setup_optimizers(trainer) - self.precision_plugin.pre_dispatch() + self.training_type_plugin.precision_plugin.pre_dispatch() def _move_optimizer_state(self, device: Optional[torch.device] = None) -> None: """Moves the state of the optimizers to the GPU if needed.""" @@ -96,12 +105,12 @@ class Accelerator: def dispatch(self, trainer: "pl.Trainer") -> None: """Hook to do something before the training/evaluation/prediction starts.""" self.training_type_plugin.dispatch(trainer) - self.precision_plugin.dispatch(trainer) + self.training_type_plugin.precision_plugin.dispatch(trainer) def post_dispatch(self, trainer: "pl.Trainer") -> None: """Hook to do something after the training/evaluation/prediction starts.""" self.training_type_plugin.post_dispatch(trainer) - self.precision_plugin.post_dispatch() + self.training_type_plugin.precision_plugin.post_dispatch() @property def model(self) -> Module: @@ -159,7 +168,7 @@ class Accelerator: See :meth:`~pytorch_lightning.core.lightning.LightningModule.training_step` for more details """ - with self.precision_plugin.train_step_context(): + with self.training_type_plugin.precision_plugin.train_step_context(): return self.training_type_plugin.training_step(*step_kwargs.values()) def validation_step(self, step_kwargs: Dict[str, Union[Any, int]]) -> Optional[STEP_OUTPUT]: @@ -167,7 +176,7 @@ class Accelerator: See :meth:`~pytorch_lightning.core.lightning.LightningModule.validation_step` for more details """ - with self.precision_plugin.val_step_context(): + with self.training_type_plugin.precision_plugin.val_step_context(): return self.training_type_plugin.validation_step(*step_kwargs.values()) def test_step(self, step_kwargs: Dict[str, Union[Any, int]]) -> Optional[STEP_OUTPUT]: @@ -175,7 +184,7 @@ class Accelerator: See :meth:`~pytorch_lightning.core.lightning.LightningModule.test_step` for more details """ - with self.precision_plugin.test_step_context(): + with self.training_type_plugin.precision_plugin.test_step_context(): return self.training_type_plugin.test_step(*step_kwargs.values()) def predict_step(self, step_kwargs: Dict[str, Union[Any, int]]) -> STEP_OUTPUT: @@ -183,7 +192,7 @@ class Accelerator: See :meth:`~pytorch_lightning.core.lightning.LightningModule.predict_step` for more details """ - with self.precision_plugin.predict_step_context(): + with self.training_type_plugin.precision_plugin.predict_step_context(): return self.training_type_plugin.predict_step(*step_kwargs.values()) def backward(self, closure_loss: Tensor, *args: Any, **kwargs: Any) -> Tensor: @@ -193,11 +202,11 @@ class Accelerator: closure_loss: a tensor holding the loss value to backpropagate """ self.training_type_plugin.pre_backward(closure_loss) - closure_loss = self.precision_plugin.pre_backward(self.lightning_module, closure_loss) + closure_loss = self.training_type_plugin.precision_plugin.pre_backward(self.lightning_module, closure_loss) - self.precision_plugin.backward(self.lightning_module, closure_loss, *args, **kwargs) + self.training_type_plugin.precision_plugin.backward(self.lightning_module, closure_loss, *args, **kwargs) - closure_loss = self.precision_plugin.post_backward(self.lightning_module, closure_loss) + closure_loss = self.training_type_plugin.precision_plugin.post_backward(self.lightning_module, closure_loss) self.training_type_plugin.post_backward(closure_loss) return closure_loss @@ -208,7 +217,7 @@ class Accelerator: opt_idx: int, closure: Callable[[], Any], model: Optional[Union["pl.LightningModule", Module]] = None, - **kwargs: Any + **kwargs: Any, ) -> None: """performs the actual optimizer step. @@ -220,7 +229,7 @@ class Accelerator: **kwargs: Any extra arguments to ``optimizer.step`` """ model = model or self.lightning_module - self.precision_plugin.optimizer_step(model, optimizer, opt_idx, closure, **kwargs) + self.training_type_plugin.precision_plugin.optimizer_step(model, optimizer, opt_idx, closure, **kwargs) def optimizer_zero_grad(self, current_epoch: int, batch_idx: int, optimizer: Optimizer, opt_idx: int) -> None: """Zeros all model parameter's gradients.""" @@ -248,26 +257,38 @@ class Accelerator: def setup_precision_plugin(self) -> None: """Attaches the precision plugin to the accelerator.""" - model, optimizers, schedulers = self.precision_plugin.connect(self.model, self.optimizers, self.lr_schedulers) + model, optimizers, schedulers = self.training_type_plugin.precision_plugin.connect( + self.model, self.optimizers, self.lr_schedulers + ) self.model = model self.optimizers = optimizers self.lr_schedulers = schedulers @property def amp_backend(self) -> Optional[LightningEnum]: - if isinstance(self.precision_plugin, ApexMixedPrecisionPlugin): + if isinstance(self.training_type_plugin.precision_plugin, ApexMixedPrecisionPlugin): return AMPType.APEX - if isinstance(self.precision_plugin, NativeMixedPrecisionPlugin): + if isinstance(self.training_type_plugin.precision_plugin, NativeMixedPrecisionPlugin): return AMPType.NATIVE return None @property def precision(self) -> Union[str, int]: - return self.precision_plugin.precision + """The type of precision being used with this accelerator. + + .. deprecated:: + This property been deprecated and will be removed soon. + Use ``training_type_plugin.precision_plugin.precision`` instead. + """ + rank_zero_deprecation( + f"`{self.__class__.__name__}.precision` has been deprecated and will be removed soon" + f" Use `training_type_plugin.precision_plugin.precision` instead." + ) + return self.training_type_plugin.precision_plugin.precision @property def scaler(self) -> Optional["GradScaler"]: - return getattr(self.precision_plugin, "scaler", None) + return getattr(self.training_type_plugin.precision_plugin, "scaler", None) def optimizer_state(self, optimizer: Optimizer) -> Dict[str, Tensor]: """Returns state of an optimizer. diff --git a/pytorch_lightning/accelerators/tpu.py b/pytorch_lightning/accelerators/tpu.py index 6e824a25f6..673e8419ca 100644 --- a/pytorch_lightning/accelerators/tpu.py +++ b/pytorch_lightning/accelerators/tpu.py @@ -36,10 +36,11 @@ class TPUAccelerator(Accelerator): ValueError: If the precision or training type plugin are unsupported. """ - if not isinstance(self.precision_plugin, TPUPrecisionPlugin): + if not isinstance(self.training_type_plugin.precision_plugin, TPUPrecisionPlugin): # this configuration should have been avoided in the accelerator connector raise ValueError( - f"The `TPUAccelerator` can only be used with a `TPUPrecisionPlugin`, found: {self.precision_plugin}." + f"The `TPUAccelerator` can only be used with a `TPUPrecisionPlugin`," + f" found: {self.training_type_plugin.precision_plugin}." ) if not isinstance(self.training_type_plugin, (SingleTPUPlugin, TPUSpawnPlugin)): raise ValueError( diff --git a/pytorch_lightning/lite/lite.py b/pytorch_lightning/lite/lite.py index 2a2ed9586b..bb07c76315 100644 --- a/pytorch_lightning/lite/lite.py +++ b/pytorch_lightning/lite/lite.py @@ -108,7 +108,7 @@ class LightningLite(ABC): ) self._accelerator = self._accelerator_connector.accelerator self._strategy = self._accelerator.training_type_plugin - self._precision_plugin = self._accelerator.precision_plugin + self._precision_plugin = self._strategy.precision_plugin self._models_setup: int = 0 # wrap the run method so we can inject setup logic or spawn processes for the user diff --git a/pytorch_lightning/plugins/training_type/ddp.py b/pytorch_lightning/plugins/training_type/ddp.py index 0285859a67..6d1b168d5a 100644 --- a/pytorch_lightning/plugins/training_type/ddp.py +++ b/pytorch_lightning/plugins/training_type/ddp.py @@ -36,6 +36,7 @@ from pytorch_lightning.overrides import LightningDistributedModule from pytorch_lightning.overrides.distributed import prepare_for_backward from pytorch_lightning.plugins.environments.cluster_environment import ClusterEnvironment from pytorch_lightning.plugins.io.checkpoint_plugin import CheckpointIO +from pytorch_lightning.plugins.precision import PrecisionPlugin from pytorch_lightning.plugins.training_type.parallel import ParallelPlugin from pytorch_lightning.trainer.states import TrainerFn from pytorch_lightning.utilities import ( @@ -86,6 +87,7 @@ class DDPPlugin(ParallelPlugin): parallel_devices: Optional[List[torch.device]] = None, cluster_environment: Optional[ClusterEnvironment] = None, checkpoint_io: Optional[CheckpointIO] = None, + precision_plugin: Optional[PrecisionPlugin] = None, ddp_comm_state: Optional[object] = None, ddp_comm_hook: Optional[callable] = None, ddp_comm_wrapper: Optional[callable] = None, @@ -96,6 +98,7 @@ class DDPPlugin(ParallelPlugin): parallel_devices=parallel_devices, cluster_environment=cluster_environment, checkpoint_io=checkpoint_io, + precision_plugin=precision_plugin, ) self.interactive_ddp_procs = [] self._num_nodes = 1 diff --git a/pytorch_lightning/plugins/training_type/ddp_spawn.py b/pytorch_lightning/plugins/training_type/ddp_spawn.py index a77027adb6..da724944ad 100644 --- a/pytorch_lightning/plugins/training_type/ddp_spawn.py +++ b/pytorch_lightning/plugins/training_type/ddp_spawn.py @@ -29,6 +29,7 @@ from pytorch_lightning.overrides import LightningDistributedModule from pytorch_lightning.overrides.distributed import prepare_for_backward from pytorch_lightning.plugins.environments.cluster_environment import ClusterEnvironment from pytorch_lightning.plugins.io.checkpoint_plugin import CheckpointIO +from pytorch_lightning.plugins.precision import PrecisionPlugin from pytorch_lightning.plugins.training_type.parallel import ParallelPlugin from pytorch_lightning.trainer.states import TrainerFn from pytorch_lightning.utilities import _TORCH_GREATER_EQUAL_1_8, rank_zero_warn @@ -65,6 +66,7 @@ class DDPSpawnPlugin(ParallelPlugin): parallel_devices: Optional[List[torch.device]] = None, cluster_environment: Optional[ClusterEnvironment] = None, checkpoint_io: Optional[CheckpointIO] = None, + precision_plugin: Optional[PrecisionPlugin] = None, ddp_comm_state: Optional[object] = None, ddp_comm_hook: Optional[callable] = None, ddp_comm_wrapper: Optional[callable] = None, @@ -74,6 +76,7 @@ class DDPSpawnPlugin(ParallelPlugin): parallel_devices=parallel_devices, cluster_environment=cluster_environment, checkpoint_io=checkpoint_io, + precision_plugin=precision_plugin, ) self._num_nodes = 1 self.sync_batchnorm = False diff --git a/pytorch_lightning/plugins/training_type/deepspeed.py b/pytorch_lightning/plugins/training_type/deepspeed.py index eb087ad199..01959bdcee 100644 --- a/pytorch_lightning/plugins/training_type/deepspeed.py +++ b/pytorch_lightning/plugins/training_type/deepspeed.py @@ -30,6 +30,7 @@ import pytorch_lightning as pl from pytorch_lightning.overrides.base import _LightningModuleWrapperBase from pytorch_lightning.plugins.environments.cluster_environment import ClusterEnvironment from pytorch_lightning.plugins.io.checkpoint_plugin import CheckpointIO +from pytorch_lightning.plugins.precision import PrecisionPlugin from pytorch_lightning.plugins.training_type.ddp import DDPPlugin from pytorch_lightning.trainer.optimizers import _get_default_scheduler_config from pytorch_lightning.trainer.states import TrainerFn @@ -129,6 +130,7 @@ class DeepSpeedPlugin(DDPPlugin): synchronize_checkpoint_boundary: bool = False, load_full_weights: bool = False, partition_module: bool = True, + precision_plugin: Optional[PrecisionPlugin] = None, ) -> None: """Provides capabilities to run training using the DeepSpeed library, with training optimizations for large billion parameter models. `For more information: https://pytorch- @@ -273,6 +275,7 @@ class DeepSpeedPlugin(DDPPlugin): super().__init__( parallel_devices=parallel_devices, cluster_environment=cluster_environment, + precision_plugin=precision_plugin, ) self.config = self._load_config(config) @@ -331,7 +334,7 @@ class DeepSpeedPlugin(DDPPlugin): @property def precision(self) -> Union[str, int]: - return self._precision or self.lightning_module.trainer.precision + return self._precision or self.precision_plugin.precision @property def amp_level(self) -> Optional[str]: @@ -456,8 +459,7 @@ class DeepSpeedPlugin(DDPPlugin): "DeepSpeed currently does not support different `accumulate_grad_batches` at different epochs." ) - precision = self.lightning_module.trainer.accelerator.precision - model = LightningDeepSpeedModule(pl_module=self.model, precision=precision) + model = LightningDeepSpeedModule(pl_module=self.model, precision=self.precision) if self.zero_stage_3 and self.partition_module: # Ensure the entire model has been moved to the appropriate device diff --git a/pytorch_lightning/plugins/training_type/dp.py b/pytorch_lightning/plugins/training_type/dp.py index 83328e8c47..3f1b9a3acf 100644 --- a/pytorch_lightning/plugins/training_type/dp.py +++ b/pytorch_lightning/plugins/training_type/dp.py @@ -18,6 +18,7 @@ from torch.nn import DataParallel, Module from pytorch_lightning.overrides.data_parallel import LightningParallelModule from pytorch_lightning.plugins.io.checkpoint_plugin import CheckpointIO +from pytorch_lightning.plugins.precision import PrecisionPlugin from pytorch_lightning.plugins.training_type.parallel import ParallelPlugin from pytorch_lightning.utilities.apply_func import apply_to_collection from pytorch_lightning.utilities.enums import _StrategyType @@ -35,8 +36,14 @@ class DataParallelPlugin(ParallelPlugin): self, parallel_devices: Optional[List[torch.device]] = None, checkpoint_io: Optional[CheckpointIO] = None, + precision_plugin: Optional[PrecisionPlugin] = None, ): - super().__init__(parallel_devices=parallel_devices, cluster_environment=None, checkpoint_io=checkpoint_io) + super().__init__( + parallel_devices=parallel_devices, + cluster_environment=None, + checkpoint_io=checkpoint_io, + precision_plugin=precision_plugin, + ) @property def global_rank(self) -> int: diff --git a/pytorch_lightning/plugins/training_type/fully_sharded.py b/pytorch_lightning/plugins/training_type/fully_sharded.py index c9601a905d..73ea87b058 100644 --- a/pytorch_lightning/plugins/training_type/fully_sharded.py +++ b/pytorch_lightning/plugins/training_type/fully_sharded.py @@ -18,6 +18,7 @@ import torch from pytorch_lightning.plugins.environments.cluster_environment import ClusterEnvironment from pytorch_lightning.plugins.io.checkpoint_plugin import CheckpointIO +from pytorch_lightning.plugins.precision import PrecisionPlugin from pytorch_lightning.plugins.training_type.ddp import DDPPlugin from pytorch_lightning.utilities import _FAIRSCALE_FULLY_SHARDED_AVAILABLE from pytorch_lightning.utilities.enums import _StrategyType @@ -46,6 +47,7 @@ class DDPFullyShardedPlugin(DDPPlugin): parallel_devices: Optional[List[torch.device]] = None, cluster_environment: Optional[ClusterEnvironment] = None, checkpoint_io: Optional[CheckpointIO] = None, + precision_plugin: Optional[PrecisionPlugin] = None, ): """Plugin for Fully Sharded Data Parallel provided by FairScale. @@ -97,6 +99,7 @@ class DDPFullyShardedPlugin(DDPPlugin): parallel_devices=parallel_devices, cluster_environment=cluster_environment, checkpoint_io=checkpoint_io, + precision_plugin=precision_plugin, ) self.cpu_offload = cpu_offload self.move_grads_to_cpu = move_grads_to_cpu @@ -124,7 +127,7 @@ class DDPFullyShardedPlugin(DDPPlugin): @contextlib.contextmanager def model_sharded_context(self) -> Generator: - precision = self.lightning_module.trainer.precision + precision = self.precision_plugin.precision def wrap_policy(*args, **kwargs): return default_auto_wrap_policy(*args, **kwargs, min_num_params=self.min_num_params) diff --git a/pytorch_lightning/plugins/training_type/horovod.py b/pytorch_lightning/plugins/training_type/horovod.py index 51558189a3..961d2764b8 100644 --- a/pytorch_lightning/plugins/training_type/horovod.py +++ b/pytorch_lightning/plugins/training_type/horovod.py @@ -21,6 +21,7 @@ from torch.optim.lr_scheduler import _LRScheduler from pytorch_lightning.core.optimizer import LightningOptimizer from pytorch_lightning.plugins.io.checkpoint_plugin import CheckpointIO +from pytorch_lightning.plugins.precision import PrecisionPlugin from pytorch_lightning.plugins.training_type.parallel import ParallelPlugin from pytorch_lightning.utilities import _HOROVOD_AVAILABLE from pytorch_lightning.utilities.distributed import distributed_available @@ -41,8 +42,14 @@ class HorovodPlugin(ParallelPlugin): self, parallel_devices: Optional[List[torch.device]] = None, checkpoint_io: Optional[CheckpointIO] = None, + precision_plugin: Optional[PrecisionPlugin] = None, ): - super().__init__(parallel_devices=parallel_devices, cluster_environment=None, checkpoint_io=checkpoint_io) + super().__init__( + parallel_devices=parallel_devices, + cluster_environment=None, + checkpoint_io=checkpoint_io, + precision_plugin=precision_plugin, + ) rank_zero_only.rank = self.global_rank @property diff --git a/pytorch_lightning/plugins/training_type/ipu.py b/pytorch_lightning/plugins/training_type/ipu.py index 898e62791d..c24008ac3e 100644 --- a/pytorch_lightning/plugins/training_type/ipu.py +++ b/pytorch_lightning/plugins/training_type/ipu.py @@ -22,6 +22,7 @@ import pytorch_lightning as pl from pytorch_lightning.overrides.base import _LightningModuleWrapperBase from pytorch_lightning.plugins.environments.cluster_environment import ClusterEnvironment from pytorch_lightning.plugins.io.checkpoint_plugin import CheckpointIO +from pytorch_lightning.plugins.precision import PrecisionPlugin from pytorch_lightning.plugins.training_type.parallel import ParallelPlugin from pytorch_lightning.trainer.states import RunningStage, TrainerFn from pytorch_lightning.utilities import _IPU_AVAILABLE, _POPTORCH_AVAILABLE @@ -64,6 +65,7 @@ class IPUPlugin(ParallelPlugin): parallel_devices: Optional[List[torch.device]] = None, cluster_environment: Optional[ClusterEnvironment] = None, checkpoint_io: Optional[CheckpointIO] = None, + precision_plugin: Optional[PrecisionPlugin] = None, training_opts: Optional["poptorch.Options"] = None, inference_opts: Optional["poptorch.Options"] = None, ) -> None: @@ -84,6 +86,7 @@ class IPUPlugin(ParallelPlugin): parallel_devices=parallel_devices, cluster_environment=cluster_environment, checkpoint_io=checkpoint_io, + precision_plugin=precision_plugin, ) if not _IPU_AVAILABLE: raise MisconfigurationException( @@ -116,8 +119,7 @@ class IPUPlugin(ParallelPlugin): self.lightning_module.trainer._update_dataloader = self._convert_to_poptorch_loader def pre_dispatch(self) -> None: - precision = self.lightning_module.trainer.precision - model = LightningIPUModule(self.lightning_module, precision) + model = LightningIPUModule(self.lightning_module, self.precision_plugin.precision) self.model = model # reset the backup diff --git a/pytorch_lightning/plugins/training_type/parallel.py b/pytorch_lightning/plugins/training_type/parallel.py index 4f4b2c5b8e..07ede1ae4f 100644 --- a/pytorch_lightning/plugins/training_type/parallel.py +++ b/pytorch_lightning/plugins/training_type/parallel.py @@ -23,6 +23,7 @@ import pytorch_lightning as pl from pytorch_lightning.overrides.base import unwrap_lightning_module from pytorch_lightning.plugins.environments.cluster_environment import ClusterEnvironment from pytorch_lightning.plugins.io.checkpoint_plugin import CheckpointIO +from pytorch_lightning.plugins.precision import PrecisionPlugin from pytorch_lightning.plugins.training_type.training_type_plugin import TrainingTypePlugin from pytorch_lightning.utilities import _XLA_AVAILABLE from pytorch_lightning.utilities.distributed import all_gather_ddp_if_available, ReduceOp @@ -36,8 +37,9 @@ class ParallelPlugin(TrainingTypePlugin, ABC): parallel_devices: Optional[List[torch.device]] = None, cluster_environment: Optional[ClusterEnvironment] = None, checkpoint_io: Optional[CheckpointIO] = None, + precision_plugin: Optional[PrecisionPlugin] = None, ): - super().__init__(checkpoint_io) + super().__init__(checkpoint_io=checkpoint_io, precision_plugin=precision_plugin) self.parallel_devices = parallel_devices self.cluster_environment = cluster_environment diff --git a/pytorch_lightning/plugins/training_type/sharded.py b/pytorch_lightning/plugins/training_type/sharded.py index d7563437bd..eb4cb48534 100644 --- a/pytorch_lightning/plugins/training_type/sharded.py +++ b/pytorch_lightning/plugins/training_type/sharded.py @@ -75,7 +75,7 @@ class DDPShardedPlugin(DDPPlugin): optim_class = type(optimizer) zero_optimizer = OSS(params=optimizer.param_groups, optim=optim_class, **optimizer.defaults) if _FAIRSCALE_OSS_FP16_BROADCAST_AVAILABLE: - precision = self._precision or self.lightning_module.trainer.precision + precision = self._precision or self.precision_plugin.precision is_fp16 = precision in ("mixed", 16) # For multi-node training, compressing the model shards in fp16 before broadcasting # improves performance. When using PyTorch AMP, it will not degrade diff --git a/pytorch_lightning/plugins/training_type/sharded_spawn.py b/pytorch_lightning/plugins/training_type/sharded_spawn.py index 12e627edbe..12c06b9dde 100644 --- a/pytorch_lightning/plugins/training_type/sharded_spawn.py +++ b/pytorch_lightning/plugins/training_type/sharded_spawn.py @@ -118,9 +118,8 @@ class DDPSpawnShardedPlugin(DDPSpawnPlugin): def new_process(self, trainer: "pl.Trainer", mp_queue: SimpleQueue) -> None: # Ensure that the scaler points to the correct process group # which is re-initialized in a new process - precision_plugin = trainer.accelerator.precision_plugin - if isinstance(precision_plugin, ShardedNativeMixedPrecisionPlugin): - precision_plugin.scaler = ShardedGradScaler() + if isinstance(self.precision_plugin, ShardedNativeMixedPrecisionPlugin): + self.precision_plugin.scaler = ShardedGradScaler() return super().new_process(trainer, mp_queue) @classmethod diff --git a/pytorch_lightning/plugins/training_type/single_device.py b/pytorch_lightning/plugins/training_type/single_device.py index 1737bf3b41..12a0f625b6 100644 --- a/pytorch_lightning/plugins/training_type/single_device.py +++ b/pytorch_lightning/plugins/training_type/single_device.py @@ -16,6 +16,7 @@ from typing import Any, Optional, Union import torch from pytorch_lightning.plugins.io.checkpoint_plugin import CheckpointIO +from pytorch_lightning.plugins.precision import PrecisionPlugin from pytorch_lightning.plugins.training_type.training_type_plugin import TrainingTypePlugin from pytorch_lightning.utilities import _XLA_AVAILABLE @@ -27,8 +28,9 @@ class SingleDevicePlugin(TrainingTypePlugin): self, device: torch.device, checkpoint_io: Optional[CheckpointIO] = None, + precision_plugin: Optional[PrecisionPlugin] = None, ): - super().__init__(checkpoint_io) + super().__init__(checkpoint_io=checkpoint_io, precision_plugin=precision_plugin) self.device: torch.device = device self.global_rank = 0 self.local_rank = 0 diff --git a/pytorch_lightning/plugins/training_type/single_tpu.py b/pytorch_lightning/plugins/training_type/single_tpu.py index 9fed200039..e6f6a5f4b2 100644 --- a/pytorch_lightning/plugins/training_type/single_tpu.py +++ b/pytorch_lightning/plugins/training_type/single_tpu.py @@ -16,6 +16,7 @@ from typing import Any, Dict, Optional from pytorch_lightning.plugins.io.checkpoint_plugin import CheckpointIO from pytorch_lightning.plugins.io.xla_plugin import XLACheckpointIO +from pytorch_lightning.plugins.precision import PrecisionPlugin from pytorch_lightning.plugins.training_type.single_device import SingleDevicePlugin from pytorch_lightning.utilities import _TPU_AVAILABLE, find_shared_parameters, set_shared_parameters from pytorch_lightning.utilities.exceptions import MisconfigurationException @@ -33,12 +34,13 @@ class SingleTPUPlugin(SingleDevicePlugin): self, device: int, checkpoint_io: Optional[CheckpointIO] = None, + precision_plugin: Optional[PrecisionPlugin] = None, debug: bool = False, ): device = xm.xla_device(device) checkpoint_io = checkpoint_io or XLACheckpointIO() - super().__init__(device=device, checkpoint_io=checkpoint_io) + super().__init__(device=device, checkpoint_io=checkpoint_io, precision_plugin=precision_plugin) self.debug = debug self.tpu_local_core_rank = 0 diff --git a/pytorch_lightning/plugins/training_type/tpu_spawn.py b/pytorch_lightning/plugins/training_type/tpu_spawn.py index 7aa4a67721..3ab9a8171a 100644 --- a/pytorch_lightning/plugins/training_type/tpu_spawn.py +++ b/pytorch_lightning/plugins/training_type/tpu_spawn.py @@ -27,6 +27,7 @@ import pytorch_lightning as pl from pytorch_lightning.overrides import LightningDistributedModule from pytorch_lightning.plugins.io.checkpoint_plugin import CheckpointIO from pytorch_lightning.plugins.io.xla_plugin import XLACheckpointIO +from pytorch_lightning.plugins.precision import PrecisionPlugin from pytorch_lightning.plugins.training_type.ddp_spawn import DDPSpawnPlugin from pytorch_lightning.trainer.connectors.data_connector import DataConnector from pytorch_lightning.trainer.states import TrainerFn @@ -56,11 +57,14 @@ class TPUSpawnPlugin(DDPSpawnPlugin): self, parallel_devices: Optional[List[int]] = None, checkpoint_io: Optional[CheckpointIO] = None, + precision_plugin: Optional[PrecisionPlugin] = None, debug: bool = False, **_: Any ) -> None: checkpoint_io = checkpoint_io or XLACheckpointIO() - super().__init__(parallel_devices=parallel_devices, checkpoint_io=checkpoint_io) + super().__init__( + parallel_devices=parallel_devices, checkpoint_io=checkpoint_io, precision_plugin=precision_plugin + ) self.debug = debug self.tpu_local_core_rank = 0 self.tpu_global_core_rank = 0 @@ -167,7 +171,7 @@ class TPUSpawnPlugin(DDPSpawnPlugin): set_shared_parameters(self.model.module, shared_params) trainer.accelerator.setup_optimizers(trainer) - trainer.precision_plugin.connect(self._model, None, None) + self.precision_plugin.connect(self._model, None, None) self.barrier("pre-run-stage") diff --git a/pytorch_lightning/plugins/training_type/training_type_plugin.py b/pytorch_lightning/plugins/training_type/training_type_plugin.py index c23edf5941..7010c0e878 100644 --- a/pytorch_lightning/plugins/training_type/training_type_plugin.py +++ b/pytorch_lightning/plugins/training_type/training_type_plugin.py @@ -25,6 +25,7 @@ import pytorch_lightning as pl from pytorch_lightning.overrides.base import unwrap_lightning_module from pytorch_lightning.plugins import TorchCheckpointIO from pytorch_lightning.plugins.io.checkpoint_plugin import CheckpointIO +from pytorch_lightning.plugins.precision import PrecisionPlugin from pytorch_lightning.utilities.distributed import ReduceOp from pytorch_lightning.utilities.types import _EVALUATE_OUTPUT, _PATH, _PREDICT_OUTPUT @@ -33,16 +34,23 @@ class TrainingTypePlugin(ABC): """Base class for all training type plugins that change the behaviour of the training, validation and test- loop.""" - def __init__(self, checkpoint_io: Optional[CheckpointIO] = None) -> None: + def __init__( + self, checkpoint_io: Optional[CheckpointIO] = None, precision_plugin: Optional[PrecisionPlugin] = None + ) -> None: self._model: Optional[Module] = None self._results: Optional[Union[_EVALUATE_OUTPUT, _PREDICT_OUTPUT]] = None checkpoint_io = checkpoint_io if checkpoint_io is not None else TorchCheckpointIO() self._checkpoint_io = checkpoint_io + self._precision_plugin = precision_plugin if precision_plugin is not None else PrecisionPlugin() @property def checkpoint_io(self) -> CheckpointIO: return self._checkpoint_io + @property + def precision_plugin(self) -> PrecisionPlugin: + return self._precision_plugin + @checkpoint_io.setter def checkpoint_io(self, plugin: CheckpointIO) -> None: self._checkpoint_io = plugin diff --git a/pytorch_lightning/trainer/connectors/accelerator_connector.py b/pytorch_lightning/trainer/connectors/accelerator_connector.py index 5532385ca1..e5df9c3b84 100644 --- a/pytorch_lightning/trainer/connectors/accelerator_connector.py +++ b/pytorch_lightning/trainer/connectors/accelerator_connector.py @@ -405,6 +405,9 @@ class AcceleratorConnector: # attach checkpoint plugin to the training type plugin if self._checkpoint_io is not None: self._training_type_plugin.checkpoint_io = self._checkpoint_io + precision_plugin = self.precision_plugin + if precision_plugin is not None: + self._training_type_plugin._precision_plugin = precision_plugin self._training_type_plugin_resolved = True return self._training_type_plugin @@ -531,11 +534,11 @@ class AcceleratorConnector: @property def _is_sharded_training_type(self) -> bool: - return isinstance(self.training_type_plugin, (DDPShardedPlugin, DDPSpawnShardedPlugin)) + return isinstance(self._training_type_plugin, (DDPShardedPlugin, DDPSpawnShardedPlugin)) @property def _is_fully_sharded_training_type(self) -> bool: - return isinstance(self.training_type_plugin, DDPFullyShardedPlugin) + return isinstance(self._training_type_plugin, DDPFullyShardedPlugin) @property def is_distributed(self) -> bool: @@ -793,12 +796,10 @@ class AcceleratorConnector: acc_cls = IPUAccelerator else: acc_cls = CPUAccelerator - # as precision_plugin is dependent on training_type_plugin, make sure - # that we first select training_type_plugin, then precision_plugin - accelerator = acc_cls(training_type_plugin=self.training_type_plugin, precision_plugin=self.precision_plugin) + + accelerator = acc_cls(precision_plugin=None, training_type_plugin=self.training_type_plugin) # transfer ownership of the plugins to the accelerator self._training_type_plugin = proxy(self.training_type_plugin) - self._precision_plugin = proxy(self.precision_plugin) return accelerator diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 26fbcb4362..2f6e987635 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -1568,7 +1568,7 @@ class Trainer( @property def precision_plugin(self) -> PrecisionPlugin: - return self.accelerator.precision_plugin + return self.training_type_plugin.precision_plugin @property def global_rank(self) -> int: @@ -1672,7 +1672,7 @@ class Trainer( @property def precision(self) -> Union[str, int]: - return self.accelerator.precision + return self.training_type_plugin.precision_plugin.precision @property def scaler(self): diff --git a/tests/accelerators/test_ipu.py b/tests/accelerators/test_ipu.py index dfaa1c8042..be2e597c9a 100644 --- a/tests/accelerators/test_ipu.py +++ b/tests/accelerators/test_ipu.py @@ -193,8 +193,8 @@ def test_mixed_precision(tmpdir): model = IPUModel() trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True, ipus=1, precision=16, callbacks=TestCallback()) - assert isinstance(trainer.accelerator.precision_plugin, IPUPrecisionPlugin) - assert trainer.accelerator.precision_plugin.precision == 16 + assert isinstance(trainer.training_type_plugin.precision_plugin, IPUPrecisionPlugin) + assert trainer.training_type_plugin.precision_plugin.precision == 16 with pytest.raises(SystemExit): trainer.fit(model) @@ -213,8 +213,8 @@ def test_pure_half_precision(tmpdir): trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True, ipus=1, precision=16, callbacks=TestCallback()) assert isinstance(trainer.accelerator.training_type_plugin, IPUPlugin) - assert isinstance(trainer.accelerator.precision_plugin, IPUPrecisionPlugin) - assert trainer.accelerator.precision_plugin.precision == 16 + assert isinstance(trainer.training_type_plugin.precision_plugin, IPUPrecisionPlugin) + assert trainer.training_type_plugin.precision_plugin.precision == 16 with pytest.raises(SystemExit): trainer.fit(model) diff --git a/tests/accelerators/test_tpu.py b/tests/accelerators/test_tpu.py index 78e4c505bb..fc1ce413cd 100644 --- a/tests/accelerators/test_tpu.py +++ b/tests/accelerators/test_tpu.py @@ -23,7 +23,7 @@ from torch.utils.data import DataLoader from pytorch_lightning import Trainer from pytorch_lightning.accelerators.cpu import CPUAccelerator from pytorch_lightning.accelerators.tpu import TPUAccelerator -from pytorch_lightning.plugins import TPUPrecisionPlugin, TPUSpawnPlugin, XLACheckpointIO +from pytorch_lightning.plugins import DDPPlugin, TPUPrecisionPlugin, TPUSpawnPlugin, XLACheckpointIO from pytorch_lightning.utilities import find_shared_parameters from pytorch_lightning.utilities.exceptions import MisconfigurationException from tests.helpers.boring_model import BoringModel, RandomDataset @@ -292,11 +292,23 @@ def test_tpu_invalid_raises(): with pytest.raises(ValueError, match="TPUAccelerator` can only be used with a `TPUPrecisionPlugin"): accelerator.setup(object()) - accelerator = TPUAccelerator(TPUPrecisionPlugin(), object()) + accelerator = TPUAccelerator(TPUPrecisionPlugin(), DDPPlugin()) with pytest.raises(ValueError, match="TPUAccelerator` can only be used with a `SingleTPUPlugin` or `TPUSpawnPlugi"): accelerator.setup(object()) +def test_tpu_invalid_raises_set_precision_with_strategy(): + accelerator = TPUAccelerator(object(), TPUSpawnPlugin(precision_plugin=object())) + with pytest.raises(ValueError, match="`TPUAccelerator` can only be used with a `TPUPrecisionPlugin`"): + accelerator.setup(object()) + + accelerator = TPUAccelerator(None, DDPPlugin(precision_plugin=TPUPrecisionPlugin())) + with pytest.raises( + ValueError, match="TPUAccelerator` can only be used with a `SingleTPUPlugin` or `TPUSpawnPlugin" + ): + accelerator.setup(object()) + + @RunIf(tpu=True) def test_xla_checkpoint_plugin_being_default(): trainer = Trainer(tpu_cores=8) diff --git a/tests/plugins/test_ddp_fully_sharded_with_full_state_dict.py b/tests/plugins/test_ddp_fully_sharded_with_full_state_dict.py index 1468c7f4a4..c0fab29717 100644 --- a/tests/plugins/test_ddp_fully_sharded_with_full_state_dict.py +++ b/tests/plugins/test_ddp_fully_sharded_with_full_state_dict.py @@ -34,8 +34,8 @@ def test_invalid_on_cpu(tmpdir): def test_fsdp_with_sharded_amp(device_count_mock, mock_cuda_available, tmpdir): """Test to ensure that plugin native amp plugin is correctly chosen when using sharded.""" trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True, strategy="fsdp", gpus=1, precision=16) - assert isinstance(trainer.accelerator.training_type_plugin, DDPFullyShardedPlugin) - assert isinstance(trainer.accelerator.precision_plugin, FullyShardedNativeMixedPrecisionPlugin) + assert isinstance(trainer.training_type_plugin, DDPFullyShardedPlugin) + assert isinstance(trainer.training_type_plugin.precision_plugin, FullyShardedNativeMixedPrecisionPlugin) class TestFSDPModel(BoringModel): diff --git a/tests/plugins/test_deepspeed_plugin.py b/tests/plugins/test_deepspeed_plugin.py index 2d39a3de6b..480b050c39 100644 --- a/tests/plugins/test_deepspeed_plugin.py +++ b/tests/plugins/test_deepspeed_plugin.py @@ -170,8 +170,8 @@ def test_deepspeed_precision_choice(amp_backend, precision, tmpdir): ) assert isinstance(trainer.accelerator.training_type_plugin, DeepSpeedPlugin) - assert isinstance(trainer.accelerator.precision_plugin, DeepSpeedPrecisionPlugin) - assert trainer.accelerator.precision_plugin.precision == precision + assert isinstance(trainer.training_type_plugin.precision_plugin, DeepSpeedPrecisionPlugin) + assert trainer.training_type_plugin.precision_plugin.precision == precision @RunIf(deepspeed=True)