Fixed reference
This commit is contained in:
parent
1e429bae58
commit
4ae6f0969a
|
@ -21,10 +21,10 @@ class LightningShardedDataParallel(ShardedDataParallel):
|
|||
if self.enable_broadcast_buffers:
|
||||
self.sync_buffers()
|
||||
|
||||
if self.base_model.training:
|
||||
outputs = self.base_model.training_step(*inputs, **kwargs)
|
||||
elif self.base_model.testing:
|
||||
outputs = self.base_model.test_step(*inputs, **kwargs)
|
||||
if self.module.training:
|
||||
outputs = self.module.training_step(*inputs, **kwargs)
|
||||
elif self.module.testing:
|
||||
outputs = self.module.test_step(*inputs, **kwargs)
|
||||
else:
|
||||
outputs = self.base_model.validation_step(*inputs, **kwargs)
|
||||
outputs = self.module.validation_step(*inputs, **kwargs)
|
||||
return outputs
|
||||
|
|
Loading…
Reference in New Issue