clean up dead code

This commit is contained in:
William Falcon 2019-07-03 16:46:14 -04:00
parent cd0d294236
commit 5bdad8a7b8
1 changed files with 11 additions and 5 deletions

View File

@ -55,19 +55,25 @@ class LightningDistributedDataParallel(DistributedDataParallel):
return parallel_apply(replicas, inputs, kwargs, self.device_ids[:len(replicas)])
def forward(self, *inputs, **kwargs):
print('in forward')
self._sync_params()
if self.device_ids:
inputs, kwargs = self.scatter(inputs, kwargs, self.device_ids)
if len(self.device_ids) == 1:
print('a')
output = self.module(*inputs[0], **kwargs[0])
# --------------
# LIGHTNING MOD
# --------------
# normal
# output = self.module(*inputs[0], **kwargs[0])
# lightning
if self.module.training:
output = self.module.training_step(*inputs[0], **kwargs[0])
else:
output = self.module.validation_step(*inputs[0], **kwargs[0])
else:
print('b')
outputs = self.parallel_apply(self._module_copies[:len(inputs)], inputs, kwargs)
output = self.gather(outputs, self.output_device)
else:
print('c')
output = self.module(*inputs, **kwargs)
if torch.is_grad_enabled():