Removed line, dont abs

This commit is contained in:
SeanNaren 2020-11-25 20:38:04 +00:00
parent cf7a7f7b8d
commit 9215908fed
2 changed files with 2 additions and 2 deletions

View File

@ -18,7 +18,6 @@ except (ModuleNotFoundError, ImportError):
else:
FAIRSCALE_SHARDED_AVAILABLE = True
class LightningShardedDataParallel(ShardedDataParallel):
def forward(self, *inputs, **kwargs):

View File

@ -522,7 +522,8 @@ def run_sharded_correctness(
assert torch.equal(ddp_param, shard_param), 'Model parameters are different between DDP and Sharded plugin'
# Assert speed parity by ensuring percentage difference between sharded/ddp is below threshold
percent_diff = (abs(sharded_time - ddp_time) / sharded_time)
percent_diff = (sharded_time - ddp_time) / sharded_time
assert percent_diff <= max_percent_speed_diff, \
f'Sharded plugin was too slow compared to DDP, Sharded Time: {sharded_time}, DDP Time: {ddp_time}'