diff --git a/pytorch_lightning/overrides/fairscale.py b/pytorch_lightning/overrides/fairscale.py index 4ab1933e4a..8d105fe896 100644 --- a/pytorch_lightning/overrides/fairscale.py +++ b/pytorch_lightning/overrides/fairscale.py @@ -18,7 +18,6 @@ except (ModuleNotFoundError, ImportError): else: FAIRSCALE_SHARDED_AVAILABLE = True - class LightningShardedDataParallel(ShardedDataParallel): def forward(self, *inputs, **kwargs): diff --git a/tests/plugins/test_sharded_plugin.py b/tests/plugins/test_sharded_plugin.py index eb6bd56aff..8a0156d93c 100644 --- a/tests/plugins/test_sharded_plugin.py +++ b/tests/plugins/test_sharded_plugin.py @@ -522,7 +522,8 @@ def run_sharded_correctness( assert torch.equal(ddp_param, shard_param), 'Model parameters are different between DDP and Sharded plugin' # Assert speed parity by ensuring percentage difference between sharded/ddp is below threshold - percent_diff = (abs(sharded_time - ddp_time) / sharded_time) + percent_diff = (sharded_time - ddp_time) / sharded_time + assert percent_diff <= max_percent_speed_diff, \ f'Sharded plugin was too slow compared to DDP, Sharded Time: {sharded_time}, DDP Time: {ddp_time}'