Fixed reference
This commit is contained in:
parent
1e429bae58
commit
4ae6f0969a
|
@ -21,10 +21,10 @@ class LightningShardedDataParallel(ShardedDataParallel):
|
||||||
if self.enable_broadcast_buffers:
|
if self.enable_broadcast_buffers:
|
||||||
self.sync_buffers()
|
self.sync_buffers()
|
||||||
|
|
||||||
if self.base_model.training:
|
if self.module.training:
|
||||||
outputs = self.base_model.training_step(*inputs, **kwargs)
|
outputs = self.module.training_step(*inputs, **kwargs)
|
||||||
elif self.base_model.testing:
|
elif self.module.testing:
|
||||||
outputs = self.base_model.test_step(*inputs, **kwargs)
|
outputs = self.module.test_step(*inputs, **kwargs)
|
||||||
else:
|
else:
|
||||||
outputs = self.base_model.validation_step(*inputs, **kwargs)
|
outputs = self.module.validation_step(*inputs, **kwargs)
|
||||||
return outputs
|
return outputs
|
||||||
|
|
Loading…
Reference in New Issue