updated args

This commit is contained in:
William Falcon 2019-06-25 19:46:49 -04:00
parent 158aca26e2
commit 0795e4d51b
1 changed files with 4 additions and 2 deletions

View File

@ -9,7 +9,6 @@ class LightningDataParallel(DataParallel):
"""
def forward(self, *inputs, **kwargs):
pdb.set_trace()
if not self.device_ids:
# -------------
# MAIN CHANGE
@ -27,7 +26,10 @@ class LightningDataParallel(DataParallel):
inputs, kwargs = self.scatter(inputs, kwargs, self.device_ids)
if len(self.device_ids) == 1:
return self.module(*inputs[0], **kwargs[0])
if self.module.training:
return self.module.training_step(*inputs[0], **kwargs[0])
else:
return self.module.validation_step(*inputs[0], **kwargs[0])
replicas = self.replicate(self.module, self.device_ids[:len(inputs)])
outputs = self.parallel_apply(replicas, inputs, kwargs)
return self.gather(outputs, self.output_device)