diff --git a/benchmarks/test_sharded_parity.py b/benchmarks/test_sharded_parity.py index a5fd6c2eff..231556079e 100644 --- a/benchmarks/test_sharded_parity.py +++ b/benchmarks/test_sharded_parity.py @@ -13,7 +13,6 @@ # limitations under the License. import os -import platform import time from typing import Type @@ -22,14 +21,12 @@ import torch from pytorch_lightning import seed_everything, Trainer from pytorch_lightning.plugins import DDPSpawnShardedPlugin -from pytorch_lightning.utilities import _FAIRSCALE_AVAILABLE, _NATIVE_AMP_AVAILABLE from tests.accelerators import DDPLauncher from tests.helpers.boring_model import BoringModel, RandomDataset +from tests.helpers.runif import RunIf -@pytest.mark.skipif(not torch.cuda.is_available(), reason="requires GPU machine") -@pytest.mark.skipif(platform.system() == "Windows", reason="Distributed training is not supported on Windows") -@pytest.mark.skipif(not _FAIRSCALE_AVAILABLE, reason="Fairscale is not available") +@RunIf(min_gpus=1, skip_windows=True, fairscale=True) def test_ddp_sharded_plugin_correctness_one_gpu(): plugin_parity_test( gpus=1, @@ -37,10 +34,7 @@ def test_ddp_sharded_plugin_correctness_one_gpu(): ) -@pytest.mark.skipif(not _NATIVE_AMP_AVAILABLE, reason="Requires native AMP") -@pytest.mark.skipif(not torch.cuda.is_available(), reason="requires GPU machine") -@pytest.mark.skipif(platform.system() == "Windows", reason="Distributed training is not supported on Windows") -@pytest.mark.skipif(not _FAIRSCALE_AVAILABLE, reason="Fairscale is not available") +@RunIf(min_gpus=1, skip_windows=True, fairscale=True, amp_native=True) def test_ddp_sharded_plugin_correctness_amp_one_gpu(): plugin_parity_test( gpus=1, @@ -50,9 +44,7 @@ def test_ddp_sharded_plugin_correctness_amp_one_gpu(): @pytest.mark.skip(reason="Not a critical test, skip till drone CI performance improves.") -@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine") -@pytest.mark.skipif(platform.system() == "Windows", reason="Distributed training is not supported on Windows") -@pytest.mark.skipif(not _FAIRSCALE_AVAILABLE, reason="Fairscale is not available") +@RunIf(min_gpus=2, skip_windows=True, fairscale=True) def test_ddp_sharded_plugin_correctness_multi_gpu(): plugin_parity_test( gpus=2, @@ -61,10 +53,7 @@ def test_ddp_sharded_plugin_correctness_multi_gpu(): ) -@pytest.mark.skipif(not _NATIVE_AMP_AVAILABLE, reason="Requires native AMP") -@pytest.mark.skipif(platform.system() == "Windows", reason="Distributed training is not supported on Windows") -@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine") -@pytest.mark.skipif(not _FAIRSCALE_AVAILABLE, reason="Fairscale is not available") +@RunIf(min_gpus=2, skip_windows=True, fairscale=True, amp_native=True) def test_ddp_sharded_plugin_correctness_amp_multi_gpu(): plugin_parity_test( gpus=2, @@ -74,10 +63,7 @@ def test_ddp_sharded_plugin_correctness_amp_multi_gpu(): ) -@pytest.mark.skipif(not _NATIVE_AMP_AVAILABLE, reason="Requires native AMP") -@pytest.mark.skipif(platform.system() == "Windows", reason="Distributed training is not supported on Windows") -@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine") -@pytest.mark.skipif(not _FAIRSCALE_AVAILABLE, reason="Fairscale is not available") +@RunIf(min_gpus=2, skip_windows=True, fairscale=True, amp_native=True) def test_ddp_string_sharded_plugin_correctness_amp_multi_gpu(): plugin_parity_test( gpus=2, @@ -87,8 +73,7 @@ def test_ddp_string_sharded_plugin_correctness_amp_multi_gpu(): ) -@pytest.mark.skipif(not _FAIRSCALE_AVAILABLE, reason="Fairscale is not available") -@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine") +@RunIf(min_gpus=2, fairscale=True) @pytest.mark.skipif( not os.getenv("PL_RUNNING_SPECIAL_TESTS", '0') == '1', reason="test should be run outside of pytest" ) @@ -101,8 +86,7 @@ def test_ddp_sharded_plugin_correctness_multi_gpu_ddp(tmpdir, args=None): ) -@pytest.mark.skipif(not _FAIRSCALE_AVAILABLE, reason="Fairscale is not available") -@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine") +@RunIf(min_gpus=2, fairscale=True) @pytest.mark.skipif( not os.getenv("PL_RUNNING_SPECIAL_TESTS", '0') == '1', reason="test should be run outside of pytest" ) @@ -116,9 +100,7 @@ def test_ddp_sharded_plugin_correctness_amp_multi_gpu_ddp(tmpdir, args=None): @pytest.mark.skip(reason="Current issue with multiple optimizers and FairScale.") -@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine") -@pytest.mark.skipif(platform.system() == "Windows", reason="Distributed training is not supported on Windows") -@pytest.mark.skipif(not _FAIRSCALE_AVAILABLE, reason="Fairscale is not available") +@RunIf(min_gpus=2, skip_windows=True, fairscale=True) def test_ddp_sharded_plugin_correctness_multi_gpu_multi_optim(): """ Ensures same results using multiple optimizers across multiple GPUs @@ -131,9 +113,7 @@ def test_ddp_sharded_plugin_correctness_multi_gpu_multi_optim(): @pytest.mark.skip(reason="Current issue with multiple optimizers and FairScale.") -@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine") -@pytest.mark.skipif(platform.system() == "Windows", reason="Distributed training is not supported on Windows") -@pytest.mark.skipif(not _FAIRSCALE_AVAILABLE, reason="Fairscale is not available") +@RunIf(min_gpus=2, skip_windows=True, fairscale=True) def test_ddp_sharded_plugin_correctness_multi_gpu_multi_optim_manual(tmpdir): """ Ensures using multiple optimizers across multiple GPUs with manual optimization