183 lines
6.4 KiB
Python
183 lines
6.4 KiB
Python
# Copyright The PyTorch Lightning team.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
import os
|
|
from contextlib import contextmanager
|
|
from typing import Any, Dict, List, Union
|
|
|
|
import torch.distributed as torch_distrib
|
|
from torch.optim import Optimizer
|
|
|
|
from pytorch_lightning import _logger as log
|
|
from pytorch_lightning.core.lightning import LightningModule
|
|
from pytorch_lightning.overrides.data_parallel import LightningDistributedDataParallel
|
|
from pytorch_lightning.plugins.plugin import LightningPlugin
|
|
|
|
|
|
class DDPPlugin(LightningPlugin):
|
|
"""
|
|
Plugin to link a custom ddp implementation to any arbitrary accelerator.
|
|
|
|
This plugin forwards all constructor arguments to `LightningDistributedDataParallel`,
|
|
which in turn forwards all args to `DistributedDataParallel`.
|
|
|
|
Example::
|
|
|
|
class MyDDP(DDPPlugin):
|
|
|
|
def configure_ddp(self, model, device_ids):
|
|
model = MyDDPWrapper(model, device_ids)
|
|
return model
|
|
|
|
my_ddp = MyDDP()
|
|
trainer = Trainer(accelerator='ddp_x', plugins=[my_ddp])
|
|
"""
|
|
|
|
def __init__(self, **kwargs):
|
|
self._ddp_kwargs: Dict[str, Any] = kwargs
|
|
|
|
def configure_ddp(
|
|
self, model: LightningModule, device_ids: List[int]
|
|
) -> LightningDistributedDataParallel:
|
|
"""
|
|
Pass through all customizations from constructor to `LightningDistributedDataParallel`.
|
|
Override to define a custom DDP implementation.
|
|
|
|
.. note:: Only requirement is that your DDP implementation subclasses LightningDistributedDataParallel
|
|
|
|
|
|
The default implementation is::
|
|
|
|
def configure_ddp(self, model, device_ids):
|
|
model = LightningDistributedDataParallel(
|
|
model, device_ids=device_ids, find_unused_parameters=True
|
|
)
|
|
return model
|
|
|
|
Args:
|
|
model: the lightningModule
|
|
device_ids: the list of devices available
|
|
|
|
Returns:
|
|
the model wrapped in LightningDistributedDataParallel
|
|
|
|
"""
|
|
# if unset, default `find_unused_parameters` `True`
|
|
self._ddp_kwargs["find_unused_parameters"] = self._ddp_kwargs.get(
|
|
"find_unused_parameters", True
|
|
)
|
|
model = LightningDistributedDataParallel(
|
|
model,
|
|
device_ids=device_ids,
|
|
**self._ddp_kwargs,
|
|
)
|
|
return model
|
|
|
|
def init_ddp_connection(
|
|
self,
|
|
trainer,
|
|
cluster_environment,
|
|
global_rank: int,
|
|
world_size: int,
|
|
is_slurm_managing_tasks: bool = True,
|
|
) -> None:
|
|
# Todo: required argument `is_slurm_managing_tasks` is not used
|
|
os.environ["MASTER_ADDR"] = str(cluster_environment.master_address())
|
|
os.environ["MASTER_PORT"] = str(cluster_environment.master_port())
|
|
os.environ["WORLD_SIZE"] = str(cluster_environment.world_size())
|
|
torch_backend = "nccl" if trainer.on_gpu else "gloo"
|
|
|
|
if not torch_distrib.is_initialized():
|
|
log.info(
|
|
f"initializing ddp: GLOBAL_RANK: {global_rank}, MEMBER: {global_rank + 1}/{world_size}"
|
|
)
|
|
torch_distrib.init_process_group(
|
|
torch_backend, rank=global_rank, world_size=world_size
|
|
)
|
|
|
|
def on_before_forward(self, model: LightningModule, *args):
|
|
"""
|
|
Override to handle custom input to device logic. For DDP, no logic is required as this is handled internally
|
|
within the DDP wrapper.
|
|
|
|
Example::
|
|
|
|
def on_before_forward(self, model, *args):
|
|
batch, batch_idx = args
|
|
return batch.to(model.device)
|
|
|
|
Args:
|
|
args: Inputs to the model.
|
|
model: Model to train.
|
|
Returns: args moved to correct device if needed.
|
|
"""
|
|
return args
|
|
|
|
def optimizer_state(self, optimizer: Optimizer) -> dict:
|
|
return optimizer.state_dict()
|
|
|
|
def on_after_setup_optimizers(self, trainer):
|
|
"""
|
|
Called after optimizers have been set-up. This is useful for doing any configuration options in RPC, or
|
|
state sharding.
|
|
"""
|
|
|
|
def get_model_from_plugin(
|
|
self,
|
|
model: Union[LightningDistributedDataParallel, LightningModule]
|
|
) -> LightningModule:
|
|
"""
|
|
Override to modify returning base :class:`LightningModule`
|
|
when accessing variable and functions outside of the parallel wrapper.
|
|
|
|
Example::
|
|
ref_model = ddp_plugin.get_model_from_plugin(model)
|
|
ref_model.training_step(...)
|
|
|
|
Args:
|
|
model: Model with parallel wrapper.
|
|
|
|
Returns: Reference :class:`LightningModule` within parallel wrapper.
|
|
|
|
"""
|
|
if isinstance(model, LightningDistributedDataParallel):
|
|
return model.module
|
|
return model
|
|
|
|
@contextmanager
|
|
def block_backward_sync(self, model: LightningDistributedDataParallel):
|
|
"""
|
|
Blocks ddp sync gradients behaviour on backwards pass.
|
|
This is useful for skipping sync when accumulating gradients, reducing communication overhead
|
|
Returns: context manager with sync behaviour off
|
|
"""
|
|
yield model.no_sync()
|
|
|
|
def on_before_manual_backward(self, model: LightningDistributedDataParallel, output: Any):
|
|
model.reducer_prepare_for_backwards(output)
|
|
|
|
def on_after_manual_backward(self, model: LightningDistributedDataParallel):
|
|
model.reducer_reset_hooks()
|
|
|
|
def distributed_sampler_kwargs(self, distributed_sampler_kwargs):
|
|
return distributed_sampler_kwargs
|
|
|
|
@property
|
|
def data_parallel_group(self):
|
|
"""
|
|
Return the group that this process exists in. By default, this is the world size.
|
|
Useful for when additional parallel groups have been created, to select certain processes.
|
|
Returns: The ProcessGroup this process exists in.
|
|
"""
|
|
return torch_distrib.group.WORLD
|