clean up dead code
This commit is contained in:
parent
cd0d294236
commit
5bdad8a7b8
|
@ -55,19 +55,25 @@ class LightningDistributedDataParallel(DistributedDataParallel):
|
|||
return parallel_apply(replicas, inputs, kwargs, self.device_ids[:len(replicas)])
|
||||
|
||||
def forward(self, *inputs, **kwargs):
|
||||
print('in forward')
|
||||
self._sync_params()
|
||||
if self.device_ids:
|
||||
inputs, kwargs = self.scatter(inputs, kwargs, self.device_ids)
|
||||
if len(self.device_ids) == 1:
|
||||
print('a')
|
||||
output = self.module(*inputs[0], **kwargs[0])
|
||||
# --------------
|
||||
# LIGHTNING MOD
|
||||
# --------------
|
||||
# normal
|
||||
# output = self.module(*inputs[0], **kwargs[0])
|
||||
|
||||
# lightning
|
||||
if self.module.training:
|
||||
output = self.module.training_step(*inputs[0], **kwargs[0])
|
||||
else:
|
||||
output = self.module.validation_step(*inputs[0], **kwargs[0])
|
||||
else:
|
||||
print('b')
|
||||
outputs = self.parallel_apply(self._module_copies[:len(inputs)], inputs, kwargs)
|
||||
output = self.gather(outputs, self.output_device)
|
||||
else:
|
||||
print('c')
|
||||
output = self.module(*inputs, **kwargs)
|
||||
|
||||
if torch.is_grad_enabled():
|
||||
|
|
Loading…
Reference in New Issue