diff --git a/benchmarks/test_sharded_parity.py b/benchmarks/test_sharded_parity.py index 449c9d070f..b96af667e1 100644 --- a/benchmarks/test_sharded_parity.py +++ b/benchmarks/test_sharded_parity.py @@ -1,13 +1,13 @@ import os import platform -import time from unittest import mock import pytest import torch from torch.utils.data.distributed import DistributedSampler -from pytorch_lightning import Trainer, seed_everything +from benchmarks.utilities import plugin_parity_test +from pytorch_lightning import seed_everything from pytorch_lightning.plugins.sharded_plugin import DDPShardedPlugin from pytorch_lightning.utilities import FAIRSCALE_AVAILABLE, NATIVE_AMP_AVAILABLE from tests.backends.launcher import DDPLauncher @@ -19,7 +19,12 @@ from tests.base.boring_model import BoringModel, RandomDataset @pytest.mark.skipif(not FAIRSCALE_AVAILABLE, reason="Fairscale is not available") def test_ddp_sharded_plugin_correctness_one_device(): # Allow slightly slower speed due to one CPU doing additional sequential memory saving calls - run_sharded_correctness(accelerator='ddp_cpu', max_percent_speed_diff=0.5) + plugin_parity_test( + accelerator='ddp_cpu', + max_percent_speed_diff=0.5, + plugin=DDPShardedPlugin(), + model_cls=SeedTrainLoaderModel + ) @pytest.mark.skipif(not torch.cuda.is_available(), reason="requires GPU machine") @@ -27,7 +32,12 @@ def test_ddp_sharded_plugin_correctness_one_device(): reason="Distributed training is not supported on Windows") @pytest.mark.skipif(not FAIRSCALE_AVAILABLE, reason="Fairscale is not available") def test_ddp_sharded_plugin_correctness_one_gpu(): - run_sharded_correctness(gpus=1, accelerator='ddp_spawn') + plugin_parity_test( + gpus=1, + accelerator='ddp_spawn', + plugin=DDPShardedPlugin(), + model_cls=SeedTrainLoaderModel + ) @pytest.mark.skipif(not NATIVE_AMP_AVAILABLE, reason="Requires native AMP") @@ -36,7 +46,13 @@ def test_ddp_sharded_plugin_correctness_one_gpu(): reason="Distributed training is not supported on Windows") @pytest.mark.skipif(not FAIRSCALE_AVAILABLE, reason="Fairscale is not available") def test_ddp_sharded_plugin_correctness_amp_one_gpu(): - run_sharded_correctness(gpus=1, precision=16, accelerator='ddp_spawn') + plugin_parity_test( + gpus=1, + precision=16, + accelerator='ddp_spawn', + plugin=DDPShardedPlugin(), + model_cls=SeedTrainLoaderModel + ) @pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine") @@ -44,7 +60,12 @@ def test_ddp_sharded_plugin_correctness_amp_one_gpu(): reason="Distributed training is not supported on Windows") @pytest.mark.skipif(not FAIRSCALE_AVAILABLE, reason="Fairscale is not available") def test_ddp_sharded_plugin_correctness_multi_gpu(): - run_sharded_correctness(gpus=2, accelerator='ddp_spawn') + plugin_parity_test( + gpus=2, + accelerator='ddp_spawn', + plugin=DDPShardedPlugin(), + model_cls=SeedTrainLoaderModel + ) @pytest.mark.skipif(not NATIVE_AMP_AVAILABLE, reason="Requires native AMP") @@ -53,7 +74,13 @@ def test_ddp_sharded_plugin_correctness_multi_gpu(): @pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine") @pytest.mark.skipif(not FAIRSCALE_AVAILABLE, reason="Fairscale is not available") def test_ddp_sharded_plugin_correctness_amp_multi_gpu(): - run_sharded_correctness(gpus=2, precision=16, accelerator='ddp_spawn') + plugin_parity_test( + gpus=2, + precision=16, + accelerator='ddp_spawn', + plugin=DDPShardedPlugin(), + model_cls=SeedTrainLoaderModel + ) @pytest.mark.skipif(not FAIRSCALE_AVAILABLE, reason="Fairscale is not available") @@ -63,7 +90,13 @@ def test_ddp_sharded_plugin_correctness_amp_multi_gpu(): reason="test should be run outside of pytest") @DDPLauncher.run("--distributed_backend ddp --gpus 2 --precision 32") def test_ddp_sharded_plugin_correctness_multi_gpu_ddp(tmpdir, args=None): - run_sharded_correctness(gpus=args.gpus, precision=args.precision, accelerator=args.distributed_backend) + plugin_parity_test( + gpus=args.gpus, + precision=args.precision, + accelerator=args.distributed_backend, + plugin=DDPShardedPlugin(), + model_cls=SeedTrainLoaderModel + ) @pytest.mark.skipif(not FAIRSCALE_AVAILABLE, reason="Fairscale is not available") @@ -73,7 +106,13 @@ def test_ddp_sharded_plugin_correctness_multi_gpu_ddp(tmpdir, args=None): reason="test should be run outside of pytest") @DDPLauncher.run("--distributed_backend ddp --gpus 2 --precision 16") def test_ddp_sharded_plugin_correctness_amp_multi_gpu_ddp(tmpdir, args=None): - run_sharded_correctness(gpus=args.gpus, precision=args.precision, accelerator=args.distributed_backend) + plugin_parity_test( + gpus=args.gpus, + precision=args.precision, + accelerator=args.distributed_backend, + plugin=DDPShardedPlugin(), + model_cls=SeedTrainLoaderModel + ) @pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine") @@ -84,7 +123,8 @@ def test_ddp_sharded_plugin_correctness_multi_gpu_multi_optim(): """ Ensures same results using multiple optimizers across multiple GPUs """ - run_sharded_correctness( + plugin_parity_test( + plugin=DDPShardedPlugin(), gpus=2, accelerator='ddp_spawn', model_cls=SeedTrainLoaderMultipleOptimizersModel, @@ -102,7 +142,8 @@ def test_ddp_sharded_plugin_correctness_multi_gpu_multi_optim_manual(tmpdir): """ Ensures using multiple optimizers across multiple GPUs with manual optimization """ - run_sharded_correctness( + plugin_parity_test( + plugin=DDPShardedPlugin(), gpus=2, accelerator='ddp_spawn', model_cls=SeedTrainLoaderManualModel, @@ -167,108 +208,3 @@ class SeedTrainLoaderMultipleOptimizersModel(SeedTrainLoaderModel): optimizer = torch.optim.SGD(self.layer.parameters(), lr=0.1) optimizer_2 = torch.optim.SGD(self.layer.parameters(), lr=0.1) return optimizer, optimizer_2 - - -def record_ddp_fit_model_stats(trainer, model, gpus): - """ - Helper to calculate wall clock time for fit + max allocated memory. - - Args: - trainer: The trainer object. - model: The LightningModule. - gpus: Number of GPUs in test. - - Returns: - Max Memory if using GPUs, and total wall clock time. - - """ - max_memory = None - - time_start = time.perf_counter() - if gpus > 0: - torch.cuda.reset_peak_memory_stats() - torch.cuda.synchronize() - - trainer.fit(model) - - if gpus > 0: - torch.cuda.synchronize() - max_memory = torch.cuda.max_memory_allocated() / 2 ** 20 - - total_time = time.perf_counter() - time_start - - return max_memory, total_time - - -def run_sharded_correctness( - accelerator='ddp_spawn', - gpus=0, - precision=32, - max_percent_speed_diff=0.25, - model_cls=SeedTrainLoaderModel): - """ - Ensures that the trained model is identical to the standard DDP implementation. - Also checks for speed/memory regressions, we should expect always less memory but performance to fluctuate. - - Args: - accelerator: Accelerator type for test. - gpus: Number of GPUS to enable. - precision: Whether to use AMP or normal FP32 training. - max_percent_speed_diff: The maximum speed difference compared to normal DDP training. - This is more a safety net for variability in CI which can vary in speed, not for benchmarking. - model_cls: Model class to use for test. - - """ - - # Train normal DDP - seed_everything(42) - ddp_model = model_cls() - - trainer = Trainer( - fast_dev_run=True, - max_epochs=1, - gpus=gpus, - precision=precision, - accelerator=accelerator, - ) - - max_ddp_memory, ddp_time = record_ddp_fit_model_stats( - trainer=trainer, - model=ddp_model, - gpus=gpus - ) - - # Reset and train sharded DDP - seed_everything(42) - sharded_model = model_cls() - - trainer = Trainer( - fast_dev_run=True, - max_epochs=1, - gpus=gpus, - precision=precision, - accelerator=accelerator, - plugins=[DDPShardedPlugin()], - ) - - max_sharded_memory, sharded_time = record_ddp_fit_model_stats( - trainer=trainer, - model=sharded_model, - gpus=gpus - ) - - # Assert model parameters are identical after fit - for ddp_param, shard_param in zip(ddp_model.parameters(), sharded_model.parameters()): - assert torch.equal(ddp_param, shard_param), 'Model parameters are different between DDP and Sharded plugin' - - # Assert speed parity by ensuring percentage difference between sharded/ddp is below threshold - percent_diff = (sharded_time - ddp_time) / sharded_time - - assert percent_diff <= max_percent_speed_diff, \ - f'Sharded plugin was too slow compared to DDP, Sharded Time: {sharded_time}, DDP Time: {ddp_time}' - - if gpus > 0: - # Assert CUDA memory parity - assert max_sharded_memory <= max_ddp_memory, \ - f'Sharded plugin used too much memory compared to DDP,' \ - f'Sharded Mem: {max_sharded_memory}, DDP Mem: {max_ddp_memory}' diff --git a/benchmarks/utilities.py b/benchmarks/utilities.py new file mode 100644 index 0000000000..1c07d2274c --- /dev/null +++ b/benchmarks/utilities.py @@ -0,0 +1,117 @@ +import time +from typing import Callable + +import torch +from pytorch_lightning.plugins.ddp_plugin import DDPPlugin + +from pytorch_lightning import Trainer +from pytorch_lightning.utilities.seed import seed_everything + + +def record_ddp_fit_model_stats(trainer, model, use_cuda): + """ + Helper to calculate wall clock time for fit + max allocated memory. + + Args: + trainer: The trainer object. + model: The model to fit. + use_cuda: Whether to sync CUDA kernels. + + Returns: + Max Memory if using GPUs, and total wall clock time. + """ + max_memory = None + + time_start = time.perf_counter() + if use_cuda: + torch.cuda.reset_peak_memory_stats() + torch.cuda.synchronize() + + trainer.fit(model) + + if use_cuda: + torch.cuda.synchronize() + max_memory = torch.cuda.max_memory_allocated() / 2 ** 20 + + total_time = time.perf_counter() - time_start + + return max_memory, total_time + + +def plugin_parity_test( + model_cls: Callable, + plugin: DDPPlugin, + seed: int = 42, + accelerator: str = 'ddp_spawn', + gpus: int = 0, + precision: int = 32, + max_percent_speed_diff: float = 0.25): + """ + Ensures that the trained model is identical to the standard DDP implementation. + Also checks for speed/memory regressions, we should expect always less memory but performance to fluctuate. + + Args: + model_cls: Model class to use for test. + plugin: Plugin to parity test. + seed: Seed for generators. Note that this does not handle the seed for data-loading on multi-process. + accelerator: Accelerator type for test. + gpus: Number of GPUS to enable. + precision: Whether to use AMP or normal FP32 training. + max_percent_speed_diff: The maximum speed difference compared to normal DDP training. + This is more a safety net for variability in CI which can vary in speed, not for benchmarking. + + """ + + # Train normal DDP + seed_everything(seed) + ddp_model = model_cls() + use_cuda = gpus > 0 + + trainer = Trainer( + fast_dev_run=True, + max_epochs=1, + gpus=gpus, + precision=precision, + accelerator=accelerator, + ) + + max_memory_ddp, ddp_time = record_ddp_fit_model_stats( + trainer=trainer, + model=ddp_model, + use_cuda=use_cuda + ) + + # Reset and train Custom DDP + seed_everything(seed) + custom_plugin_model = model_cls() + + trainer = Trainer( + fast_dev_run=True, + max_epochs=1, + gpus=gpus, + precision=precision, + accelerator=accelerator, + plugins=[plugin], + ) + + max_memory_custom, custom_model_time = record_ddp_fit_model_stats( + trainer=trainer, + model=custom_plugin_model, + use_cuda=use_cuda + ) + + # Assert model parameters are identical after fit + for ddp_param, custom_param in zip(ddp_model.parameters(), custom_plugin_model.parameters()): + assert torch.equal(ddp_param, custom_param), 'Model parameters are different between DDP and Custom plugin' + + # Assert speed parity by ensuring percentage difference between custom/ddp is below threshold + percent_diff = (custom_model_time - ddp_time) / custom_model_time + + assert percent_diff <= max_percent_speed_diff, \ + f'Custom DDP plugin was too slow compared to DDP, Custom Plugin Time: {custom_model_time}, DDP Time: {ddp_time}' + + if use_cuda: + # Assert CUDA memory parity + assert max_memory_custom <= max_memory_ddp, \ + f'Custom plugin used too much memory compared to DDP,' \ + f'Custom Mem: {max_memory_custom}, DDP Mem: {max_memory_ddp}'