Moved common functions into utilities

This commit is contained in:
SeanNaren 2020-11-27 12:25:44 +00:00
parent bde2a12990
commit 10d41fb4ea
2 changed files with 169 additions and 116 deletions

View File

@ -1,13 +1,13 @@
import os
import platform
import time
from unittest import mock
import pytest
import torch
from torch.utils.data.distributed import DistributedSampler
from pytorch_lightning import Trainer, seed_everything
from benchmarks.utilities import plugin_parity_test
from pytorch_lightning import seed_everything
from pytorch_lightning.plugins.sharded_plugin import DDPShardedPlugin
from pytorch_lightning.utilities import FAIRSCALE_AVAILABLE, NATIVE_AMP_AVAILABLE
from tests.backends.launcher import DDPLauncher
@ -19,7 +19,12 @@ from tests.base.boring_model import BoringModel, RandomDataset
@pytest.mark.skipif(not FAIRSCALE_AVAILABLE, reason="Fairscale is not available")
def test_ddp_sharded_plugin_correctness_one_device():
# Allow slightly slower speed due to one CPU doing additional sequential memory saving calls
run_sharded_correctness(accelerator='ddp_cpu', max_percent_speed_diff=0.5)
plugin_parity_test(
accelerator='ddp_cpu',
max_percent_speed_diff=0.5,
plugin=DDPShardedPlugin(),
model_cls=SeedTrainLoaderModel
)
@pytest.mark.skipif(not torch.cuda.is_available(), reason="requires GPU machine")
@ -27,7 +32,12 @@ def test_ddp_sharded_plugin_correctness_one_device():
reason="Distributed training is not supported on Windows")
@pytest.mark.skipif(not FAIRSCALE_AVAILABLE, reason="Fairscale is not available")
def test_ddp_sharded_plugin_correctness_one_gpu():
run_sharded_correctness(gpus=1, accelerator='ddp_spawn')
plugin_parity_test(
gpus=1,
accelerator='ddp_spawn',
plugin=DDPShardedPlugin(),
model_cls=SeedTrainLoaderModel
)
@pytest.mark.skipif(not NATIVE_AMP_AVAILABLE, reason="Requires native AMP")
@ -36,7 +46,13 @@ def test_ddp_sharded_plugin_correctness_one_gpu():
reason="Distributed training is not supported on Windows")
@pytest.mark.skipif(not FAIRSCALE_AVAILABLE, reason="Fairscale is not available")
def test_ddp_sharded_plugin_correctness_amp_one_gpu():
run_sharded_correctness(gpus=1, precision=16, accelerator='ddp_spawn')
plugin_parity_test(
gpus=1,
precision=16,
accelerator='ddp_spawn',
plugin=DDPShardedPlugin(),
model_cls=SeedTrainLoaderModel
)
@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine")
@ -44,7 +60,12 @@ def test_ddp_sharded_plugin_correctness_amp_one_gpu():
reason="Distributed training is not supported on Windows")
@pytest.mark.skipif(not FAIRSCALE_AVAILABLE, reason="Fairscale is not available")
def test_ddp_sharded_plugin_correctness_multi_gpu():
run_sharded_correctness(gpus=2, accelerator='ddp_spawn')
plugin_parity_test(
gpus=2,
accelerator='ddp_spawn',
plugin=DDPShardedPlugin(),
model_cls=SeedTrainLoaderModel
)
@pytest.mark.skipif(not NATIVE_AMP_AVAILABLE, reason="Requires native AMP")
@ -53,7 +74,13 @@ def test_ddp_sharded_plugin_correctness_multi_gpu():
@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine")
@pytest.mark.skipif(not FAIRSCALE_AVAILABLE, reason="Fairscale is not available")
def test_ddp_sharded_plugin_correctness_amp_multi_gpu():
run_sharded_correctness(gpus=2, precision=16, accelerator='ddp_spawn')
plugin_parity_test(
gpus=2,
precision=16,
accelerator='ddp_spawn',
plugin=DDPShardedPlugin(),
model_cls=SeedTrainLoaderModel
)
@pytest.mark.skipif(not FAIRSCALE_AVAILABLE, reason="Fairscale is not available")
@ -63,7 +90,13 @@ def test_ddp_sharded_plugin_correctness_amp_multi_gpu():
reason="test should be run outside of pytest")
@DDPLauncher.run("--distributed_backend ddp --gpus 2 --precision 32")
def test_ddp_sharded_plugin_correctness_multi_gpu_ddp(tmpdir, args=None):
run_sharded_correctness(gpus=args.gpus, precision=args.precision, accelerator=args.distributed_backend)
plugin_parity_test(
gpus=args.gpus,
precision=args.precision,
accelerator=args.distributed_backend,
plugin=DDPShardedPlugin(),
model_cls=SeedTrainLoaderModel
)
@pytest.mark.skipif(not FAIRSCALE_AVAILABLE, reason="Fairscale is not available")
@ -73,7 +106,13 @@ def test_ddp_sharded_plugin_correctness_multi_gpu_ddp(tmpdir, args=None):
reason="test should be run outside of pytest")
@DDPLauncher.run("--distributed_backend ddp --gpus 2 --precision 16")
def test_ddp_sharded_plugin_correctness_amp_multi_gpu_ddp(tmpdir, args=None):
run_sharded_correctness(gpus=args.gpus, precision=args.precision, accelerator=args.distributed_backend)
plugin_parity_test(
gpus=args.gpus,
precision=args.precision,
accelerator=args.distributed_backend,
plugin=DDPShardedPlugin(),
model_cls=SeedTrainLoaderModel
)
@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine")
@ -84,7 +123,8 @@ def test_ddp_sharded_plugin_correctness_multi_gpu_multi_optim():
"""
Ensures same results using multiple optimizers across multiple GPUs
"""
run_sharded_correctness(
plugin_parity_test(
plugin=DDPShardedPlugin(),
gpus=2,
accelerator='ddp_spawn',
model_cls=SeedTrainLoaderMultipleOptimizersModel,
@ -102,7 +142,8 @@ def test_ddp_sharded_plugin_correctness_multi_gpu_multi_optim_manual(tmpdir):
"""
Ensures using multiple optimizers across multiple GPUs with manual optimization
"""
run_sharded_correctness(
plugin_parity_test(
plugin=DDPShardedPlugin(),
gpus=2,
accelerator='ddp_spawn',
model_cls=SeedTrainLoaderManualModel,
@ -167,108 +208,3 @@ class SeedTrainLoaderMultipleOptimizersModel(SeedTrainLoaderModel):
optimizer = torch.optim.SGD(self.layer.parameters(), lr=0.1)
optimizer_2 = torch.optim.SGD(self.layer.parameters(), lr=0.1)
return optimizer, optimizer_2
def record_ddp_fit_model_stats(trainer, model, gpus):
"""
Helper to calculate wall clock time for fit + max allocated memory.
Args:
trainer: The trainer object.
model: The LightningModule.
gpus: Number of GPUs in test.
Returns:
Max Memory if using GPUs, and total wall clock time.
"""
max_memory = None
time_start = time.perf_counter()
if gpus > 0:
torch.cuda.reset_peak_memory_stats()
torch.cuda.synchronize()
trainer.fit(model)
if gpus > 0:
torch.cuda.synchronize()
max_memory = torch.cuda.max_memory_allocated() / 2 ** 20
total_time = time.perf_counter() - time_start
return max_memory, total_time
def run_sharded_correctness(
accelerator='ddp_spawn',
gpus=0,
precision=32,
max_percent_speed_diff=0.25,
model_cls=SeedTrainLoaderModel):
"""
Ensures that the trained model is identical to the standard DDP implementation.
Also checks for speed/memory regressions, we should expect always less memory but performance to fluctuate.
Args:
accelerator: Accelerator type for test.
gpus: Number of GPUS to enable.
precision: Whether to use AMP or normal FP32 training.
max_percent_speed_diff: The maximum speed difference compared to normal DDP training.
This is more a safety net for variability in CI which can vary in speed, not for benchmarking.
model_cls: Model class to use for test.
"""
# Train normal DDP
seed_everything(42)
ddp_model = model_cls()
trainer = Trainer(
fast_dev_run=True,
max_epochs=1,
gpus=gpus,
precision=precision,
accelerator=accelerator,
)
max_ddp_memory, ddp_time = record_ddp_fit_model_stats(
trainer=trainer,
model=ddp_model,
gpus=gpus
)
# Reset and train sharded DDP
seed_everything(42)
sharded_model = model_cls()
trainer = Trainer(
fast_dev_run=True,
max_epochs=1,
gpus=gpus,
precision=precision,
accelerator=accelerator,
plugins=[DDPShardedPlugin()],
)
max_sharded_memory, sharded_time = record_ddp_fit_model_stats(
trainer=trainer,
model=sharded_model,
gpus=gpus
)
# Assert model parameters are identical after fit
for ddp_param, shard_param in zip(ddp_model.parameters(), sharded_model.parameters()):
assert torch.equal(ddp_param, shard_param), 'Model parameters are different between DDP and Sharded plugin'
# Assert speed parity by ensuring percentage difference between sharded/ddp is below threshold
percent_diff = (sharded_time - ddp_time) / sharded_time
assert percent_diff <= max_percent_speed_diff, \
f'Sharded plugin was too slow compared to DDP, Sharded Time: {sharded_time}, DDP Time: {ddp_time}'
if gpus > 0:
# Assert CUDA memory parity
assert max_sharded_memory <= max_ddp_memory, \
f'Sharded plugin used too much memory compared to DDP,' \
f'Sharded Mem: {max_sharded_memory}, DDP Mem: {max_ddp_memory}'

117
benchmarks/utilities.py Normal file
View File

@ -0,0 +1,117 @@
import time
from typing import Callable
import torch
from pytorch_lightning.plugins.ddp_plugin import DDPPlugin
from pytorch_lightning import Trainer
from pytorch_lightning.utilities.seed import seed_everything
def record_ddp_fit_model_stats(trainer, model, use_cuda):
"""
Helper to calculate wall clock time for fit + max allocated memory.
Args:
trainer: The trainer object.
model: The model to fit.
use_cuda: Whether to sync CUDA kernels.
Returns:
Max Memory if using GPUs, and total wall clock time.
"""
max_memory = None
time_start = time.perf_counter()
if use_cuda:
torch.cuda.reset_peak_memory_stats()
torch.cuda.synchronize()
trainer.fit(model)
if use_cuda:
torch.cuda.synchronize()
max_memory = torch.cuda.max_memory_allocated() / 2 ** 20
total_time = time.perf_counter() - time_start
return max_memory, total_time
def plugin_parity_test(
model_cls: Callable,
plugin: DDPPlugin,
seed: int = 42,
accelerator: str = 'ddp_spawn',
gpus: int = 0,
precision: int = 32,
max_percent_speed_diff: float = 0.25):
"""
Ensures that the trained model is identical to the standard DDP implementation.
Also checks for speed/memory regressions, we should expect always less memory but performance to fluctuate.
Args:
model_cls: Model class to use for test.
plugin: Plugin to parity test.
seed: Seed for generators. Note that this does not handle the seed for data-loading on multi-process.
accelerator: Accelerator type for test.
gpus: Number of GPUS to enable.
precision: Whether to use AMP or normal FP32 training.
max_percent_speed_diff: The maximum speed difference compared to normal DDP training.
This is more a safety net for variability in CI which can vary in speed, not for benchmarking.
"""
# Train normal DDP
seed_everything(seed)
ddp_model = model_cls()
use_cuda = gpus > 0
trainer = Trainer(
fast_dev_run=True,
max_epochs=1,
gpus=gpus,
precision=precision,
accelerator=accelerator,
)
max_memory_ddp, ddp_time = record_ddp_fit_model_stats(
trainer=trainer,
model=ddp_model,
use_cuda=use_cuda
)
# Reset and train Custom DDP
seed_everything(seed)
custom_plugin_model = model_cls()
trainer = Trainer(
fast_dev_run=True,
max_epochs=1,
gpus=gpus,
precision=precision,
accelerator=accelerator,
plugins=[plugin],
)
max_memory_custom, custom_model_time = record_ddp_fit_model_stats(
trainer=trainer,
model=custom_plugin_model,
use_cuda=use_cuda
)
# Assert model parameters are identical after fit
for ddp_param, custom_param in zip(ddp_model.parameters(), custom_plugin_model.parameters()):
assert torch.equal(ddp_param, custom_param), 'Model parameters are different between DDP and Custom plugin'
# Assert speed parity by ensuring percentage difference between custom/ddp is below threshold
percent_diff = (custom_model_time - ddp_time) / custom_model_time
assert percent_diff <= max_percent_speed_diff, \
f'Custom DDP plugin was too slow compared to DDP, Custom Plugin Time: {custom_model_time}, DDP Time: {ddp_time}'
if use_cuda:
# Assert CUDA memory parity
assert max_memory_custom <= max_memory_ddp, \
f'Custom plugin used too much memory compared to DDP,' \
f'Custom Mem: {max_memory_custom}, DDP Mem: {max_memory_ddp}'