diff --git a/pytorch_lightning/overrides/fairscale.py b/pytorch_lightning/overrides/fairscale.py index 05ec35c1cb..1776667334 100644 --- a/pytorch_lightning/overrides/fairscale.py +++ b/pytorch_lightning/overrides/fairscale.py @@ -16,7 +16,6 @@ from pytorch_lightning.utilities import _module_available, NATIVE_AMP_AVALAIBLE if _module_available('fairscale.nn.data_parallel.sharded_ddp') and NATIVE_AMP_AVALAIBLE: from fairscale.nn.data_parallel.sharded_ddp import ShardedDataParallel - class LightningShardedDataParallel(ShardedDataParallel): def forward(self, *inputs, **kwargs): @@ -31,7 +30,6 @@ if _module_available('fairscale.nn.data_parallel.sharded_ddp') and NATIVE_AMP_AV outputs = self.module.validation_step(*inputs, **kwargs) return outputs - FAIRSCALE_SHARDED_AVAILABLE = True else: FAIRSCALE_SHARDED_AVAILABLE = False