lightning/docs/source-pytorch/advanced/strategy_registry.rst

Ignoring revisions in .git-blame-ignore-revs. Click here to bypass and see the normal blame view.

48 lines
1.7 KiB
ReStructuredText
Raw Normal View History

2022-01-05 10:37:11 +00:00
Strategy Registry
=================
2021-10-28 11:14:08 +00:00
2022-01-05 10:37:11 +00:00
Lightning includes a registry that holds information about Training strategies and allows for the registration of new custom strategies.
2021-10-28 11:14:08 +00:00
2022-01-05 10:37:11 +00:00
The Strategies are assigned strings that identify them, such as "ddp", "deepspeed_stage_2_offload", and so on.
It also returns the optional description and parameters for initialising the Strategy that were defined during registration.
2021-10-28 11:14:08 +00:00
.. code-block:: python
# Training with the DDP Strategy
trainer = Trainer(strategy="ddp", accelerator="gpu", devices=4)
2021-10-28 11:14:08 +00:00
# Training with DeepSpeed ZeRO Stage 3 and CPU Offload
trainer = Trainer(strategy="deepspeed_stage_3_offload", accelerator="gpu", devices=3)
# Training with the TPU Spawn Strategy with `debug` as True
trainer = Trainer(strategy="xla_debug", accelerator="tpu", devices=8)
2021-10-28 11:14:08 +00:00
2022-01-05 10:37:11 +00:00
Additionally, you can pass your custom registered training strategies to the ``strategy`` argument.
2021-10-28 11:14:08 +00:00
.. code-block:: python
from lightning.pytorch.strategies import DDPStrategy, StrategyRegistry, CheckpointIO
2021-10-28 11:14:08 +00:00
class CustomCheckpointIO(CheckpointIO):
def save_checkpoint(self, checkpoint: Dict[str, Any], path: Union[str, Path]) -> None:
...
def load_checkpoint(self, path: Union[str, Path]) -> Dict[str, Any]:
...
custom_checkpoint_io = CustomCheckpointIO()
# Register the DDP Strategy with your custom CheckpointIO plugin
2022-01-05 10:37:11 +00:00
StrategyRegistry.register(
2021-10-28 11:14:08 +00:00
"ddp_custom_checkpoint_io",
DDPStrategy,
description="DDP Strategy with custom checkpoint io plugin",
2021-10-28 11:14:08 +00:00
checkpoint_io=custom_checkpoint_io,
)
trainer = Trainer(strategy="ddp_custom_checkpoint_io", accelerator="gpu", devices=2)