Removed line, dont abs
This commit is contained in:
parent
cf7a7f7b8d
commit
9215908fed
|
@ -18,7 +18,6 @@ except (ModuleNotFoundError, ImportError):
|
|||
else:
|
||||
FAIRSCALE_SHARDED_AVAILABLE = True
|
||||
|
||||
|
||||
class LightningShardedDataParallel(ShardedDataParallel):
|
||||
|
||||
def forward(self, *inputs, **kwargs):
|
||||
|
|
|
@ -522,7 +522,8 @@ def run_sharded_correctness(
|
|||
assert torch.equal(ddp_param, shard_param), 'Model parameters are different between DDP and Sharded plugin'
|
||||
|
||||
# Assert speed parity by ensuring percentage difference between sharded/ddp is below threshold
|
||||
percent_diff = (abs(sharded_time - ddp_time) / sharded_time)
|
||||
percent_diff = (sharded_time - ddp_time) / sharded_time
|
||||
|
||||
assert percent_diff <= max_percent_speed_diff, \
|
||||
f'Sharded plugin was too slow compared to DDP, Sharded Time: {sharded_time}, DDP Time: {ddp_time}'
|
||||
|
||||
|
|
Loading…
Reference in New Issue