updated args
This commit is contained in:
parent
158aca26e2
commit
0795e4d51b
|
@ -9,7 +9,6 @@ class LightningDataParallel(DataParallel):
|
|||
"""
|
||||
|
||||
def forward(self, *inputs, **kwargs):
|
||||
pdb.set_trace()
|
||||
if not self.device_ids:
|
||||
# -------------
|
||||
# MAIN CHANGE
|
||||
|
@ -27,7 +26,10 @@ class LightningDataParallel(DataParallel):
|
|||
|
||||
inputs, kwargs = self.scatter(inputs, kwargs, self.device_ids)
|
||||
if len(self.device_ids) == 1:
|
||||
return self.module(*inputs[0], **kwargs[0])
|
||||
if self.module.training:
|
||||
return self.module.training_step(*inputs[0], **kwargs[0])
|
||||
else:
|
||||
return self.module.validation_step(*inputs[0], **kwargs[0])
|
||||
replicas = self.replicate(self.module, self.device_ids[:len(inputs)])
|
||||
outputs = self.parallel_apply(replicas, inputs, kwargs)
|
||||
return self.gather(outputs, self.output_device)
|
||||
|
|
Loading…
Reference in New Issue