diff --git a/pytorch_lightning/pt_overrides/override_data_parallel.py b/pytorch_lightning/pt_overrides/override_data_parallel.py index 8d6b89fe03..98cc0ed871 100644 --- a/pytorch_lightning/pt_overrides/override_data_parallel.py +++ b/pytorch_lightning/pt_overrides/override_data_parallel.py @@ -2,34 +2,101 @@ from itertools import chain from torch.nn import DataParallel import pdb +import threading +import torch +from torch.cuda._utils import _get_device_index + + +def get_a_var(obj): + if isinstance(obj, torch.Tensor): + return obj + + if isinstance(obj, list) or isinstance(obj, tuple): + for result in map(get_a_var, obj): + if isinstance(result, torch.Tensor): + return result + if isinstance(obj, dict): + for result in map(get_a_var, obj.items()): + if isinstance(result, torch.Tensor): + return result + return None + class LightningDataParallel(DataParallel): """ Override the forward call in lightning so it goes to training and validation step respectively """ - def forward(self, *inputs, **kwargs): - if not self.device_ids: - # ------------- - # MAIN CHANGE - if self.module.training: - return self.module.training_step(*inputs, **kwargs) - else: - return self.module.validation_step(*inputs, **kwargs) - # ------------- + def parallel_apply(self, replicas, inputs, kwargs): + return parallel_apply(replicas, inputs, kwargs, self.device_ids[:len(replicas)]) - for t in chain(self.module.parameters(), self.module.buffers()): - if t.device != self.src_device_obj: - raise RuntimeError("module must have its parameters and buffers " - "on device {} (device_ids[0]) but found one of " - "them on device: {}".format(self.src_device_obj, t.device)) - inputs, kwargs = self.scatter(inputs, kwargs, self.device_ids) - if len(self.device_ids) == 1: - if self.module.training: - return self.module.training_step(*inputs[0], **kwargs[0]) - else: - return self.module.validation_step(*inputs[0], **kwargs[0]) - replicas = self.replicate(self.module, self.device_ids[:len(inputs)]) - outputs = self.parallel_apply(replicas, inputs, kwargs) - return self.gather(outputs, self.output_device) +def parallel_apply(modules, inputs, kwargs_tup=None, devices=None): + r"""Applies each `module` in :attr:`modules` in parallel on arguments + contained in :attr:`inputs` (positional) and :attr:`kwargs_tup` (keyword) + on each of :attr:`devices`. + + Args: + modules (Module): modules to be parallelized + inputs (tensor): inputs to the modules + devices (list of int or torch.device): CUDA devices + + :attr:`modules`, :attr:`inputs`, :attr:`kwargs_tup` (if given), and + :attr:`devices` (if given) should all have same length. Moreover, each + element of :attr:`inputs` can either be a single object as the only argument + to a module, or a collection of positional arguments. + """ + assert len(modules) == len(inputs) + if kwargs_tup is not None: + assert len(modules) == len(kwargs_tup) + else: + kwargs_tup = ({},) * len(modules) + if devices is not None: + assert len(modules) == len(devices) + else: + devices = [None] * len(modules) + devices = list(map(lambda x: _get_device_index(x, True), devices)) + lock = threading.Lock() + results = {} + grad_enabled = torch.is_grad_enabled() + + def _worker(i, module, input, kwargs, device=None): + torch.set_grad_enabled(grad_enabled) + if device is None: + device = get_a_var(input).get_device() + try: + with torch.cuda.device(device): + # this also avoids accidental slicing of `input` if it is a Tensor + if not isinstance(input, (list, tuple)): + input = (input,) + + if module.training: + return module.training_step(*input, **kwargs) + else: + return module.validation_step(*input, **kwargs) + with lock: + results[i] = output + except Exception as e: + with lock: + results[i] = e + + if len(modules) > 1: + threads = [threading.Thread(target=_worker, + args=(i, module, input, kwargs, device)) + for i, (module, input, kwargs, device) in + enumerate(zip(modules, inputs, kwargs_tup, devices))] + + for thread in threads: + thread.start() + for thread in threads: + thread.join() + else: + _worker(0, modules[0], inputs[0], kwargs_tup[0], devices[0]) + + outputs = [] + for i in range(len(inputs)): + output = results[i] + if isinstance(output, Exception): + raise output + outputs.append(output) + return outputs \ No newline at end of file