78 lines
3.0 KiB
Python
78 lines
3.0 KiB
Python
|
# Copyright The PyTorch Lightning team.
|
||
|
#
|
||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||
|
# you may not use this file except in compliance with the License.
|
||
|
# You may obtain a copy of the License at
|
||
|
#
|
||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||
|
#
|
||
|
# Unless required by applicable law or agreed to in writing, software
|
||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||
|
# See the License for the specific language governing permissions and
|
||
|
# limitations under the License.
|
||
|
import itertools
|
||
|
from typing import Any
|
||
|
|
||
|
import torch
|
||
|
from torch.nn.parallel import DistributedDataParallel
|
||
|
|
||
|
from pytorch_lightning.core.lightning import LightningModule
|
||
|
from pytorch_lightning.overrides.base import _LightningModuleWrapperBase
|
||
|
|
||
|
|
||
|
class LightningDistributedModule(_LightningModuleWrapperBase):
|
||
|
|
||
|
def __init__(self, pl_module: LightningModule):
|
||
|
"""
|
||
|
Wraps the user's LightningModule and redirects the forward call to the appropriate
|
||
|
method, either ``training_step``, ``validation_step``, ``test_step`` or ``predict``.
|
||
|
This class is used in combination with :class:`~torch.nn.parallel.DistributedDataParallel` as
|
||
|
shown in the example.
|
||
|
|
||
|
Example:
|
||
|
|
||
|
ddp_model = torch.nn.parallel.DistributedDataParallel(
|
||
|
module=LightningDistributedModule(lightning_module),
|
||
|
device_ids=[local_rank],
|
||
|
...
|
||
|
)
|
||
|
|
||
|
Args:
|
||
|
pl_module: the model to wrap
|
||
|
|
||
|
"""
|
||
|
super().__init__(pl_module)
|
||
|
|
||
|
|
||
|
def _find_tensors(obj): # pragma: no-cover
|
||
|
r"""
|
||
|
Recursively find all tensors contained in the specified object.
|
||
|
"""
|
||
|
if isinstance(obj, torch.Tensor):
|
||
|
return [obj]
|
||
|
if isinstance(obj, (list, tuple)):
|
||
|
return itertools.chain(*map(_find_tensors, obj))
|
||
|
if isinstance(obj, dict):
|
||
|
return itertools.chain(*map(_find_tensors, obj.values()))
|
||
|
return []
|
||
|
|
||
|
|
||
|
# In manual_optimization, we need to call reducer prepare_for_backward.
|
||
|
# Note: Keep track of Pytorch DDP and update if there is a change
|
||
|
# https://github.com/pytorch/pytorch/blob/v1.7.1/torch/nn/parallel/distributed.py#L626-L638
|
||
|
def prepare_for_backward(model: DistributedDataParallel, output: Any):
|
||
|
if torch.is_grad_enabled() and model.require_backward_grad_sync:
|
||
|
model.require_forward_param_sync = True
|
||
|
# We'll return the output object verbatim since it is a freeform
|
||
|
# object. We need to find any tensors in this object, though,
|
||
|
# because we need to figure out which parameters were used during
|
||
|
# this forward pass, to ensure we short circuit reduction for any
|
||
|
# unused parameters. Only if `find_unused_parameters` is set.
|
||
|
if model.find_unused_parameters:
|
||
|
model.reducer.prepare_for_backward(list(_find_tensors(output)))
|
||
|
else:
|
||
|
model.reducer.prepare_for_backward([])
|
||
|
else:
|
||
|
model.require_forward_param_sync = False
|