From 4ae6f0969a582a2fdcabbc03442832eb99f071e8 Mon Sep 17 00:00:00 2001 From: SeanNaren Date: Sun, 22 Nov 2020 13:59:02 +0000 Subject: [PATCH] Fixed reference --- pytorch_lightning/overrides/fairscale.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/pytorch_lightning/overrides/fairscale.py b/pytorch_lightning/overrides/fairscale.py index 9691465103..73d9a6e6fb 100644 --- a/pytorch_lightning/overrides/fairscale.py +++ b/pytorch_lightning/overrides/fairscale.py @@ -21,10 +21,10 @@ class LightningShardedDataParallel(ShardedDataParallel): if self.enable_broadcast_buffers: self.sync_buffers() - if self.base_model.training: - outputs = self.base_model.training_step(*inputs, **kwargs) - elif self.base_model.testing: - outputs = self.base_model.test_step(*inputs, **kwargs) + if self.module.training: + outputs = self.module.training_step(*inputs, **kwargs) + elif self.module.testing: + outputs = self.module.test_step(*inputs, **kwargs) else: - outputs = self.base_model.validation_step(*inputs, **kwargs) + outputs = self.module.validation_step(*inputs, **kwargs) return outputs