2020-10-27 12:27:59 +00:00
|
|
|
from typing import List, Dict, Any
|
|
|
|
|
2020-10-22 09:15:51 +00:00
|
|
|
from pytorch_lightning.core.lightning import LightningModule
|
2020-10-27 12:27:59 +00:00
|
|
|
from pytorch_lightning.overrides.data_parallel import LightningDistributedDataParallel
|
2020-10-22 09:15:51 +00:00
|
|
|
|
|
|
|
|
|
|
|
class DDPPlugin(object):
|
|
|
|
"""
|
|
|
|
Plugin to link a custom ddp implementation to any arbitrary accelerator.
|
|
|
|
|
2020-10-27 12:27:59 +00:00
|
|
|
This plugin forwards all constructor arguments to `LightningDistributedDataParallel`,
|
|
|
|
which in turn forwards all args to `DistributedDataParallel`.
|
|
|
|
|
2020-10-22 09:15:51 +00:00
|
|
|
Example::
|
|
|
|
|
|
|
|
class MyDDP(DDPPlugin):
|
|
|
|
|
|
|
|
def configure_ddp(self, model, device_ids):
|
|
|
|
model = MyDDPWrapper(model, device_ids)
|
|
|
|
return model
|
|
|
|
|
|
|
|
my_ddp = MyDDP()
|
|
|
|
trainer = Trainer(accelerator='ddp_x', plugins=[my_ddp])
|
|
|
|
"""
|
|
|
|
|
2020-10-27 12:27:59 +00:00
|
|
|
def __init__(self, **kwargs):
|
|
|
|
self._ddp_kwargs: Dict[str, Any] = kwargs
|
|
|
|
|
|
|
|
def configure_ddp(
|
|
|
|
self, model: LightningModule, device_ids: List[int]
|
|
|
|
) -> LightningDistributedDataParallel:
|
2020-10-22 09:15:51 +00:00
|
|
|
"""
|
2020-10-27 12:27:59 +00:00
|
|
|
Pass through all customizations from constructor to `LightningDistributedDataParallel`.
|
2020-10-22 09:15:51 +00:00
|
|
|
Override to define a custom DDP implementation.
|
|
|
|
|
|
|
|
.. note:: Only requirement is that your DDP implementation subclasses LightningDistributedDataParallel
|
|
|
|
|
|
|
|
|
|
|
|
The default implementation is::
|
|
|
|
|
|
|
|
def configure_ddp(self, model, device_ids):
|
|
|
|
model = LightningDistributedDataParallel(
|
|
|
|
model, device_ids=device_ids, find_unused_parameters=True
|
|
|
|
)
|
|
|
|
return model
|
|
|
|
|
|
|
|
Args:
|
|
|
|
model: the lightningModule
|
|
|
|
device_ids: the list of devices available
|
|
|
|
|
|
|
|
Returns:
|
2020-10-26 12:34:35 +00:00
|
|
|
the model wrapped in LightningDistributedDataParallel
|
2020-10-22 09:15:51 +00:00
|
|
|
|
|
|
|
"""
|
2020-10-27 12:27:59 +00:00
|
|
|
# if unset, default `find_unused_parameters` `True`
|
|
|
|
self._ddp_kwargs["find_unused_parameters"] = self._ddp_kwargs.get(
|
|
|
|
"find_unused_parameters", True
|
|
|
|
)
|
|
|
|
model = LightningDistributedDataParallel(
|
|
|
|
model,
|
|
|
|
device_ids=device_ids,
|
|
|
|
**self._ddp_kwargs,
|
|
|
|
)
|
2020-10-22 09:15:51 +00:00
|
|
|
return model
|