Update Sharded test with RunIf (#6384)
This commit is contained in:
parent
34b733b35e
commit
966184a452
|
@ -13,7 +13,6 @@
|
|||
# limitations under the License.
|
||||
|
||||
import os
|
||||
import platform
|
||||
import time
|
||||
from typing import Type
|
||||
|
||||
|
@ -22,14 +21,12 @@ import torch
|
|||
|
||||
from pytorch_lightning import seed_everything, Trainer
|
||||
from pytorch_lightning.plugins import DDPSpawnShardedPlugin
|
||||
from pytorch_lightning.utilities import _FAIRSCALE_AVAILABLE, _NATIVE_AMP_AVAILABLE
|
||||
from tests.accelerators import DDPLauncher
|
||||
from tests.helpers.boring_model import BoringModel, RandomDataset
|
||||
from tests.helpers.runif import RunIf
|
||||
|
||||
|
||||
@pytest.mark.skipif(not torch.cuda.is_available(), reason="requires GPU machine")
|
||||
@pytest.mark.skipif(platform.system() == "Windows", reason="Distributed training is not supported on Windows")
|
||||
@pytest.mark.skipif(not _FAIRSCALE_AVAILABLE, reason="Fairscale is not available")
|
||||
@RunIf(min_gpus=1, skip_windows=True, fairscale=True)
|
||||
def test_ddp_sharded_plugin_correctness_one_gpu():
|
||||
plugin_parity_test(
|
||||
gpus=1,
|
||||
|
@ -37,10 +34,7 @@ def test_ddp_sharded_plugin_correctness_one_gpu():
|
|||
)
|
||||
|
||||
|
||||
@pytest.mark.skipif(not _NATIVE_AMP_AVAILABLE, reason="Requires native AMP")
|
||||
@pytest.mark.skipif(not torch.cuda.is_available(), reason="requires GPU machine")
|
||||
@pytest.mark.skipif(platform.system() == "Windows", reason="Distributed training is not supported on Windows")
|
||||
@pytest.mark.skipif(not _FAIRSCALE_AVAILABLE, reason="Fairscale is not available")
|
||||
@RunIf(min_gpus=1, skip_windows=True, fairscale=True, amp_native=True)
|
||||
def test_ddp_sharded_plugin_correctness_amp_one_gpu():
|
||||
plugin_parity_test(
|
||||
gpus=1,
|
||||
|
@ -50,9 +44,7 @@ def test_ddp_sharded_plugin_correctness_amp_one_gpu():
|
|||
|
||||
|
||||
@pytest.mark.skip(reason="Not a critical test, skip till drone CI performance improves.")
|
||||
@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine")
|
||||
@pytest.mark.skipif(platform.system() == "Windows", reason="Distributed training is not supported on Windows")
|
||||
@pytest.mark.skipif(not _FAIRSCALE_AVAILABLE, reason="Fairscale is not available")
|
||||
@RunIf(min_gpus=2, skip_windows=True, fairscale=True)
|
||||
def test_ddp_sharded_plugin_correctness_multi_gpu():
|
||||
plugin_parity_test(
|
||||
gpus=2,
|
||||
|
@ -61,10 +53,7 @@ def test_ddp_sharded_plugin_correctness_multi_gpu():
|
|||
)
|
||||
|
||||
|
||||
@pytest.mark.skipif(not _NATIVE_AMP_AVAILABLE, reason="Requires native AMP")
|
||||
@pytest.mark.skipif(platform.system() == "Windows", reason="Distributed training is not supported on Windows")
|
||||
@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine")
|
||||
@pytest.mark.skipif(not _FAIRSCALE_AVAILABLE, reason="Fairscale is not available")
|
||||
@RunIf(min_gpus=2, skip_windows=True, fairscale=True, amp_native=True)
|
||||
def test_ddp_sharded_plugin_correctness_amp_multi_gpu():
|
||||
plugin_parity_test(
|
||||
gpus=2,
|
||||
|
@ -74,10 +63,7 @@ def test_ddp_sharded_plugin_correctness_amp_multi_gpu():
|
|||
)
|
||||
|
||||
|
||||
@pytest.mark.skipif(not _NATIVE_AMP_AVAILABLE, reason="Requires native AMP")
|
||||
@pytest.mark.skipif(platform.system() == "Windows", reason="Distributed training is not supported on Windows")
|
||||
@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine")
|
||||
@pytest.mark.skipif(not _FAIRSCALE_AVAILABLE, reason="Fairscale is not available")
|
||||
@RunIf(min_gpus=2, skip_windows=True, fairscale=True, amp_native=True)
|
||||
def test_ddp_string_sharded_plugin_correctness_amp_multi_gpu():
|
||||
plugin_parity_test(
|
||||
gpus=2,
|
||||
|
@ -87,8 +73,7 @@ def test_ddp_string_sharded_plugin_correctness_amp_multi_gpu():
|
|||
)
|
||||
|
||||
|
||||
@pytest.mark.skipif(not _FAIRSCALE_AVAILABLE, reason="Fairscale is not available")
|
||||
@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine")
|
||||
@RunIf(min_gpus=2, fairscale=True)
|
||||
@pytest.mark.skipif(
|
||||
not os.getenv("PL_RUNNING_SPECIAL_TESTS", '0') == '1', reason="test should be run outside of pytest"
|
||||
)
|
||||
|
@ -101,8 +86,7 @@ def test_ddp_sharded_plugin_correctness_multi_gpu_ddp(tmpdir, args=None):
|
|||
)
|
||||
|
||||
|
||||
@pytest.mark.skipif(not _FAIRSCALE_AVAILABLE, reason="Fairscale is not available")
|
||||
@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine")
|
||||
@RunIf(min_gpus=2, fairscale=True)
|
||||
@pytest.mark.skipif(
|
||||
not os.getenv("PL_RUNNING_SPECIAL_TESTS", '0') == '1', reason="test should be run outside of pytest"
|
||||
)
|
||||
|
@ -116,9 +100,7 @@ def test_ddp_sharded_plugin_correctness_amp_multi_gpu_ddp(tmpdir, args=None):
|
|||
|
||||
|
||||
@pytest.mark.skip(reason="Current issue with multiple optimizers and FairScale.")
|
||||
@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine")
|
||||
@pytest.mark.skipif(platform.system() == "Windows", reason="Distributed training is not supported on Windows")
|
||||
@pytest.mark.skipif(not _FAIRSCALE_AVAILABLE, reason="Fairscale is not available")
|
||||
@RunIf(min_gpus=2, skip_windows=True, fairscale=True)
|
||||
def test_ddp_sharded_plugin_correctness_multi_gpu_multi_optim():
|
||||
"""
|
||||
Ensures same results using multiple optimizers across multiple GPUs
|
||||
|
@ -131,9 +113,7 @@ def test_ddp_sharded_plugin_correctness_multi_gpu_multi_optim():
|
|||
|
||||
|
||||
@pytest.mark.skip(reason="Current issue with multiple optimizers and FairScale.")
|
||||
@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine")
|
||||
@pytest.mark.skipif(platform.system() == "Windows", reason="Distributed training is not supported on Windows")
|
||||
@pytest.mark.skipif(not _FAIRSCALE_AVAILABLE, reason="Fairscale is not available")
|
||||
@RunIf(min_gpus=2, skip_windows=True, fairscale=True)
|
||||
def test_ddp_sharded_plugin_correctness_multi_gpu_multi_optim_manual(tmpdir):
|
||||
"""
|
||||
Ensures using multiple optimizers across multiple GPUs with manual optimization
|
||||
|
|
Loading…
Reference in New Issue