# Copyright The PyTorch Lightning team. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import os from contextlib import contextmanager from typing import Any, Dict, List, Union import torch.distributed as torch_distrib from torch.optim import Optimizer from pytorch_lightning import _logger as log from pytorch_lightning.core.lightning import LightningModule from pytorch_lightning.overrides.data_parallel import LightningDistributedDataParallel from pytorch_lightning.plugins.plugin import LightningPlugin class DDPPlugin(LightningPlugin): """ Plugin to link a custom ddp implementation to any arbitrary accelerator. This plugin forwards all constructor arguments to `LightningDistributedDataParallel`, which in turn forwards all args to `DistributedDataParallel`. Example:: class MyDDP(DDPPlugin): def configure_ddp(self, model, device_ids): model = MyDDPWrapper(model, device_ids) return model my_ddp = MyDDP() trainer = Trainer(accelerator='ddp_x', plugins=[my_ddp]) """ def __init__(self, **kwargs): self._ddp_kwargs: Dict[str, Any] = kwargs def configure_ddp( self, model: LightningModule, device_ids: List[int] ) -> LightningDistributedDataParallel: """ Pass through all customizations from constructor to `LightningDistributedDataParallel`. Override to define a custom DDP implementation. .. note:: Only requirement is that your DDP implementation subclasses LightningDistributedDataParallel The default implementation is:: def configure_ddp(self, model, device_ids): model = LightningDistributedDataParallel( model, device_ids=device_ids, find_unused_parameters=True ) return model Args: model: the lightningModule device_ids: the list of devices available Returns: the model wrapped in LightningDistributedDataParallel """ # if unset, default `find_unused_parameters` `True` self._ddp_kwargs["find_unused_parameters"] = self._ddp_kwargs.get( "find_unused_parameters", True ) model = LightningDistributedDataParallel( model, device_ids=device_ids, **self._ddp_kwargs, ) return model def init_ddp_connection( self, trainer, cluster_environment, global_rank: int, world_size: int, is_slurm_managing_tasks: bool = True, ) -> None: # Todo: required argument `is_slurm_managing_tasks` is not used os.environ["MASTER_ADDR"] = str(cluster_environment.master_address()) os.environ["MASTER_PORT"] = str(cluster_environment.master_port()) os.environ["WORLD_SIZE"] = str(cluster_environment.world_size()) torch_backend = "nccl" if trainer.on_gpu else "gloo" if not torch_distrib.is_initialized(): log.info( f"initializing ddp: GLOBAL_RANK: {global_rank}, MEMBER: {global_rank + 1}/{world_size}" ) torch_distrib.init_process_group( torch_backend, rank=global_rank, world_size=world_size ) def on_before_forward(self, model: LightningModule, *args): """ Override to handle custom input to device logic. For DDP, no logic is required as this is handled internally within the DDP wrapper. Example:: def on_before_forward(self, model, *args): batch, batch_idx = args return batch.to(model.device) Args: args: Inputs to the model. model: Model to train. Returns: args moved to correct device if needed. """ return args def optimizer_state(self, optimizer: Optimizer) -> dict: return optimizer.state_dict() def on_after_setup_optimizers(self, trainer): """ Called after optimizers have been set-up. This is useful for doing any configuration options in RPC, or state sharding. """ def get_model_from_plugin( self, model: Union[LightningDistributedDataParallel, LightningModule] ) -> LightningModule: """ Override to modify returning base :class:`LightningModule` when accessing variable and functions outside of the parallel wrapper. Example:: ref_model = ddp_plugin.get_model_from_plugin(model) ref_model.training_step(...) Args: model: Model with parallel wrapper. Returns: Reference :class:`LightningModule` within parallel wrapper. """ if isinstance(model, LightningDistributedDataParallel): return model.module return model @contextmanager def block_backward_sync(self, model: LightningDistributedDataParallel): """ Blocks ddp sync gradients behaviour on backwards pass. This is useful for skipping sync when accumulating gradients, reducing communication overhead Returns: context manager with sync behaviour off """ yield model.no_sync() def on_before_manual_backward(self, model: LightningDistributedDataParallel, output: Any): model.reducer_prepare_for_backwards(output) def on_after_manual_backward(self, model: LightningDistributedDataParallel): model.reducer_reset_hooks() def distributed_sampler_kwargs(self, distributed_sampler_kwargs): return distributed_sampler_kwargs @property def data_parallel_group(self): """ Return the group that this process exists in. By default, this is the world size. Useful for when additional parallel groups have been created, to select certain processes. Returns: The ProcessGroup this process exists in. """ return torch_distrib.group.WORLD