# Copyright The PyTorch Lightning team. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import itertools from typing import Any import torch from torch.nn.parallel import DistributedDataParallel from pytorch_lightning.core.lightning import LightningModule from pytorch_lightning.overrides.base import _LightningModuleWrapperBase class LightningDistributedModule(_LightningModuleWrapperBase): def __init__(self, pl_module: LightningModule): """ Wraps the user's LightningModule and redirects the forward call to the appropriate method, either ``training_step``, ``validation_step``, ``test_step`` or ``predict``. This class is used in combination with :class:`~torch.nn.parallel.DistributedDataParallel` as shown in the example. Example: ddp_model = torch.nn.parallel.DistributedDataParallel( module=LightningDistributedModule(lightning_module), device_ids=[local_rank], ... ) Args: pl_module: the model to wrap """ super().__init__(pl_module) def _find_tensors(obj): # pragma: no-cover r""" Recursively find all tensors contained in the specified object. """ if isinstance(obj, torch.Tensor): return [obj] if isinstance(obj, (list, tuple)): return itertools.chain(*map(_find_tensors, obj)) if isinstance(obj, dict): return itertools.chain(*map(_find_tensors, obj.values())) return [] # In manual_optimization, we need to call reducer prepare_for_backward. # Note: Keep track of Pytorch DDP and update if there is a change # https://github.com/pytorch/pytorch/blob/v1.7.1/torch/nn/parallel/distributed.py#L626-L638 def prepare_for_backward(model: DistributedDataParallel, output: Any): if torch.is_grad_enabled() and model.require_backward_grad_sync: model.require_forward_param_sync = True # We'll return the output object verbatim since it is a freeform # object. We need to find any tensors in this object, though, # because we need to figure out which parameters were used during # this forward pass, to ensure we short circuit reduction for any # unused parameters. Only if `find_unused_parameters` is set. if model.find_unused_parameters: model.reducer.prepare_for_backward(list(_find_tensors(output))) else: model.reducer.prepare_for_backward([]) else: model.require_forward_param_sync = False