Fix logic and add test for apex check, rename file, add DDP launcher tests
This commit is contained in:
parent
8dc857c38d
commit
fc9b2bf015
|
@ -10,6 +10,7 @@ from torch.utils.data.distributed import DistributedSampler
|
|||
from pytorch_lightning import Trainer, seed_everything
|
||||
from pytorch_lightning.plugins.sharded_plugin import DDPShardedPlugin
|
||||
from pytorch_lightning.utilities import FAIRSCALE_AVAILABLE, NATIVE_AMP_AVALAIBLE
|
||||
from tests.backends.launcher import DDPLauncher
|
||||
from tests.base.boring_model import BoringModel, RandomDataset
|
||||
|
||||
|
||||
|
@ -55,6 +56,26 @@ def test_ddp_sharded_plugin_correctness_amp_multi_gpu():
|
|||
run_sharded_correctness(gpus=2, precision=16, accelerator='ddp_spawn')
|
||||
|
||||
|
||||
@pytest.mark.skipif(not FAIRSCALE_AVAILABLE, reason="Fairscale is not available")
|
||||
@mock.patch.dict(os.environ, {"PL_DEV_DEBUG": "1"})
|
||||
@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine")
|
||||
@pytest.mark.skipif(not os.getenv("PL_RUNNING_SPECIAL_TESTS", '0') == '1',
|
||||
reason="test should be run outside of pytest")
|
||||
@DDPLauncher.run("--distributed_backend ddp --gpus 2 --precision 32")
|
||||
def test_ddp_sharded_plugin_correctness_amp_multi_gpu_ddp(tmpdir, args=None):
|
||||
run_sharded_correctness(gpus=args.gpus, precision=args.precision, accelerator=args.distributed_backend)
|
||||
|
||||
|
||||
@pytest.mark.skipif(not FAIRSCALE_AVAILABLE, reason="Fairscale is not available")
|
||||
@mock.patch.dict(os.environ, {"PL_DEV_DEBUG": "1"})
|
||||
@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine")
|
||||
@pytest.mark.skipif(not os.getenv("PL_RUNNING_SPECIAL_TESTS", '0') == '1',
|
||||
reason="test should be run outside of pytest")
|
||||
@DDPLauncher.run("--distributed_backend ddp --gpus 2 --precision 16")
|
||||
def test_ddp_sharded_plugin_correctness_amp_multi_gpu_ddp(tmpdir, args=None):
|
||||
run_sharded_correctness(gpus=args.gpus, precision=args.precision, accelerator=args.distributed_backend)
|
||||
|
||||
|
||||
@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine")
|
||||
@pytest.mark.skipif(platform.system() == "Windows",
|
||||
reason="Distributed training is not supported on Windows")
|
|
@ -73,8 +73,8 @@ class PrecisionConnector:
|
|||
rank_zero_warn('You have asked for Apex AMP but you have not installed it yet.'
|
||||
' Install apex first using this guide: https://github.com/NVIDIA/apex#linux')
|
||||
elif using_sharded_plugin:
|
||||
rank_zero_warn(
|
||||
'Sharded Plugin is not supported with Apex AMP, please using native AMP for 16 bit precision.')
|
||||
raise MisconfigurationException('Sharded Plugin is not supported with Apex AMP, '
|
||||
'please using native AMP for 16 bit precision.')
|
||||
else:
|
||||
log.info('Using APEX 16bit precision.')
|
||||
self.trainer.amp_backend = AMPType.APEX
|
||||
|
|
|
@ -9,7 +9,8 @@ from pytorch_lightning import Trainer
|
|||
from pytorch_lightning.callbacks import Callback
|
||||
from pytorch_lightning.plugins.sharded_native_amp_plugin import ShardedNativeAMPPlugin
|
||||
from pytorch_lightning.plugins.sharded_plugin import DDPShardedPlugin, FAIRSCALE_AVAILABLE
|
||||
from pytorch_lightning.utilities import NATIVE_AMP_AVALAIBLE
|
||||
from pytorch_lightning.utilities import NATIVE_AMP_AVALAIBLE, APEX_AVAILABLE
|
||||
from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
||||
from tests.base.boring_model import BoringModel
|
||||
|
||||
|
||||
|
@ -54,6 +55,26 @@ def test_ddp_choice_sharded(tmpdir, ddp_backend, gpus, num_processes):
|
|||
trainer.fit(model)
|
||||
|
||||
|
||||
@pytest.mark.skipif(not APEX_AVAILABLE, reason="test requires apex")
|
||||
@pytest.mark.skipif(not FAIRSCALE_AVAILABLE, reason="Fairscale is not available")
|
||||
def test_invalid_apex_sharded(tmpdir):
|
||||
"""
|
||||
Test to ensure that we raise an error when we try to use apex and sharded
|
||||
"""
|
||||
|
||||
model = BoringModel()
|
||||
with pytest.raises(MisconfigurationException, match='Sharded Plugin is not supported with Apex AMP'):
|
||||
trainer = Trainer(
|
||||
fast_dev_run=True,
|
||||
distributed_backend='ddp_spawn',
|
||||
plugins=[DDPShardedPlugin()],
|
||||
precision=16,
|
||||
amp_backend='apex'
|
||||
)
|
||||
|
||||
trainer.fit(model)
|
||||
|
||||
|
||||
@mock.patch.dict(
|
||||
os.environ,
|
||||
{
|
||||
|
|
Loading…
Reference in New Issue