updated args
This commit is contained in:
parent
0795e4d51b
commit
c941649532
|
@ -2,34 +2,101 @@ from itertools import chain
|
|||
from torch.nn import DataParallel
|
||||
import pdb
|
||||
|
||||
import threading
|
||||
import torch
|
||||
from torch.cuda._utils import _get_device_index
|
||||
|
||||
|
||||
def get_a_var(obj):
|
||||
if isinstance(obj, torch.Tensor):
|
||||
return obj
|
||||
|
||||
if isinstance(obj, list) or isinstance(obj, tuple):
|
||||
for result in map(get_a_var, obj):
|
||||
if isinstance(result, torch.Tensor):
|
||||
return result
|
||||
if isinstance(obj, dict):
|
||||
for result in map(get_a_var, obj.items()):
|
||||
if isinstance(result, torch.Tensor):
|
||||
return result
|
||||
return None
|
||||
|
||||
|
||||
class LightningDataParallel(DataParallel):
|
||||
"""
|
||||
Override the forward call in lightning so it goes to training and validation step respectively
|
||||
"""
|
||||
|
||||
def forward(self, *inputs, **kwargs):
|
||||
if not self.device_ids:
|
||||
# -------------
|
||||
# MAIN CHANGE
|
||||
if self.module.training:
|
||||
return self.module.training_step(*inputs, **kwargs)
|
||||
else:
|
||||
return self.module.validation_step(*inputs, **kwargs)
|
||||
# -------------
|
||||
def parallel_apply(self, replicas, inputs, kwargs):
|
||||
return parallel_apply(replicas, inputs, kwargs, self.device_ids[:len(replicas)])
|
||||
|
||||
for t in chain(self.module.parameters(), self.module.buffers()):
|
||||
if t.device != self.src_device_obj:
|
||||
raise RuntimeError("module must have its parameters and buffers "
|
||||
"on device {} (device_ids[0]) but found one of "
|
||||
"them on device: {}".format(self.src_device_obj, t.device))
|
||||
|
||||
inputs, kwargs = self.scatter(inputs, kwargs, self.device_ids)
|
||||
if len(self.device_ids) == 1:
|
||||
if self.module.training:
|
||||
return self.module.training_step(*inputs[0], **kwargs[0])
|
||||
else:
|
||||
return self.module.validation_step(*inputs[0], **kwargs[0])
|
||||
replicas = self.replicate(self.module, self.device_ids[:len(inputs)])
|
||||
outputs = self.parallel_apply(replicas, inputs, kwargs)
|
||||
return self.gather(outputs, self.output_device)
|
||||
def parallel_apply(modules, inputs, kwargs_tup=None, devices=None):
|
||||
r"""Applies each `module` in :attr:`modules` in parallel on arguments
|
||||
contained in :attr:`inputs` (positional) and :attr:`kwargs_tup` (keyword)
|
||||
on each of :attr:`devices`.
|
||||
|
||||
Args:
|
||||
modules (Module): modules to be parallelized
|
||||
inputs (tensor): inputs to the modules
|
||||
devices (list of int or torch.device): CUDA devices
|
||||
|
||||
:attr:`modules`, :attr:`inputs`, :attr:`kwargs_tup` (if given), and
|
||||
:attr:`devices` (if given) should all have same length. Moreover, each
|
||||
element of :attr:`inputs` can either be a single object as the only argument
|
||||
to a module, or a collection of positional arguments.
|
||||
"""
|
||||
assert len(modules) == len(inputs)
|
||||
if kwargs_tup is not None:
|
||||
assert len(modules) == len(kwargs_tup)
|
||||
else:
|
||||
kwargs_tup = ({},) * len(modules)
|
||||
if devices is not None:
|
||||
assert len(modules) == len(devices)
|
||||
else:
|
||||
devices = [None] * len(modules)
|
||||
devices = list(map(lambda x: _get_device_index(x, True), devices))
|
||||
lock = threading.Lock()
|
||||
results = {}
|
||||
grad_enabled = torch.is_grad_enabled()
|
||||
|
||||
def _worker(i, module, input, kwargs, device=None):
|
||||
torch.set_grad_enabled(grad_enabled)
|
||||
if device is None:
|
||||
device = get_a_var(input).get_device()
|
||||
try:
|
||||
with torch.cuda.device(device):
|
||||
# this also avoids accidental slicing of `input` if it is a Tensor
|
||||
if not isinstance(input, (list, tuple)):
|
||||
input = (input,)
|
||||
|
||||
if module.training:
|
||||
return module.training_step(*input, **kwargs)
|
||||
else:
|
||||
return module.validation_step(*input, **kwargs)
|
||||
with lock:
|
||||
results[i] = output
|
||||
except Exception as e:
|
||||
with lock:
|
||||
results[i] = e
|
||||
|
||||
if len(modules) > 1:
|
||||
threads = [threading.Thread(target=_worker,
|
||||
args=(i, module, input, kwargs, device))
|
||||
for i, (module, input, kwargs, device) in
|
||||
enumerate(zip(modules, inputs, kwargs_tup, devices))]
|
||||
|
||||
for thread in threads:
|
||||
thread.start()
|
||||
for thread in threads:
|
||||
thread.join()
|
||||
else:
|
||||
_worker(0, modules[0], inputs[0], kwargs_tup[0], devices[0])
|
||||
|
||||
outputs = []
|
||||
for i in range(len(inputs)):
|
||||
output = results[i]
|
||||
if isinstance(output, Exception):
|
||||
raise output
|
||||
outputs.append(output)
|
||||
return outputs
|
Loading…
Reference in New Issue