# Copyright The PyTorch Lightning team. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import itertools import threading import warnings from collections.abc import Iterable, Mapping from itertools import chain from typing import Any, Optional import torch from torch import Tensor from torch.cuda._utils import _get_device_index from torch.nn import DataParallel, Module from torch.nn.parallel import DistributedDataParallel from torch.nn.parallel._functions import Gather from pytorch_lightning.core.lightning import LightningModule from pytorch_lightning.core.step_result import Result from pytorch_lightning.utilities.warnings import WarningCache def _find_tensors(obj): # pragma: no-cover r""" Recursively find all tensors contained in the specified object. """ if isinstance(obj, torch.Tensor): return [obj] if isinstance(obj, (list, tuple)): return itertools.chain(*map(_find_tensors, obj)) if isinstance(obj, dict): return itertools.chain(*map(_find_tensors, obj.values())) return [] def get_a_var(obj): # pragma: no-cover if isinstance(obj, torch.Tensor): return obj if isinstance(obj, (list, tuple)): for result in map(get_a_var, obj): if isinstance(result, torch.Tensor): return result if isinstance(obj, dict): for result in map(get_a_var, obj.items()): if isinstance(result, torch.Tensor): return result return None warning_cache = WarningCache() class LightningDataParallel(DataParallel): """ Override the forward call in lightning so it goes to training and validation step respectively """ def forward(self, *inputs, **kwargs): if not self.device_ids: return self.module(*inputs, **kwargs) for t in chain(self.module.parameters(), self.module.buffers()): if t.device != self.src_device_obj: raise RuntimeError("module must have its parameters and buffers " "on device {} (device_ids[0]) but found one of " "them on device: {}".format(self.src_device_obj, t.device)) inputs, kwargs = self.scatter(inputs, kwargs, self.device_ids) if len(self.device_ids) == 1: # lightning if self.module.training: return self.module.training_step(*inputs[0], **kwargs[0]) if self.module.testing: return self.module.test_step(*inputs[0], **kwargs[0]) return self.module.validation_step(*inputs[0], **kwargs[0]) replicas = self.replicate(self.module, self.device_ids[:len(inputs)]) outputs = self.parallel_apply(replicas, inputs, kwargs) if isinstance(outputs[0], Result): outputs = self.__gather_structured_result(outputs) else: outputs = self.gather(outputs) return outputs def __gather_structured_result(self, outputs): prototype_output = outputs[0] original_class = prototype_output.__class__ outputs = [dict(x) for x in outputs] # remove all the meta info meta = outputs[0]['meta'] for i, output in enumerate(outputs): del output['meta'] outputs = self.gather(outputs) result = original_class() result.update(outputs) result['meta'] = meta return result def gather(self, outputs): r""" Override the gather method to support python scalars as well. """ def gather_map(outputs): elem = outputs[0] elem_type = type(elem) if isinstance(elem, torch.Tensor): return Gather.apply(self.output_device, self.dim, *outputs) if elem is None: return None if isinstance(elem, Mapping): if not all((len(elem) == len(d) for d in outputs)): raise ValueError('All dicts must have the same number of keys') return elem_type(((k, gather_map([d[k] for d in outputs])) for k in elem)) if isinstance(elem, Iterable) and not isinstance(elem, str): return elem_type(map(gather_map, zip(*outputs))) return outputs # Recursive function calls like this create reference cycles. # Setting the function to None clears the refcycle. try: res = gather_map(outputs) finally: gather_map = None return res def parallel_apply(self, replicas, inputs, kwargs): return parallel_apply(replicas, inputs, kwargs, self.device_ids[:len(replicas)]) class LightningDistributedDataParallel(DistributedDataParallel): def __init__(self, module: LightningModule, *args, **kwargs): warnings.warn( "The usage of `LightningDistributedDataParallel` is deprecated since v1.2 and will be removed in v1.4." " From now on we recommend to directly sublcass `torch.nn.parallel.DistributedDataParallel`.", DeprecationWarning ) super().__init__(LightningDistributedModule(module), *args, **kwargs) class LightningDistributedModule(torch.nn.Module): def __init__(self, pl_module: LightningModule): """ Wraps the user's LightningModule and redirects the forward call to the appropriate method, either ``training_step``, ``validation_step`` or ```test_step``. This class is used in combination with :class:`~torch.nn.parallel.DistributedDataParallel` as shown in the example. Example: ddp_model = DistributedDataParallel( module=LightningDistributedModule(lightning_module), device_ids=[local_rank], ... ) Args: pl_module: the model to wrap """ super().__init__() self.module = pl_module def forward(self, *inputs, **kwargs): if self.module.training: output = self.module.training_step(*inputs, **kwargs) warn_if_output_is_none(output, "training_step") elif self.module.testing: output = self.module.test_step(*inputs, **kwargs) warn_if_output_is_none(output, "test_step") else: output = self.module.validation_step(*inputs, **kwargs) warn_if_output_is_none(output, "validation_step") return output # In manual_optimization, we need to call reducer prepare_for_backward. # Note: Keep track of Pytorch DDP and update if there is a change # https://github.com/pytorch/pytorch/blob/v1.7.1/torch/nn/parallel/distributed.py#L626-L638 def prepare_for_backward(model: DistributedDataParallel, output: Any): if torch.is_grad_enabled() and model.require_backward_grad_sync: model.require_forward_param_sync = True # We'll return the output object verbatim since it is a freeform # object. We need to find any tensors in this object, though, # because we need to figure out which parameters were used during # this forward pass, to ensure we short circuit reduction for any # unused parameters. Only if `find_unused_parameters` is set. if model.find_unused_parameters: model.reducer.prepare_for_backward(list(_find_tensors(output))) else: model.reducer.prepare_for_backward([]) else: model.require_forward_param_sync = False def warn_if_output_is_none(output: Any, method_name: str) -> None: if output is None: warning_cache.warn(f'Your {method_name} returned None. Did you forget to return an output?') def warn_missing_output(fx_called): if fx_called == 'training_step': warning_cache.warn("Your training_step returned None. Make sure that was your intention!") def parallel_apply( modules: Module, inputs: Tensor, kwargs_tup: Optional[tuple] = None, devices: Optional[list] = None, ): # pragma: no-cover r"""Applies each `module` in :attr:`modules` in parallel on arguments contained in :attr:`inputs` (positional) and :attr:`kwargs_tup` (keyword) on each of :attr:`devices`. Args: modules: modules to be parallelized inputs: inputs to the modules devices: CUDA devices :attr:`modules`, :attr:`inputs`, :attr:`kwargs_tup` (if given), and :attr:`devices` (if given) should all have same length. Moreover, each element of :attr:`inputs` can either be a single object as the only argument to a module, or a collection of positional arguments. """ assert len(modules) == len(inputs) if kwargs_tup is not None: assert len(modules) == len(kwargs_tup) else: kwargs_tup = ({},) * len(modules) if devices is not None: assert len(modules) == len(devices) else: devices = [None] * len(modules) devices = list(map(lambda x: _get_device_index(x, True), devices)) lock = threading.Lock() results = {} grad_enabled = torch.is_grad_enabled() def _worker(i, module, input, kwargs, device=None): torch.set_grad_enabled(grad_enabled) if device is None: device = get_a_var(input).get_device() try: with torch.cuda.device(device): # this also avoids accidental slicing of `input` if it is a Tensor if not isinstance(input, (list, tuple)): input = (input,) module = module.to(device) # --------------- # CHANGE if module.training: output = module.training_step(*input, **kwargs) fx_called = 'training_step' elif module.testing: output = module.test_step(*input, **kwargs) fx_called = 'test_step' else: output = module.validation_step(*input, **kwargs) fx_called = 'validation_step' if output is None: warn_missing_output(fx_called) if output is not None and module._distrib_type in ('dp', 'ddp2'): auto_squeeze_dim_zeros(output) # --------------- with lock: results[i] = output # todo: specify the possible exception except Exception as ex: with lock: results[i] = ex # TODO: fix hack (maybe not a hack) # make sure each module knows what training state it's in... # fixes weird bug where copies are out of sync root_m = modules[0] for m in modules[1:]: m.training = root_m.training m.testing = root_m.testing if len(modules) > 1: threads = [threading.Thread(target=_worker, args=(i, module, input, kwargs, device)) for i, (module, input, kwargs, device) in enumerate(zip(modules, inputs, kwargs_tup, devices))] for thread in threads: thread.start() for thread in threads: thread.join() else: _worker(0, modules[0], inputs[0], kwargs_tup[0], devices[0]) outputs = [] for i in range(len(inputs)): output = results[i] if isinstance(output, Exception): raise output outputs.append(output) return outputs def auto_squeeze_dim_zeros(output): """ In DP or DDP2 we need to unsqueeze dim 0 :param output: :return: """ if isinstance(output, torch.Tensor): output = output.unsqueeze(0) return output for k, v in output.items(): if not isinstance(v, torch.Tensor): continue is_scalar = v.dim() == 0 if is_scalar: output[k] = output[k].unsqueeze(0)