updated args

This commit is contained in:
William Falcon 2019-06-25 19:42:15 -04:00
parent c54dd94295
commit 8df13035eb
3 changed files with 34 additions and 0 deletions

3
.gitignore vendored
View File

@ -116,3 +116,6 @@ ENV/
# mypy # mypy
.mypy_cache/ .mypy_cache/
# data
mnist/

View File

@ -0,0 +1,31 @@
from itertools import chain
from torch.nn import DataParallel
class LightningDataParallel(DataParallel):
"""
Override the forward call in lightning so it goes to training and validation step respectively
"""
def forward(self, *inputs, **kwargs):
if not self.device_ids:
# -------------
# MAIN CHANGE
if self.module.training:
return self.module.training_step(*inputs, **kwargs)
else:
return self.module.validation_step(*inputs, **kwargs)
# -------------
for t in chain(self.module.parameters(), self.module.buffers()):
if t.device != self.src_device_obj:
raise RuntimeError("module must have its parameters and buffers "
"on device {} (device_ids[0]) but found one of "
"them on device: {}".format(self.src_device_obj, t.device))
inputs, kwargs = self.scatter(inputs, kwargs, self.device_ids)
if len(self.device_ids) == 1:
return self.module(*inputs[0], **kwargs[0])
replicas = self.replicate(self.module, self.device_ids[:len(inputs)])
outputs = self.parallel_apply(replicas, inputs, kwargs)
return self.gather(outputs, self.output_device)