lightning/pytorch_lightning/pt_overrides/override_data_parallel.py

188 lines
6.9 KiB
Python
Raw Normal View History

2019-06-25 23:42:15 +00:00
from torch.nn import DataParallel
2019-07-03 19:09:49 +00:00
from torch.nn.parallel import DistributedDataParallel
2019-07-03 20:43:05 +00:00
import itertools
2019-07-18 15:39:06 +00:00
from itertools import chain
2019-06-25 23:42:15 +00:00
2019-06-25 23:52:26 +00:00
import threading
import torch
from torch.cuda._utils import _get_device_index
2019-06-26 00:03:27 +00:00
import pdb
2019-06-25 23:52:26 +00:00
2019-07-24 22:32:48 +00:00
def _find_tensors(obj): # pragma: no cover
2019-07-03 20:43:05 +00:00
r"""
Recursively find all tensors contained in the specified object.
"""
if isinstance(obj, torch.Tensor):
return [obj]
if isinstance(obj, (list, tuple)):
return itertools.chain(*map(_find_tensors, obj))
if isinstance(obj, dict):
return itertools.chain(*map(_find_tensors, obj.values()))
return []
2019-07-24 22:32:48 +00:00
def get_a_var(obj): # pragma: no cover
2019-06-25 23:52:26 +00:00
if isinstance(obj, torch.Tensor):
return obj
if isinstance(obj, list) or isinstance(obj, tuple):
for result in map(get_a_var, obj):
if isinstance(result, torch.Tensor):
return result
if isinstance(obj, dict):
for result in map(get_a_var, obj.items()):
if isinstance(result, torch.Tensor):
return result
return None
2019-06-25 23:42:15 +00:00
class LightningDataParallel(DataParallel):
"""
Override the forward call in lightning so it goes to training and validation step respectively
"""
2019-07-18 15:39:06 +00:00
def forward(self, *inputs, **kwargs):
if not self.device_ids:
return self.module(*inputs, **kwargs)
for t in chain(self.module.parameters(), self.module.buffers()):
if t.device != self.src_device_obj:
raise RuntimeError("module must have its parameters and buffers "
"on device {} (device_ids[0]) but found one of "
"them on device: {}".format(self.src_device_obj, t.device))
inputs, kwargs = self.scatter(inputs, kwargs, self.device_ids)
if len(self.device_ids) == 1:
# lightning
if self.module.training:
return self.module.training_step(*inputs[0], **kwargs[0])
else:
return self.module.validation_step(*inputs[0], **kwargs[0])
replicas = self.replicate(self.module, self.device_ids[:len(inputs)])
outputs = self.parallel_apply(replicas, inputs, kwargs)
return self.gather(outputs, self.output_device)
2019-06-25 23:52:26 +00:00
def parallel_apply(self, replicas, inputs, kwargs):
return parallel_apply(replicas, inputs, kwargs, self.device_ids[:len(replicas)])
2019-07-03 20:44:18 +00:00
class LightningDistributedDataParallel(DistributedDataParallel):
"""
Override the forward call in lightning so it goes to training and validation step respectively
"""
def parallel_apply(self, replicas, inputs, kwargs):
return parallel_apply(replicas, inputs, kwargs, self.device_ids[:len(replicas)])
2019-07-24 23:18:23 +00:00
def forward(self, *inputs, **kwargs): # pragma: no cover
2019-07-03 20:43:05 +00:00
self._sync_params()
if self.device_ids:
inputs, kwargs = self.scatter(inputs, kwargs, self.device_ids)
if len(self.device_ids) == 1:
2019-07-03 20:46:14 +00:00
# --------------
# LIGHTNING MOD
# --------------
# normal
# output = self.module(*inputs[0], **kwargs[0])
# lightning
if self.module.training:
output = self.module.training_step(*inputs[0], **kwargs[0])
else:
output = self.module.validation_step(*inputs[0], **kwargs[0])
2019-07-03 20:43:05 +00:00
else:
outputs = self.parallel_apply(self._module_copies[:len(inputs)], inputs, kwargs)
output = self.gather(outputs, self.output_device)
else:
output = self.module(*inputs, **kwargs)
if torch.is_grad_enabled():
# We'll return the output object verbatim since it is a freeform
# object. We need to find any tensors in this object, though,
# because we need to figure out which parameters were used during
# this forward pass, to ensure we short circuit reduction for any
# unused parameters. Only if `find_unused_parameters` is set.
if self.find_unused_parameters:
self.reducer.prepare_for_backward(list(_find_tensors(output)))
else:
self.reducer.prepare_for_backward([])
return output
2019-06-25 23:52:26 +00:00
2019-07-24 23:29:51 +00:00
def parallel_apply(modules, inputs, kwargs_tup=None, devices=None): # pragma: no cover
2019-06-25 23:52:26 +00:00
r"""Applies each `module` in :attr:`modules` in parallel on arguments
contained in :attr:`inputs` (positional) and :attr:`kwargs_tup` (keyword)
on each of :attr:`devices`.
Args:
modules (Module): modules to be parallelized
inputs (tensor): inputs to the modules
devices (list of int or torch.device): CUDA devices
:attr:`modules`, :attr:`inputs`, :attr:`kwargs_tup` (if given), and
:attr:`devices` (if given) should all have same length. Moreover, each
element of :attr:`inputs` can either be a single object as the only argument
to a module, or a collection of positional arguments.
"""
assert len(modules) == len(inputs)
if kwargs_tup is not None:
assert len(modules) == len(kwargs_tup)
else:
kwargs_tup = ({},) * len(modules)
if devices is not None:
assert len(modules) == len(devices)
else:
devices = [None] * len(modules)
devices = list(map(lambda x: _get_device_index(x, True), devices))
lock = threading.Lock()
results = {}
grad_enabled = torch.is_grad_enabled()
def _worker(i, module, input, kwargs, device=None):
torch.set_grad_enabled(grad_enabled)
if device is None:
device = get_a_var(input).get_device()
try:
with torch.cuda.device(device):
# this also avoids accidental slicing of `input` if it is a Tensor
if not isinstance(input, (list, tuple)):
input = (input,)
2019-06-25 23:54:28 +00:00
# ---------------
# CHANGE
2019-06-25 23:52:26 +00:00
if module.training:
2019-06-26 00:12:41 +00:00
output = module.training_step(*input, **kwargs)
2019-06-25 23:52:26 +00:00
else:
2019-06-26 00:12:41 +00:00
output = module.validation_step(*input, **kwargs)
2019-06-25 23:54:28 +00:00
# ---------------
2019-06-25 23:52:26 +00:00
with lock:
results[i] = output
except Exception as e:
with lock:
results[i] = e
if len(modules) > 1:
threads = [threading.Thread(target=_worker,
args=(i, module, input, kwargs, device))
for i, (module, input, kwargs, device) in
enumerate(zip(modules, inputs, kwargs_tup, devices))]
for thread in threads:
thread.start()
for thread in threads:
thread.join()
else:
_worker(0, modules[0], inputs[0], kwargs_tup[0], devices[0])
outputs = []
for i in range(len(inputs)):
output = results[i]
if isinstance(output, Exception):
raise output
outputs.append(output)
2019-07-03 20:43:05 +00:00
return outputs