1/n Move precision plugin into strategy - update reference (#10570)
* 1/n move precision plugin into strategy - update reference * update precision plugin reference in tpu_spawn * add missing reference in error message * add back removed license line * update references in tests * update reference in trainer * update return annotation for precision_plugin property on TTP * simplify access to precision plugin reference in sharded plug * add changelog * remove precision property from ttp and add deprecation message * fix make doc and update precision reference * simplify a reference to precision accidentally overridden Adrian's change, now add it back * Update CHANGELOG.md add Adrian's change back * Update accelerator precision Add Adrian's change back * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Add none check for precision plugin just to be safe * Update ipu.py * update precision_plugin param deprecation message * Update accelerator.py * Remove deprecated warning Tests will fail after 9940 Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
parent
2c7c4aab80
commit
700521c7d3
|
@ -37,6 +37,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
|
|||
- Raise `MisconfigurationException` when `enable_progress_bar=False` and a progress bar instance has been passed in the callback list ([#10520](https://github.com/PyTorchLightning/pytorch-lightning/issues/10520))
|
||||
|
||||
|
||||
- Moved ownership of the `PrecisionPlugin` into `TrainingTypePlugin` and updated all references ([#10570](https://github.com/PyTorchLightning/pytorch-lightning/pull/10570))
|
||||
|
||||
|
||||
-
|
||||
|
||||
|
||||
|
@ -50,7 +53,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
|
|||
- Deprecated `DistributedType` in favor of `_StrategyType` ([#10505](https://github.com/PyTorchLightning/pytorch-lightning/pull/10505))
|
||||
|
||||
|
||||
-
|
||||
- Deprecated the `precision_plugin` constructor argument from `Accelerator` ([#10570](https://github.com/PyTorchLightning/pytorch-lightning/pull/10570))
|
||||
|
||||
|
||||
-
|
||||
|
@ -139,6 +142,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
|
|||
- Removed deprecated `reload_dataloaders_every_epoch` from `Trainer` in favour of `reload_dataloaders_every_n_epochs` ([#10481](https://github.com/PyTorchLightning/pytorch-lightning/pull/10481))
|
||||
|
||||
|
||||
- Removed the `precision_plugin` attribute from `Accelerator` in favor of its equivalent attribute `precision_plugin` in the `TrainingTypePlugin` ([#10570](https://github.com/PyTorchLightning/pytorch-lightning/pull/10570))
|
||||
|
||||
### Fixed
|
||||
|
||||
|
|
|
@ -25,6 +25,7 @@ import pytorch_lightning as pl
|
|||
from pytorch_lightning.plugins.precision import ApexMixedPrecisionPlugin, NativeMixedPrecisionPlugin, PrecisionPlugin
|
||||
from pytorch_lightning.plugins.training_type import DataParallelPlugin, TrainingTypePlugin
|
||||
from pytorch_lightning.trainer.states import TrainerFn
|
||||
from pytorch_lightning.utilities import rank_zero_deprecation
|
||||
from pytorch_lightning.utilities.apply_func import apply_to_collection, move_data_to_device
|
||||
from pytorch_lightning.utilities.enums import AMPType, LightningEnum
|
||||
from pytorch_lightning.utilities.types import STEP_OUTPUT
|
||||
|
@ -44,15 +45,23 @@ class Accelerator:
|
|||
One to handle differences from the training routine and one to handle different precisions.
|
||||
"""
|
||||
|
||||
def __init__(self, precision_plugin: PrecisionPlugin, training_type_plugin: TrainingTypePlugin) -> None:
|
||||
def __init__(self, precision_plugin: Optional[PrecisionPlugin], training_type_plugin: TrainingTypePlugin) -> None:
|
||||
"""
|
||||
Args:
|
||||
precision_plugin: the plugin to handle precision-specific parts
|
||||
|
||||
.. deprecated::
|
||||
The ``precision_plugin`` parameter has been deprecated and will be removed soon.
|
||||
Pass the precision plugin as a parameter to the ``TrainingTypePlugin`` instead.
|
||||
|
||||
training_type_plugin: the plugin to handle different training routines
|
||||
"""
|
||||
self.precision_plugin = precision_plugin
|
||||
|
||||
self.training_type_plugin = training_type_plugin
|
||||
|
||||
if precision_plugin is not None:
|
||||
self.training_type_plugin._precision_plugin = precision_plugin
|
||||
|
||||
self.optimizers: List = []
|
||||
self.lr_schedulers: List = []
|
||||
self.optimizer_frequencies: List = []
|
||||
|
@ -84,7 +93,7 @@ class Accelerator:
|
|||
if self.training_type_plugin.setup_optimizers_in_pre_dispatch:
|
||||
self.setup_optimizers(trainer)
|
||||
|
||||
self.precision_plugin.pre_dispatch()
|
||||
self.training_type_plugin.precision_plugin.pre_dispatch()
|
||||
|
||||
def _move_optimizer_state(self, device: Optional[torch.device] = None) -> None:
|
||||
"""Moves the state of the optimizers to the GPU if needed."""
|
||||
|
@ -96,12 +105,12 @@ class Accelerator:
|
|||
def dispatch(self, trainer: "pl.Trainer") -> None:
|
||||
"""Hook to do something before the training/evaluation/prediction starts."""
|
||||
self.training_type_plugin.dispatch(trainer)
|
||||
self.precision_plugin.dispatch(trainer)
|
||||
self.training_type_plugin.precision_plugin.dispatch(trainer)
|
||||
|
||||
def post_dispatch(self, trainer: "pl.Trainer") -> None:
|
||||
"""Hook to do something after the training/evaluation/prediction starts."""
|
||||
self.training_type_plugin.post_dispatch(trainer)
|
||||
self.precision_plugin.post_dispatch()
|
||||
self.training_type_plugin.precision_plugin.post_dispatch()
|
||||
|
||||
@property
|
||||
def model(self) -> Module:
|
||||
|
@ -159,7 +168,7 @@ class Accelerator:
|
|||
|
||||
See :meth:`~pytorch_lightning.core.lightning.LightningModule.training_step` for more details
|
||||
"""
|
||||
with self.precision_plugin.train_step_context():
|
||||
with self.training_type_plugin.precision_plugin.train_step_context():
|
||||
return self.training_type_plugin.training_step(*step_kwargs.values())
|
||||
|
||||
def validation_step(self, step_kwargs: Dict[str, Union[Any, int]]) -> Optional[STEP_OUTPUT]:
|
||||
|
@ -167,7 +176,7 @@ class Accelerator:
|
|||
|
||||
See :meth:`~pytorch_lightning.core.lightning.LightningModule.validation_step` for more details
|
||||
"""
|
||||
with self.precision_plugin.val_step_context():
|
||||
with self.training_type_plugin.precision_plugin.val_step_context():
|
||||
return self.training_type_plugin.validation_step(*step_kwargs.values())
|
||||
|
||||
def test_step(self, step_kwargs: Dict[str, Union[Any, int]]) -> Optional[STEP_OUTPUT]:
|
||||
|
@ -175,7 +184,7 @@ class Accelerator:
|
|||
|
||||
See :meth:`~pytorch_lightning.core.lightning.LightningModule.test_step` for more details
|
||||
"""
|
||||
with self.precision_plugin.test_step_context():
|
||||
with self.training_type_plugin.precision_plugin.test_step_context():
|
||||
return self.training_type_plugin.test_step(*step_kwargs.values())
|
||||
|
||||
def predict_step(self, step_kwargs: Dict[str, Union[Any, int]]) -> STEP_OUTPUT:
|
||||
|
@ -183,7 +192,7 @@ class Accelerator:
|
|||
|
||||
See :meth:`~pytorch_lightning.core.lightning.LightningModule.predict_step` for more details
|
||||
"""
|
||||
with self.precision_plugin.predict_step_context():
|
||||
with self.training_type_plugin.precision_plugin.predict_step_context():
|
||||
return self.training_type_plugin.predict_step(*step_kwargs.values())
|
||||
|
||||
def backward(self, closure_loss: Tensor, *args: Any, **kwargs: Any) -> Tensor:
|
||||
|
@ -193,11 +202,11 @@ class Accelerator:
|
|||
closure_loss: a tensor holding the loss value to backpropagate
|
||||
"""
|
||||
self.training_type_plugin.pre_backward(closure_loss)
|
||||
closure_loss = self.precision_plugin.pre_backward(self.lightning_module, closure_loss)
|
||||
closure_loss = self.training_type_plugin.precision_plugin.pre_backward(self.lightning_module, closure_loss)
|
||||
|
||||
self.precision_plugin.backward(self.lightning_module, closure_loss, *args, **kwargs)
|
||||
self.training_type_plugin.precision_plugin.backward(self.lightning_module, closure_loss, *args, **kwargs)
|
||||
|
||||
closure_loss = self.precision_plugin.post_backward(self.lightning_module, closure_loss)
|
||||
closure_loss = self.training_type_plugin.precision_plugin.post_backward(self.lightning_module, closure_loss)
|
||||
self.training_type_plugin.post_backward(closure_loss)
|
||||
|
||||
return closure_loss
|
||||
|
@ -208,7 +217,7 @@ class Accelerator:
|
|||
opt_idx: int,
|
||||
closure: Callable[[], Any],
|
||||
model: Optional[Union["pl.LightningModule", Module]] = None,
|
||||
**kwargs: Any
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""performs the actual optimizer step.
|
||||
|
||||
|
@ -220,7 +229,7 @@ class Accelerator:
|
|||
**kwargs: Any extra arguments to ``optimizer.step``
|
||||
"""
|
||||
model = model or self.lightning_module
|
||||
self.precision_plugin.optimizer_step(model, optimizer, opt_idx, closure, **kwargs)
|
||||
self.training_type_plugin.precision_plugin.optimizer_step(model, optimizer, opt_idx, closure, **kwargs)
|
||||
|
||||
def optimizer_zero_grad(self, current_epoch: int, batch_idx: int, optimizer: Optimizer, opt_idx: int) -> None:
|
||||
"""Zeros all model parameter's gradients."""
|
||||
|
@ -248,26 +257,38 @@ class Accelerator:
|
|||
|
||||
def setup_precision_plugin(self) -> None:
|
||||
"""Attaches the precision plugin to the accelerator."""
|
||||
model, optimizers, schedulers = self.precision_plugin.connect(self.model, self.optimizers, self.lr_schedulers)
|
||||
model, optimizers, schedulers = self.training_type_plugin.precision_plugin.connect(
|
||||
self.model, self.optimizers, self.lr_schedulers
|
||||
)
|
||||
self.model = model
|
||||
self.optimizers = optimizers
|
||||
self.lr_schedulers = schedulers
|
||||
|
||||
@property
|
||||
def amp_backend(self) -> Optional[LightningEnum]:
|
||||
if isinstance(self.precision_plugin, ApexMixedPrecisionPlugin):
|
||||
if isinstance(self.training_type_plugin.precision_plugin, ApexMixedPrecisionPlugin):
|
||||
return AMPType.APEX
|
||||
if isinstance(self.precision_plugin, NativeMixedPrecisionPlugin):
|
||||
if isinstance(self.training_type_plugin.precision_plugin, NativeMixedPrecisionPlugin):
|
||||
return AMPType.NATIVE
|
||||
return None
|
||||
|
||||
@property
|
||||
def precision(self) -> Union[str, int]:
|
||||
return self.precision_plugin.precision
|
||||
"""The type of precision being used with this accelerator.
|
||||
|
||||
.. deprecated::
|
||||
This property been deprecated and will be removed soon.
|
||||
Use ``training_type_plugin.precision_plugin.precision`` instead.
|
||||
"""
|
||||
rank_zero_deprecation(
|
||||
f"`{self.__class__.__name__}.precision` has been deprecated and will be removed soon"
|
||||
f" Use `training_type_plugin.precision_plugin.precision` instead."
|
||||
)
|
||||
return self.training_type_plugin.precision_plugin.precision
|
||||
|
||||
@property
|
||||
def scaler(self) -> Optional["GradScaler"]:
|
||||
return getattr(self.precision_plugin, "scaler", None)
|
||||
return getattr(self.training_type_plugin.precision_plugin, "scaler", None)
|
||||
|
||||
def optimizer_state(self, optimizer: Optimizer) -> Dict[str, Tensor]:
|
||||
"""Returns state of an optimizer.
|
||||
|
|
|
@ -36,10 +36,11 @@ class TPUAccelerator(Accelerator):
|
|||
ValueError:
|
||||
If the precision or training type plugin are unsupported.
|
||||
"""
|
||||
if not isinstance(self.precision_plugin, TPUPrecisionPlugin):
|
||||
if not isinstance(self.training_type_plugin.precision_plugin, TPUPrecisionPlugin):
|
||||
# this configuration should have been avoided in the accelerator connector
|
||||
raise ValueError(
|
||||
f"The `TPUAccelerator` can only be used with a `TPUPrecisionPlugin`, found: {self.precision_plugin}."
|
||||
f"The `TPUAccelerator` can only be used with a `TPUPrecisionPlugin`,"
|
||||
f" found: {self.training_type_plugin.precision_plugin}."
|
||||
)
|
||||
if not isinstance(self.training_type_plugin, (SingleTPUPlugin, TPUSpawnPlugin)):
|
||||
raise ValueError(
|
||||
|
|
|
@ -108,7 +108,7 @@ class LightningLite(ABC):
|
|||
)
|
||||
self._accelerator = self._accelerator_connector.accelerator
|
||||
self._strategy = self._accelerator.training_type_plugin
|
||||
self._precision_plugin = self._accelerator.precision_plugin
|
||||
self._precision_plugin = self._strategy.precision_plugin
|
||||
self._models_setup: int = 0
|
||||
|
||||
# wrap the run method so we can inject setup logic or spawn processes for the user
|
||||
|
|
|
@ -36,6 +36,7 @@ from pytorch_lightning.overrides import LightningDistributedModule
|
|||
from pytorch_lightning.overrides.distributed import prepare_for_backward
|
||||
from pytorch_lightning.plugins.environments.cluster_environment import ClusterEnvironment
|
||||
from pytorch_lightning.plugins.io.checkpoint_plugin import CheckpointIO
|
||||
from pytorch_lightning.plugins.precision import PrecisionPlugin
|
||||
from pytorch_lightning.plugins.training_type.parallel import ParallelPlugin
|
||||
from pytorch_lightning.trainer.states import TrainerFn
|
||||
from pytorch_lightning.utilities import (
|
||||
|
@ -86,6 +87,7 @@ class DDPPlugin(ParallelPlugin):
|
|||
parallel_devices: Optional[List[torch.device]] = None,
|
||||
cluster_environment: Optional[ClusterEnvironment] = None,
|
||||
checkpoint_io: Optional[CheckpointIO] = None,
|
||||
precision_plugin: Optional[PrecisionPlugin] = None,
|
||||
ddp_comm_state: Optional[object] = None,
|
||||
ddp_comm_hook: Optional[callable] = None,
|
||||
ddp_comm_wrapper: Optional[callable] = None,
|
||||
|
@ -96,6 +98,7 @@ class DDPPlugin(ParallelPlugin):
|
|||
parallel_devices=parallel_devices,
|
||||
cluster_environment=cluster_environment,
|
||||
checkpoint_io=checkpoint_io,
|
||||
precision_plugin=precision_plugin,
|
||||
)
|
||||
self.interactive_ddp_procs = []
|
||||
self._num_nodes = 1
|
||||
|
|
|
@ -29,6 +29,7 @@ from pytorch_lightning.overrides import LightningDistributedModule
|
|||
from pytorch_lightning.overrides.distributed import prepare_for_backward
|
||||
from pytorch_lightning.plugins.environments.cluster_environment import ClusterEnvironment
|
||||
from pytorch_lightning.plugins.io.checkpoint_plugin import CheckpointIO
|
||||
from pytorch_lightning.plugins.precision import PrecisionPlugin
|
||||
from pytorch_lightning.plugins.training_type.parallel import ParallelPlugin
|
||||
from pytorch_lightning.trainer.states import TrainerFn
|
||||
from pytorch_lightning.utilities import _TORCH_GREATER_EQUAL_1_8, rank_zero_warn
|
||||
|
@ -65,6 +66,7 @@ class DDPSpawnPlugin(ParallelPlugin):
|
|||
parallel_devices: Optional[List[torch.device]] = None,
|
||||
cluster_environment: Optional[ClusterEnvironment] = None,
|
||||
checkpoint_io: Optional[CheckpointIO] = None,
|
||||
precision_plugin: Optional[PrecisionPlugin] = None,
|
||||
ddp_comm_state: Optional[object] = None,
|
||||
ddp_comm_hook: Optional[callable] = None,
|
||||
ddp_comm_wrapper: Optional[callable] = None,
|
||||
|
@ -74,6 +76,7 @@ class DDPSpawnPlugin(ParallelPlugin):
|
|||
parallel_devices=parallel_devices,
|
||||
cluster_environment=cluster_environment,
|
||||
checkpoint_io=checkpoint_io,
|
||||
precision_plugin=precision_plugin,
|
||||
)
|
||||
self._num_nodes = 1
|
||||
self.sync_batchnorm = False
|
||||
|
|
|
@ -30,6 +30,7 @@ import pytorch_lightning as pl
|
|||
from pytorch_lightning.overrides.base import _LightningModuleWrapperBase
|
||||
from pytorch_lightning.plugins.environments.cluster_environment import ClusterEnvironment
|
||||
from pytorch_lightning.plugins.io.checkpoint_plugin import CheckpointIO
|
||||
from pytorch_lightning.plugins.precision import PrecisionPlugin
|
||||
from pytorch_lightning.plugins.training_type.ddp import DDPPlugin
|
||||
from pytorch_lightning.trainer.optimizers import _get_default_scheduler_config
|
||||
from pytorch_lightning.trainer.states import TrainerFn
|
||||
|
@ -129,6 +130,7 @@ class DeepSpeedPlugin(DDPPlugin):
|
|||
synchronize_checkpoint_boundary: bool = False,
|
||||
load_full_weights: bool = False,
|
||||
partition_module: bool = True,
|
||||
precision_plugin: Optional[PrecisionPlugin] = None,
|
||||
) -> None:
|
||||
"""Provides capabilities to run training using the DeepSpeed library, with training optimizations for large
|
||||
billion parameter models. `For more information: https://pytorch-
|
||||
|
@ -273,6 +275,7 @@ class DeepSpeedPlugin(DDPPlugin):
|
|||
super().__init__(
|
||||
parallel_devices=parallel_devices,
|
||||
cluster_environment=cluster_environment,
|
||||
precision_plugin=precision_plugin,
|
||||
)
|
||||
|
||||
self.config = self._load_config(config)
|
||||
|
@ -331,7 +334,7 @@ class DeepSpeedPlugin(DDPPlugin):
|
|||
|
||||
@property
|
||||
def precision(self) -> Union[str, int]:
|
||||
return self._precision or self.lightning_module.trainer.precision
|
||||
return self._precision or self.precision_plugin.precision
|
||||
|
||||
@property
|
||||
def amp_level(self) -> Optional[str]:
|
||||
|
@ -456,8 +459,7 @@ class DeepSpeedPlugin(DDPPlugin):
|
|||
"DeepSpeed currently does not support different `accumulate_grad_batches` at different epochs."
|
||||
)
|
||||
|
||||
precision = self.lightning_module.trainer.accelerator.precision
|
||||
model = LightningDeepSpeedModule(pl_module=self.model, precision=precision)
|
||||
model = LightningDeepSpeedModule(pl_module=self.model, precision=self.precision)
|
||||
|
||||
if self.zero_stage_3 and self.partition_module:
|
||||
# Ensure the entire model has been moved to the appropriate device
|
||||
|
|
|
@ -18,6 +18,7 @@ from torch.nn import DataParallel, Module
|
|||
|
||||
from pytorch_lightning.overrides.data_parallel import LightningParallelModule
|
||||
from pytorch_lightning.plugins.io.checkpoint_plugin import CheckpointIO
|
||||
from pytorch_lightning.plugins.precision import PrecisionPlugin
|
||||
from pytorch_lightning.plugins.training_type.parallel import ParallelPlugin
|
||||
from pytorch_lightning.utilities.apply_func import apply_to_collection
|
||||
from pytorch_lightning.utilities.enums import _StrategyType
|
||||
|
@ -35,8 +36,14 @@ class DataParallelPlugin(ParallelPlugin):
|
|||
self,
|
||||
parallel_devices: Optional[List[torch.device]] = None,
|
||||
checkpoint_io: Optional[CheckpointIO] = None,
|
||||
precision_plugin: Optional[PrecisionPlugin] = None,
|
||||
):
|
||||
super().__init__(parallel_devices=parallel_devices, cluster_environment=None, checkpoint_io=checkpoint_io)
|
||||
super().__init__(
|
||||
parallel_devices=parallel_devices,
|
||||
cluster_environment=None,
|
||||
checkpoint_io=checkpoint_io,
|
||||
precision_plugin=precision_plugin,
|
||||
)
|
||||
|
||||
@property
|
||||
def global_rank(self) -> int:
|
||||
|
|
|
@ -18,6 +18,7 @@ import torch
|
|||
|
||||
from pytorch_lightning.plugins.environments.cluster_environment import ClusterEnvironment
|
||||
from pytorch_lightning.plugins.io.checkpoint_plugin import CheckpointIO
|
||||
from pytorch_lightning.plugins.precision import PrecisionPlugin
|
||||
from pytorch_lightning.plugins.training_type.ddp import DDPPlugin
|
||||
from pytorch_lightning.utilities import _FAIRSCALE_FULLY_SHARDED_AVAILABLE
|
||||
from pytorch_lightning.utilities.enums import _StrategyType
|
||||
|
@ -46,6 +47,7 @@ class DDPFullyShardedPlugin(DDPPlugin):
|
|||
parallel_devices: Optional[List[torch.device]] = None,
|
||||
cluster_environment: Optional[ClusterEnvironment] = None,
|
||||
checkpoint_io: Optional[CheckpointIO] = None,
|
||||
precision_plugin: Optional[PrecisionPlugin] = None,
|
||||
):
|
||||
"""Plugin for Fully Sharded Data Parallel provided by FairScale.
|
||||
|
||||
|
@ -97,6 +99,7 @@ class DDPFullyShardedPlugin(DDPPlugin):
|
|||
parallel_devices=parallel_devices,
|
||||
cluster_environment=cluster_environment,
|
||||
checkpoint_io=checkpoint_io,
|
||||
precision_plugin=precision_plugin,
|
||||
)
|
||||
self.cpu_offload = cpu_offload
|
||||
self.move_grads_to_cpu = move_grads_to_cpu
|
||||
|
@ -124,7 +127,7 @@ class DDPFullyShardedPlugin(DDPPlugin):
|
|||
|
||||
@contextlib.contextmanager
|
||||
def model_sharded_context(self) -> Generator:
|
||||
precision = self.lightning_module.trainer.precision
|
||||
precision = self.precision_plugin.precision
|
||||
|
||||
def wrap_policy(*args, **kwargs):
|
||||
return default_auto_wrap_policy(*args, **kwargs, min_num_params=self.min_num_params)
|
||||
|
|
|
@ -21,6 +21,7 @@ from torch.optim.lr_scheduler import _LRScheduler
|
|||
|
||||
from pytorch_lightning.core.optimizer import LightningOptimizer
|
||||
from pytorch_lightning.plugins.io.checkpoint_plugin import CheckpointIO
|
||||
from pytorch_lightning.plugins.precision import PrecisionPlugin
|
||||
from pytorch_lightning.plugins.training_type.parallel import ParallelPlugin
|
||||
from pytorch_lightning.utilities import _HOROVOD_AVAILABLE
|
||||
from pytorch_lightning.utilities.distributed import distributed_available
|
||||
|
@ -41,8 +42,14 @@ class HorovodPlugin(ParallelPlugin):
|
|||
self,
|
||||
parallel_devices: Optional[List[torch.device]] = None,
|
||||
checkpoint_io: Optional[CheckpointIO] = None,
|
||||
precision_plugin: Optional[PrecisionPlugin] = None,
|
||||
):
|
||||
super().__init__(parallel_devices=parallel_devices, cluster_environment=None, checkpoint_io=checkpoint_io)
|
||||
super().__init__(
|
||||
parallel_devices=parallel_devices,
|
||||
cluster_environment=None,
|
||||
checkpoint_io=checkpoint_io,
|
||||
precision_plugin=precision_plugin,
|
||||
)
|
||||
rank_zero_only.rank = self.global_rank
|
||||
|
||||
@property
|
||||
|
|
|
@ -22,6 +22,7 @@ import pytorch_lightning as pl
|
|||
from pytorch_lightning.overrides.base import _LightningModuleWrapperBase
|
||||
from pytorch_lightning.plugins.environments.cluster_environment import ClusterEnvironment
|
||||
from pytorch_lightning.plugins.io.checkpoint_plugin import CheckpointIO
|
||||
from pytorch_lightning.plugins.precision import PrecisionPlugin
|
||||
from pytorch_lightning.plugins.training_type.parallel import ParallelPlugin
|
||||
from pytorch_lightning.trainer.states import RunningStage, TrainerFn
|
||||
from pytorch_lightning.utilities import _IPU_AVAILABLE, _POPTORCH_AVAILABLE
|
||||
|
@ -64,6 +65,7 @@ class IPUPlugin(ParallelPlugin):
|
|||
parallel_devices: Optional[List[torch.device]] = None,
|
||||
cluster_environment: Optional[ClusterEnvironment] = None,
|
||||
checkpoint_io: Optional[CheckpointIO] = None,
|
||||
precision_plugin: Optional[PrecisionPlugin] = None,
|
||||
training_opts: Optional["poptorch.Options"] = None,
|
||||
inference_opts: Optional["poptorch.Options"] = None,
|
||||
) -> None:
|
||||
|
@ -84,6 +86,7 @@ class IPUPlugin(ParallelPlugin):
|
|||
parallel_devices=parallel_devices,
|
||||
cluster_environment=cluster_environment,
|
||||
checkpoint_io=checkpoint_io,
|
||||
precision_plugin=precision_plugin,
|
||||
)
|
||||
if not _IPU_AVAILABLE:
|
||||
raise MisconfigurationException(
|
||||
|
@ -116,8 +119,7 @@ class IPUPlugin(ParallelPlugin):
|
|||
self.lightning_module.trainer._update_dataloader = self._convert_to_poptorch_loader
|
||||
|
||||
def pre_dispatch(self) -> None:
|
||||
precision = self.lightning_module.trainer.precision
|
||||
model = LightningIPUModule(self.lightning_module, precision)
|
||||
model = LightningIPUModule(self.lightning_module, self.precision_plugin.precision)
|
||||
self.model = model
|
||||
|
||||
# reset the backup
|
||||
|
|
|
@ -23,6 +23,7 @@ import pytorch_lightning as pl
|
|||
from pytorch_lightning.overrides.base import unwrap_lightning_module
|
||||
from pytorch_lightning.plugins.environments.cluster_environment import ClusterEnvironment
|
||||
from pytorch_lightning.plugins.io.checkpoint_plugin import CheckpointIO
|
||||
from pytorch_lightning.plugins.precision import PrecisionPlugin
|
||||
from pytorch_lightning.plugins.training_type.training_type_plugin import TrainingTypePlugin
|
||||
from pytorch_lightning.utilities import _XLA_AVAILABLE
|
||||
from pytorch_lightning.utilities.distributed import all_gather_ddp_if_available, ReduceOp
|
||||
|
@ -36,8 +37,9 @@ class ParallelPlugin(TrainingTypePlugin, ABC):
|
|||
parallel_devices: Optional[List[torch.device]] = None,
|
||||
cluster_environment: Optional[ClusterEnvironment] = None,
|
||||
checkpoint_io: Optional[CheckpointIO] = None,
|
||||
precision_plugin: Optional[PrecisionPlugin] = None,
|
||||
):
|
||||
super().__init__(checkpoint_io)
|
||||
super().__init__(checkpoint_io=checkpoint_io, precision_plugin=precision_plugin)
|
||||
self.parallel_devices = parallel_devices
|
||||
self.cluster_environment = cluster_environment
|
||||
|
||||
|
|
|
@ -75,7 +75,7 @@ class DDPShardedPlugin(DDPPlugin):
|
|||
optim_class = type(optimizer)
|
||||
zero_optimizer = OSS(params=optimizer.param_groups, optim=optim_class, **optimizer.defaults)
|
||||
if _FAIRSCALE_OSS_FP16_BROADCAST_AVAILABLE:
|
||||
precision = self._precision or self.lightning_module.trainer.precision
|
||||
precision = self._precision or self.precision_plugin.precision
|
||||
is_fp16 = precision in ("mixed", 16)
|
||||
# For multi-node training, compressing the model shards in fp16 before broadcasting
|
||||
# improves performance. When using PyTorch AMP, it will not degrade
|
||||
|
|
|
@ -118,9 +118,8 @@ class DDPSpawnShardedPlugin(DDPSpawnPlugin):
|
|||
def new_process(self, trainer: "pl.Trainer", mp_queue: SimpleQueue) -> None:
|
||||
# Ensure that the scaler points to the correct process group
|
||||
# which is re-initialized in a new process
|
||||
precision_plugin = trainer.accelerator.precision_plugin
|
||||
if isinstance(precision_plugin, ShardedNativeMixedPrecisionPlugin):
|
||||
precision_plugin.scaler = ShardedGradScaler()
|
||||
if isinstance(self.precision_plugin, ShardedNativeMixedPrecisionPlugin):
|
||||
self.precision_plugin.scaler = ShardedGradScaler()
|
||||
return super().new_process(trainer, mp_queue)
|
||||
|
||||
@classmethod
|
||||
|
|
|
@ -16,6 +16,7 @@ from typing import Any, Optional, Union
|
|||
import torch
|
||||
|
||||
from pytorch_lightning.plugins.io.checkpoint_plugin import CheckpointIO
|
||||
from pytorch_lightning.plugins.precision import PrecisionPlugin
|
||||
from pytorch_lightning.plugins.training_type.training_type_plugin import TrainingTypePlugin
|
||||
from pytorch_lightning.utilities import _XLA_AVAILABLE
|
||||
|
||||
|
@ -27,8 +28,9 @@ class SingleDevicePlugin(TrainingTypePlugin):
|
|||
self,
|
||||
device: torch.device,
|
||||
checkpoint_io: Optional[CheckpointIO] = None,
|
||||
precision_plugin: Optional[PrecisionPlugin] = None,
|
||||
):
|
||||
super().__init__(checkpoint_io)
|
||||
super().__init__(checkpoint_io=checkpoint_io, precision_plugin=precision_plugin)
|
||||
self.device: torch.device = device
|
||||
self.global_rank = 0
|
||||
self.local_rank = 0
|
||||
|
|
|
@ -16,6 +16,7 @@ from typing import Any, Dict, Optional
|
|||
|
||||
from pytorch_lightning.plugins.io.checkpoint_plugin import CheckpointIO
|
||||
from pytorch_lightning.plugins.io.xla_plugin import XLACheckpointIO
|
||||
from pytorch_lightning.plugins.precision import PrecisionPlugin
|
||||
from pytorch_lightning.plugins.training_type.single_device import SingleDevicePlugin
|
||||
from pytorch_lightning.utilities import _TPU_AVAILABLE, find_shared_parameters, set_shared_parameters
|
||||
from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
||||
|
@ -33,12 +34,13 @@ class SingleTPUPlugin(SingleDevicePlugin):
|
|||
self,
|
||||
device: int,
|
||||
checkpoint_io: Optional[CheckpointIO] = None,
|
||||
precision_plugin: Optional[PrecisionPlugin] = None,
|
||||
debug: bool = False,
|
||||
):
|
||||
|
||||
device = xm.xla_device(device)
|
||||
checkpoint_io = checkpoint_io or XLACheckpointIO()
|
||||
super().__init__(device=device, checkpoint_io=checkpoint_io)
|
||||
super().__init__(device=device, checkpoint_io=checkpoint_io, precision_plugin=precision_plugin)
|
||||
|
||||
self.debug = debug
|
||||
self.tpu_local_core_rank = 0
|
||||
|
|
|
@ -27,6 +27,7 @@ import pytorch_lightning as pl
|
|||
from pytorch_lightning.overrides import LightningDistributedModule
|
||||
from pytorch_lightning.plugins.io.checkpoint_plugin import CheckpointIO
|
||||
from pytorch_lightning.plugins.io.xla_plugin import XLACheckpointIO
|
||||
from pytorch_lightning.plugins.precision import PrecisionPlugin
|
||||
from pytorch_lightning.plugins.training_type.ddp_spawn import DDPSpawnPlugin
|
||||
from pytorch_lightning.trainer.connectors.data_connector import DataConnector
|
||||
from pytorch_lightning.trainer.states import TrainerFn
|
||||
|
@ -56,11 +57,14 @@ class TPUSpawnPlugin(DDPSpawnPlugin):
|
|||
self,
|
||||
parallel_devices: Optional[List[int]] = None,
|
||||
checkpoint_io: Optional[CheckpointIO] = None,
|
||||
precision_plugin: Optional[PrecisionPlugin] = None,
|
||||
debug: bool = False,
|
||||
**_: Any
|
||||
) -> None:
|
||||
checkpoint_io = checkpoint_io or XLACheckpointIO()
|
||||
super().__init__(parallel_devices=parallel_devices, checkpoint_io=checkpoint_io)
|
||||
super().__init__(
|
||||
parallel_devices=parallel_devices, checkpoint_io=checkpoint_io, precision_plugin=precision_plugin
|
||||
)
|
||||
self.debug = debug
|
||||
self.tpu_local_core_rank = 0
|
||||
self.tpu_global_core_rank = 0
|
||||
|
@ -167,7 +171,7 @@ class TPUSpawnPlugin(DDPSpawnPlugin):
|
|||
set_shared_parameters(self.model.module, shared_params)
|
||||
|
||||
trainer.accelerator.setup_optimizers(trainer)
|
||||
trainer.precision_plugin.connect(self._model, None, None)
|
||||
self.precision_plugin.connect(self._model, None, None)
|
||||
|
||||
self.barrier("pre-run-stage")
|
||||
|
||||
|
|
|
@ -25,6 +25,7 @@ import pytorch_lightning as pl
|
|||
from pytorch_lightning.overrides.base import unwrap_lightning_module
|
||||
from pytorch_lightning.plugins import TorchCheckpointIO
|
||||
from pytorch_lightning.plugins.io.checkpoint_plugin import CheckpointIO
|
||||
from pytorch_lightning.plugins.precision import PrecisionPlugin
|
||||
from pytorch_lightning.utilities.distributed import ReduceOp
|
||||
from pytorch_lightning.utilities.types import _EVALUATE_OUTPUT, _PATH, _PREDICT_OUTPUT
|
||||
|
||||
|
@ -33,16 +34,23 @@ class TrainingTypePlugin(ABC):
|
|||
"""Base class for all training type plugins that change the behaviour of the training, validation and test-
|
||||
loop."""
|
||||
|
||||
def __init__(self, checkpoint_io: Optional[CheckpointIO] = None) -> None:
|
||||
def __init__(
|
||||
self, checkpoint_io: Optional[CheckpointIO] = None, precision_plugin: Optional[PrecisionPlugin] = None
|
||||
) -> None:
|
||||
self._model: Optional[Module] = None
|
||||
self._results: Optional[Union[_EVALUATE_OUTPUT, _PREDICT_OUTPUT]] = None
|
||||
checkpoint_io = checkpoint_io if checkpoint_io is not None else TorchCheckpointIO()
|
||||
self._checkpoint_io = checkpoint_io
|
||||
self._precision_plugin = precision_plugin if precision_plugin is not None else PrecisionPlugin()
|
||||
|
||||
@property
|
||||
def checkpoint_io(self) -> CheckpointIO:
|
||||
return self._checkpoint_io
|
||||
|
||||
@property
|
||||
def precision_plugin(self) -> PrecisionPlugin:
|
||||
return self._precision_plugin
|
||||
|
||||
@checkpoint_io.setter
|
||||
def checkpoint_io(self, plugin: CheckpointIO) -> None:
|
||||
self._checkpoint_io = plugin
|
||||
|
|
|
@ -405,6 +405,9 @@ class AcceleratorConnector:
|
|||
# attach checkpoint plugin to the training type plugin
|
||||
if self._checkpoint_io is not None:
|
||||
self._training_type_plugin.checkpoint_io = self._checkpoint_io
|
||||
precision_plugin = self.precision_plugin
|
||||
if precision_plugin is not None:
|
||||
self._training_type_plugin._precision_plugin = precision_plugin
|
||||
self._training_type_plugin_resolved = True
|
||||
|
||||
return self._training_type_plugin
|
||||
|
@ -531,11 +534,11 @@ class AcceleratorConnector:
|
|||
|
||||
@property
|
||||
def _is_sharded_training_type(self) -> bool:
|
||||
return isinstance(self.training_type_plugin, (DDPShardedPlugin, DDPSpawnShardedPlugin))
|
||||
return isinstance(self._training_type_plugin, (DDPShardedPlugin, DDPSpawnShardedPlugin))
|
||||
|
||||
@property
|
||||
def _is_fully_sharded_training_type(self) -> bool:
|
||||
return isinstance(self.training_type_plugin, DDPFullyShardedPlugin)
|
||||
return isinstance(self._training_type_plugin, DDPFullyShardedPlugin)
|
||||
|
||||
@property
|
||||
def is_distributed(self) -> bool:
|
||||
|
@ -793,12 +796,10 @@ class AcceleratorConnector:
|
|||
acc_cls = IPUAccelerator
|
||||
else:
|
||||
acc_cls = CPUAccelerator
|
||||
# as precision_plugin is dependent on training_type_plugin, make sure
|
||||
# that we first select training_type_plugin, then precision_plugin
|
||||
accelerator = acc_cls(training_type_plugin=self.training_type_plugin, precision_plugin=self.precision_plugin)
|
||||
|
||||
accelerator = acc_cls(precision_plugin=None, training_type_plugin=self.training_type_plugin)
|
||||
# transfer ownership of the plugins to the accelerator
|
||||
self._training_type_plugin = proxy(self.training_type_plugin)
|
||||
self._precision_plugin = proxy(self.precision_plugin)
|
||||
|
||||
return accelerator
|
||||
|
||||
|
|
|
@ -1568,7 +1568,7 @@ class Trainer(
|
|||
|
||||
@property
|
||||
def precision_plugin(self) -> PrecisionPlugin:
|
||||
return self.accelerator.precision_plugin
|
||||
return self.training_type_plugin.precision_plugin
|
||||
|
||||
@property
|
||||
def global_rank(self) -> int:
|
||||
|
@ -1672,7 +1672,7 @@ class Trainer(
|
|||
|
||||
@property
|
||||
def precision(self) -> Union[str, int]:
|
||||
return self.accelerator.precision
|
||||
return self.training_type_plugin.precision_plugin.precision
|
||||
|
||||
@property
|
||||
def scaler(self):
|
||||
|
|
|
@ -193,8 +193,8 @@ def test_mixed_precision(tmpdir):
|
|||
|
||||
model = IPUModel()
|
||||
trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True, ipus=1, precision=16, callbacks=TestCallback())
|
||||
assert isinstance(trainer.accelerator.precision_plugin, IPUPrecisionPlugin)
|
||||
assert trainer.accelerator.precision_plugin.precision == 16
|
||||
assert isinstance(trainer.training_type_plugin.precision_plugin, IPUPrecisionPlugin)
|
||||
assert trainer.training_type_plugin.precision_plugin.precision == 16
|
||||
with pytest.raises(SystemExit):
|
||||
trainer.fit(model)
|
||||
|
||||
|
@ -213,8 +213,8 @@ def test_pure_half_precision(tmpdir):
|
|||
trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True, ipus=1, precision=16, callbacks=TestCallback())
|
||||
|
||||
assert isinstance(trainer.accelerator.training_type_plugin, IPUPlugin)
|
||||
assert isinstance(trainer.accelerator.precision_plugin, IPUPrecisionPlugin)
|
||||
assert trainer.accelerator.precision_plugin.precision == 16
|
||||
assert isinstance(trainer.training_type_plugin.precision_plugin, IPUPrecisionPlugin)
|
||||
assert trainer.training_type_plugin.precision_plugin.precision == 16
|
||||
|
||||
with pytest.raises(SystemExit):
|
||||
trainer.fit(model)
|
||||
|
|
|
@ -23,7 +23,7 @@ from torch.utils.data import DataLoader
|
|||
from pytorch_lightning import Trainer
|
||||
from pytorch_lightning.accelerators.cpu import CPUAccelerator
|
||||
from pytorch_lightning.accelerators.tpu import TPUAccelerator
|
||||
from pytorch_lightning.plugins import TPUPrecisionPlugin, TPUSpawnPlugin, XLACheckpointIO
|
||||
from pytorch_lightning.plugins import DDPPlugin, TPUPrecisionPlugin, TPUSpawnPlugin, XLACheckpointIO
|
||||
from pytorch_lightning.utilities import find_shared_parameters
|
||||
from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
||||
from tests.helpers.boring_model import BoringModel, RandomDataset
|
||||
|
@ -292,11 +292,23 @@ def test_tpu_invalid_raises():
|
|||
with pytest.raises(ValueError, match="TPUAccelerator` can only be used with a `TPUPrecisionPlugin"):
|
||||
accelerator.setup(object())
|
||||
|
||||
accelerator = TPUAccelerator(TPUPrecisionPlugin(), object())
|
||||
accelerator = TPUAccelerator(TPUPrecisionPlugin(), DDPPlugin())
|
||||
with pytest.raises(ValueError, match="TPUAccelerator` can only be used with a `SingleTPUPlugin` or `TPUSpawnPlugi"):
|
||||
accelerator.setup(object())
|
||||
|
||||
|
||||
def test_tpu_invalid_raises_set_precision_with_strategy():
|
||||
accelerator = TPUAccelerator(object(), TPUSpawnPlugin(precision_plugin=object()))
|
||||
with pytest.raises(ValueError, match="`TPUAccelerator` can only be used with a `TPUPrecisionPlugin`"):
|
||||
accelerator.setup(object())
|
||||
|
||||
accelerator = TPUAccelerator(None, DDPPlugin(precision_plugin=TPUPrecisionPlugin()))
|
||||
with pytest.raises(
|
||||
ValueError, match="TPUAccelerator` can only be used with a `SingleTPUPlugin` or `TPUSpawnPlugin"
|
||||
):
|
||||
accelerator.setup(object())
|
||||
|
||||
|
||||
@RunIf(tpu=True)
|
||||
def test_xla_checkpoint_plugin_being_default():
|
||||
trainer = Trainer(tpu_cores=8)
|
||||
|
|
|
@ -34,8 +34,8 @@ def test_invalid_on_cpu(tmpdir):
|
|||
def test_fsdp_with_sharded_amp(device_count_mock, mock_cuda_available, tmpdir):
|
||||
"""Test to ensure that plugin native amp plugin is correctly chosen when using sharded."""
|
||||
trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True, strategy="fsdp", gpus=1, precision=16)
|
||||
assert isinstance(trainer.accelerator.training_type_plugin, DDPFullyShardedPlugin)
|
||||
assert isinstance(trainer.accelerator.precision_plugin, FullyShardedNativeMixedPrecisionPlugin)
|
||||
assert isinstance(trainer.training_type_plugin, DDPFullyShardedPlugin)
|
||||
assert isinstance(trainer.training_type_plugin.precision_plugin, FullyShardedNativeMixedPrecisionPlugin)
|
||||
|
||||
|
||||
class TestFSDPModel(BoringModel):
|
||||
|
|
|
@ -170,8 +170,8 @@ def test_deepspeed_precision_choice(amp_backend, precision, tmpdir):
|
|||
)
|
||||
|
||||
assert isinstance(trainer.accelerator.training_type_plugin, DeepSpeedPlugin)
|
||||
assert isinstance(trainer.accelerator.precision_plugin, DeepSpeedPrecisionPlugin)
|
||||
assert trainer.accelerator.precision_plugin.precision == precision
|
||||
assert isinstance(trainer.training_type_plugin.precision_plugin, DeepSpeedPrecisionPlugin)
|
||||
assert trainer.training_type_plugin.precision_plugin.precision == precision
|
||||
|
||||
|
||||
@RunIf(deepspeed=True)
|
||||
|
|
Loading…
Reference in New Issue