updated args
This commit is contained in:
parent
c54dd94295
commit
8df13035eb
|
@ -116,3 +116,6 @@ ENV/
|
||||||
|
|
||||||
# mypy
|
# mypy
|
||||||
.mypy_cache/
|
.mypy_cache/
|
||||||
|
|
||||||
|
# data
|
||||||
|
mnist/
|
|
@ -0,0 +1,31 @@
|
||||||
|
from itertools import chain
|
||||||
|
from torch.nn import DataParallel
|
||||||
|
|
||||||
|
|
||||||
|
class LightningDataParallel(DataParallel):
|
||||||
|
"""
|
||||||
|
Override the forward call in lightning so it goes to training and validation step respectively
|
||||||
|
"""
|
||||||
|
|
||||||
|
def forward(self, *inputs, **kwargs):
|
||||||
|
if not self.device_ids:
|
||||||
|
# -------------
|
||||||
|
# MAIN CHANGE
|
||||||
|
if self.module.training:
|
||||||
|
return self.module.training_step(*inputs, **kwargs)
|
||||||
|
else:
|
||||||
|
return self.module.validation_step(*inputs, **kwargs)
|
||||||
|
# -------------
|
||||||
|
|
||||||
|
for t in chain(self.module.parameters(), self.module.buffers()):
|
||||||
|
if t.device != self.src_device_obj:
|
||||||
|
raise RuntimeError("module must have its parameters and buffers "
|
||||||
|
"on device {} (device_ids[0]) but found one of "
|
||||||
|
"them on device: {}".format(self.src_device_obj, t.device))
|
||||||
|
|
||||||
|
inputs, kwargs = self.scatter(inputs, kwargs, self.device_ids)
|
||||||
|
if len(self.device_ids) == 1:
|
||||||
|
return self.module(*inputs[0], **kwargs[0])
|
||||||
|
replicas = self.replicate(self.module, self.device_ids[:len(inputs)])
|
||||||
|
outputs = self.parallel_apply(replicas, inputs, kwargs)
|
||||||
|
return self.gather(outputs, self.output_device)
|
Loading…
Reference in New Issue