lightning/pytorch_lightning/plugins/ddp_plugin.py

65 lines
2.0 KiB
Python
Raw Normal View History

from typing import List, Dict, Any
from pytorch_lightning.core.lightning import LightningModule
from pytorch_lightning.overrides.data_parallel import LightningDistributedDataParallel
class DDPPlugin(object):
"""
Plugin to link a custom ddp implementation to any arbitrary accelerator.
This plugin forwards all constructor arguments to `LightningDistributedDataParallel`,
which in turn forwards all args to `DistributedDataParallel`.
Example::
class MyDDP(DDPPlugin):
def configure_ddp(self, model, device_ids):
model = MyDDPWrapper(model, device_ids)
return model
my_ddp = MyDDP()
trainer = Trainer(accelerator='ddp_x', plugins=[my_ddp])
"""
def __init__(self, **kwargs):
self._ddp_kwargs: Dict[str, Any] = kwargs
def configure_ddp(
self, model: LightningModule, device_ids: List[int]
) -> LightningDistributedDataParallel:
"""
Pass through all customizations from constructor to `LightningDistributedDataParallel`.
Override to define a custom DDP implementation.
.. note:: Only requirement is that your DDP implementation subclasses LightningDistributedDataParallel
The default implementation is::
def configure_ddp(self, model, device_ids):
model = LightningDistributedDataParallel(
model, device_ids=device_ids, find_unused_parameters=True
)
return model
Args:
model: the lightningModule
device_ids: the list of devices available
Returns:
the model wrapped in LightningDistributedDataParallel
"""
# if unset, default `find_unused_parameters` `True`
self._ddp_kwargs["find_unused_parameters"] = self._ddp_kwargs.get(
"find_unused_parameters", True
)
model = LightningDistributedDataParallel(
model,
device_ids=device_ids,
**self._ddp_kwargs,
)
return model