Fixed reference

This commit is contained in:
SeanNaren 2020-11-22 13:59:02 +00:00
parent 1e429bae58
commit 4ae6f0969a
1 changed files with 5 additions and 5 deletions

View File

@ -21,10 +21,10 @@ class LightningShardedDataParallel(ShardedDataParallel):
if self.enable_broadcast_buffers: if self.enable_broadcast_buffers:
self.sync_buffers() self.sync_buffers()
if self.base_model.training: if self.module.training:
outputs = self.base_model.training_step(*inputs, **kwargs) outputs = self.module.training_step(*inputs, **kwargs)
elif self.base_model.testing: elif self.module.testing:
outputs = self.base_model.test_step(*inputs, **kwargs) outputs = self.module.test_step(*inputs, **kwargs)
else: else:
outputs = self.base_model.validation_step(*inputs, **kwargs) outputs = self.module.validation_step(*inputs, **kwargs)
return outputs return outputs