Update Sharded test with RunIf (#6384)

This commit is contained in:
Kaushik B 2021-03-07 04:54:34 +05:30 committed by GitHub
parent 34b733b35e
commit 966184a452
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 10 additions and 30 deletions

View File

@ -13,7 +13,6 @@
# limitations under the License.
import os
import platform
import time
from typing import Type
@ -22,14 +21,12 @@ import torch
from pytorch_lightning import seed_everything, Trainer
from pytorch_lightning.plugins import DDPSpawnShardedPlugin
from pytorch_lightning.utilities import _FAIRSCALE_AVAILABLE, _NATIVE_AMP_AVAILABLE
from tests.accelerators import DDPLauncher
from tests.helpers.boring_model import BoringModel, RandomDataset
from tests.helpers.runif import RunIf
@pytest.mark.skipif(not torch.cuda.is_available(), reason="requires GPU machine")
@pytest.mark.skipif(platform.system() == "Windows", reason="Distributed training is not supported on Windows")
@pytest.mark.skipif(not _FAIRSCALE_AVAILABLE, reason="Fairscale is not available")
@RunIf(min_gpus=1, skip_windows=True, fairscale=True)
def test_ddp_sharded_plugin_correctness_one_gpu():
plugin_parity_test(
gpus=1,
@ -37,10 +34,7 @@ def test_ddp_sharded_plugin_correctness_one_gpu():
)
@pytest.mark.skipif(not _NATIVE_AMP_AVAILABLE, reason="Requires native AMP")
@pytest.mark.skipif(not torch.cuda.is_available(), reason="requires GPU machine")
@pytest.mark.skipif(platform.system() == "Windows", reason="Distributed training is not supported on Windows")
@pytest.mark.skipif(not _FAIRSCALE_AVAILABLE, reason="Fairscale is not available")
@RunIf(min_gpus=1, skip_windows=True, fairscale=True, amp_native=True)
def test_ddp_sharded_plugin_correctness_amp_one_gpu():
plugin_parity_test(
gpus=1,
@ -50,9 +44,7 @@ def test_ddp_sharded_plugin_correctness_amp_one_gpu():
@pytest.mark.skip(reason="Not a critical test, skip till drone CI performance improves.")
@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine")
@pytest.mark.skipif(platform.system() == "Windows", reason="Distributed training is not supported on Windows")
@pytest.mark.skipif(not _FAIRSCALE_AVAILABLE, reason="Fairscale is not available")
@RunIf(min_gpus=2, skip_windows=True, fairscale=True)
def test_ddp_sharded_plugin_correctness_multi_gpu():
plugin_parity_test(
gpus=2,
@ -61,10 +53,7 @@ def test_ddp_sharded_plugin_correctness_multi_gpu():
)
@pytest.mark.skipif(not _NATIVE_AMP_AVAILABLE, reason="Requires native AMP")
@pytest.mark.skipif(platform.system() == "Windows", reason="Distributed training is not supported on Windows")
@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine")
@pytest.mark.skipif(not _FAIRSCALE_AVAILABLE, reason="Fairscale is not available")
@RunIf(min_gpus=2, skip_windows=True, fairscale=True, amp_native=True)
def test_ddp_sharded_plugin_correctness_amp_multi_gpu():
plugin_parity_test(
gpus=2,
@ -74,10 +63,7 @@ def test_ddp_sharded_plugin_correctness_amp_multi_gpu():
)
@pytest.mark.skipif(not _NATIVE_AMP_AVAILABLE, reason="Requires native AMP")
@pytest.mark.skipif(platform.system() == "Windows", reason="Distributed training is not supported on Windows")
@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine")
@pytest.mark.skipif(not _FAIRSCALE_AVAILABLE, reason="Fairscale is not available")
@RunIf(min_gpus=2, skip_windows=True, fairscale=True, amp_native=True)
def test_ddp_string_sharded_plugin_correctness_amp_multi_gpu():
plugin_parity_test(
gpus=2,
@ -87,8 +73,7 @@ def test_ddp_string_sharded_plugin_correctness_amp_multi_gpu():
)
@pytest.mark.skipif(not _FAIRSCALE_AVAILABLE, reason="Fairscale is not available")
@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine")
@RunIf(min_gpus=2, fairscale=True)
@pytest.mark.skipif(
not os.getenv("PL_RUNNING_SPECIAL_TESTS", '0') == '1', reason="test should be run outside of pytest"
)
@ -101,8 +86,7 @@ def test_ddp_sharded_plugin_correctness_multi_gpu_ddp(tmpdir, args=None):
)
@pytest.mark.skipif(not _FAIRSCALE_AVAILABLE, reason="Fairscale is not available")
@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine")
@RunIf(min_gpus=2, fairscale=True)
@pytest.mark.skipif(
not os.getenv("PL_RUNNING_SPECIAL_TESTS", '0') == '1', reason="test should be run outside of pytest"
)
@ -116,9 +100,7 @@ def test_ddp_sharded_plugin_correctness_amp_multi_gpu_ddp(tmpdir, args=None):
@pytest.mark.skip(reason="Current issue with multiple optimizers and FairScale.")
@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine")
@pytest.mark.skipif(platform.system() == "Windows", reason="Distributed training is not supported on Windows")
@pytest.mark.skipif(not _FAIRSCALE_AVAILABLE, reason="Fairscale is not available")
@RunIf(min_gpus=2, skip_windows=True, fairscale=True)
def test_ddp_sharded_plugin_correctness_multi_gpu_multi_optim():
"""
Ensures same results using multiple optimizers across multiple GPUs
@ -131,9 +113,7 @@ def test_ddp_sharded_plugin_correctness_multi_gpu_multi_optim():
@pytest.mark.skip(reason="Current issue with multiple optimizers and FairScale.")
@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine")
@pytest.mark.skipif(platform.system() == "Windows", reason="Distributed training is not supported on Windows")
@pytest.mark.skipif(not _FAIRSCALE_AVAILABLE, reason="Fairscale is not available")
@RunIf(min_gpus=2, skip_windows=True, fairscale=True)
def test_ddp_sharded_plugin_correctness_multi_gpu_multi_optim_manual(tmpdir):
"""
Ensures using multiple optimizers across multiple GPUs with manual optimization