lightning/pytorch_lightning/overrides/data_parallel.py

350 lines
12 KiB
Python

# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import itertools
import threading
import warnings
from collections.abc import Iterable, Mapping
from itertools import chain
from typing import Any, Optional
import torch
from torch import Tensor
from torch.cuda._utils import _get_device_index
from torch.nn import DataParallel, Module
from torch.nn.parallel import DistributedDataParallel
from torch.nn.parallel._functions import Gather
from pytorch_lightning.core.lightning import LightningModule
from pytorch_lightning.core.step_result import Result
from pytorch_lightning.utilities.warnings import WarningCache
def _find_tensors(obj): # pragma: no-cover
r"""
Recursively find all tensors contained in the specified object.
"""
if isinstance(obj, torch.Tensor):
return [obj]
if isinstance(obj, (list, tuple)):
return itertools.chain(*map(_find_tensors, obj))
if isinstance(obj, dict):
return itertools.chain(*map(_find_tensors, obj.values()))
return []
def get_a_var(obj): # pragma: no-cover
if isinstance(obj, torch.Tensor):
return obj
if isinstance(obj, (list, tuple)):
for result in map(get_a_var, obj):
if isinstance(result, torch.Tensor):
return result
if isinstance(obj, dict):
for result in map(get_a_var, obj.items()):
if isinstance(result, torch.Tensor):
return result
return None
warning_cache = WarningCache()
class LightningDataParallel(DataParallel):
"""
Override the forward call in lightning so it goes to training and validation step respectively
"""
def forward(self, *inputs, **kwargs):
if not self.device_ids:
return self.module(*inputs, **kwargs)
for t in chain(self.module.parameters(), self.module.buffers()):
if t.device != self.src_device_obj:
raise RuntimeError("module must have its parameters and buffers "
"on device {} (device_ids[0]) but found one of "
"them on device: {}".format(self.src_device_obj, t.device))
inputs, kwargs = self.scatter(inputs, kwargs, self.device_ids)
if len(self.device_ids) == 1:
# lightning
if self.module.training:
return self.module.training_step(*inputs[0], **kwargs[0])
if self.module.testing:
return self.module.test_step(*inputs[0], **kwargs[0])
return self.module.validation_step(*inputs[0], **kwargs[0])
replicas = self.replicate(self.module, self.device_ids[:len(inputs)])
outputs = self.parallel_apply(replicas, inputs, kwargs)
if isinstance(outputs[0], Result):
outputs = self.__gather_structured_result(outputs)
else:
outputs = self.gather(outputs)
return outputs
def __gather_structured_result(self, outputs):
prototype_output = outputs[0]
original_class = prototype_output.__class__
outputs = [dict(x) for x in outputs]
# remove all the meta info
meta = outputs[0]['meta']
for i, output in enumerate(outputs):
del output['meta']
outputs = self.gather(outputs)
result = original_class()
result.update(outputs)
result['meta'] = meta
return result
def gather(self, outputs):
r"""
Override the gather method to support python scalars as well.
"""
def gather_map(outputs):
elem = outputs[0]
elem_type = type(elem)
if isinstance(elem, torch.Tensor):
return Gather.apply(self.output_device, self.dim, *outputs)
if elem is None:
return None
if isinstance(elem, Mapping):
if not all((len(elem) == len(d) for d in outputs)):
raise ValueError('All dicts must have the same number of keys')
return elem_type(((k, gather_map([d[k] for d in outputs]))
for k in elem))
if isinstance(elem, Iterable) and not isinstance(elem, str):
return elem_type(map(gather_map, zip(*outputs)))
return outputs
# Recursive function calls like this create reference cycles.
# Setting the function to None clears the refcycle.
try:
res = gather_map(outputs)
finally:
gather_map = None
return res
def parallel_apply(self, replicas, inputs, kwargs):
return parallel_apply(replicas, inputs, kwargs, self.device_ids[:len(replicas)])
class LightningDistributedDataParallel(DistributedDataParallel):
def __init__(self, module: LightningModule, *args, **kwargs):
warnings.warn(
"The usage of `LightningDistributedDataParallel` is deprecated since v1.2 and will be removed in v1.4."
" From now on we recommend to directly sublcass `torch.nn.parallel.DistributedDataParallel`.",
DeprecationWarning
)
super().__init__(LightningDistributedModule(module), *args, **kwargs)
class LightningDistributedModule(torch.nn.Module):
def __init__(self, pl_module: LightningModule):
"""
Wraps the user's LightningModule and redirects the forward call to the appropriate
method, either ``training_step``, ``validation_step`` or ```test_step``.
This class is used in combination with :class:`~torch.nn.parallel.DistributedDataParallel` as
shown in the example.
Example:
ddp_model = DistributedDataParallel(
module=LightningDistributedModule(lightning_module),
device_ids=[local_rank],
...
)
Args:
pl_module: the model to wrap
"""
super().__init__()
self.module = pl_module
def forward(self, *inputs, **kwargs):
if self.module.training:
output = self.module.training_step(*inputs, **kwargs)
warn_if_output_is_none(output, "training_step")
elif self.module.testing:
output = self.module.test_step(*inputs, **kwargs)
warn_if_output_is_none(output, "test_step")
else:
output = self.module.validation_step(*inputs, **kwargs)
warn_if_output_is_none(output, "validation_step")
return output
# In manual_optimization, we need to call reducer prepare_for_backward.
# Note: Keep track of Pytorch DDP and update if there is a change
# https://github.com/pytorch/pytorch/blob/v1.7.1/torch/nn/parallel/distributed.py#L626-L638
def prepare_for_backward(model: DistributedDataParallel, output: Any):
if torch.is_grad_enabled() and model.require_backward_grad_sync:
model.require_forward_param_sync = True
# We'll return the output object verbatim since it is a freeform
# object. We need to find any tensors in this object, though,
# because we need to figure out which parameters were used during
# this forward pass, to ensure we short circuit reduction for any
# unused parameters. Only if `find_unused_parameters` is set.
if model.find_unused_parameters:
model.reducer.prepare_for_backward(list(_find_tensors(output)))
else:
model.reducer.prepare_for_backward([])
else:
model.require_forward_param_sync = False
def warn_if_output_is_none(output: Any, method_name: str) -> None:
if output is None:
warning_cache.warn(f'Your {method_name} returned None. Did you forget to return an output?')
def warn_missing_output(fx_called):
if fx_called == 'training_step':
warning_cache.warn("Your training_step returned None. Make sure that was your intention!")
def parallel_apply(
modules: Module,
inputs: Tensor,
kwargs_tup: Optional[tuple] = None,
devices: Optional[list] = None,
): # pragma: no-cover
r"""Applies each `module` in :attr:`modules` in parallel on arguments
contained in :attr:`inputs` (positional) and :attr:`kwargs_tup` (keyword)
on each of :attr:`devices`.
Args:
modules: modules to be parallelized
inputs: inputs to the modules
devices: CUDA devices
:attr:`modules`, :attr:`inputs`, :attr:`kwargs_tup` (if given), and
:attr:`devices` (if given) should all have same length. Moreover, each
element of :attr:`inputs` can either be a single object as the only argument
to a module, or a collection of positional arguments.
"""
assert len(modules) == len(inputs)
if kwargs_tup is not None:
assert len(modules) == len(kwargs_tup)
else:
kwargs_tup = ({},) * len(modules)
if devices is not None:
assert len(modules) == len(devices)
else:
devices = [None] * len(modules)
devices = list(map(lambda x: _get_device_index(x, True), devices))
lock = threading.Lock()
results = {}
grad_enabled = torch.is_grad_enabled()
def _worker(i, module, input, kwargs, device=None):
torch.set_grad_enabled(grad_enabled)
if device is None:
device = get_a_var(input).get_device()
try:
with torch.cuda.device(device):
# this also avoids accidental slicing of `input` if it is a Tensor
if not isinstance(input, (list, tuple)):
input = (input,)
module = module.to(device)
# ---------------
# CHANGE
if module.training:
output = module.training_step(*input, **kwargs)
fx_called = 'training_step'
elif module.testing:
output = module.test_step(*input, **kwargs)
fx_called = 'test_step'
else:
output = module.validation_step(*input, **kwargs)
fx_called = 'validation_step'
if output is None:
warn_missing_output(fx_called)
if output is not None and module._distrib_type in ('dp', 'ddp2'):
auto_squeeze_dim_zeros(output)
# ---------------
with lock:
results[i] = output
# todo: specify the possible exception
except Exception as ex:
with lock:
results[i] = ex
# TODO: fix hack (maybe not a hack)
# make sure each module knows what training state it's in...
# fixes weird bug where copies are out of sync
root_m = modules[0]
for m in modules[1:]:
m.training = root_m.training
m.testing = root_m.testing
if len(modules) > 1:
threads = [threading.Thread(target=_worker,
args=(i, module, input, kwargs, device))
for i, (module, input, kwargs, device) in
enumerate(zip(modules, inputs, kwargs_tup, devices))]
for thread in threads:
thread.start()
for thread in threads:
thread.join()
else:
_worker(0, modules[0], inputs[0], kwargs_tup[0], devices[0])
outputs = []
for i in range(len(inputs)):
output = results[i]
if isinstance(output, Exception):
raise output
outputs.append(output)
return outputs
def auto_squeeze_dim_zeros(output):
"""
In DP or DDP2 we need to unsqueeze dim 0
:param output:
:return:
"""
if isinstance(output, torch.Tensor):
output = output.unsqueeze(0)
return output
for k, v in output.items():
if not isinstance(v, torch.Tensor):
continue
is_scalar = v.dim() == 0
if is_scalar:
output[k] = output[k].unsqueeze(0)