Fix logic and add test for apex check, rename file, add DDP launcher tests

This commit is contained in:
SeanNaren 2020-11-26 22:45:21 +00:00
parent 8dc857c38d
commit fc9b2bf015
3 changed files with 45 additions and 3 deletions

View File

@ -10,6 +10,7 @@ from torch.utils.data.distributed import DistributedSampler
from pytorch_lightning import Trainer, seed_everything
from pytorch_lightning.plugins.sharded_plugin import DDPShardedPlugin
from pytorch_lightning.utilities import FAIRSCALE_AVAILABLE, NATIVE_AMP_AVALAIBLE
from tests.backends.launcher import DDPLauncher
from tests.base.boring_model import BoringModel, RandomDataset
@ -55,6 +56,26 @@ def test_ddp_sharded_plugin_correctness_amp_multi_gpu():
run_sharded_correctness(gpus=2, precision=16, accelerator='ddp_spawn')
@pytest.mark.skipif(not FAIRSCALE_AVAILABLE, reason="Fairscale is not available")
@mock.patch.dict(os.environ, {"PL_DEV_DEBUG": "1"})
@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine")
@pytest.mark.skipif(not os.getenv("PL_RUNNING_SPECIAL_TESTS", '0') == '1',
reason="test should be run outside of pytest")
@DDPLauncher.run("--distributed_backend ddp --gpus 2 --precision 32")
def test_ddp_sharded_plugin_correctness_amp_multi_gpu_ddp(tmpdir, args=None):
run_sharded_correctness(gpus=args.gpus, precision=args.precision, accelerator=args.distributed_backend)
@pytest.mark.skipif(not FAIRSCALE_AVAILABLE, reason="Fairscale is not available")
@mock.patch.dict(os.environ, {"PL_DEV_DEBUG": "1"})
@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine")
@pytest.mark.skipif(not os.getenv("PL_RUNNING_SPECIAL_TESTS", '0') == '1',
reason="test should be run outside of pytest")
@DDPLauncher.run("--distributed_backend ddp --gpus 2 --precision 16")
def test_ddp_sharded_plugin_correctness_amp_multi_gpu_ddp(tmpdir, args=None):
run_sharded_correctness(gpus=args.gpus, precision=args.precision, accelerator=args.distributed_backend)
@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine")
@pytest.mark.skipif(platform.system() == "Windows",
reason="Distributed training is not supported on Windows")

View File

@ -73,8 +73,8 @@ class PrecisionConnector:
rank_zero_warn('You have asked for Apex AMP but you have not installed it yet.'
' Install apex first using this guide: https://github.com/NVIDIA/apex#linux')
elif using_sharded_plugin:
rank_zero_warn(
'Sharded Plugin is not supported with Apex AMP, please using native AMP for 16 bit precision.')
raise MisconfigurationException('Sharded Plugin is not supported with Apex AMP, '
'please using native AMP for 16 bit precision.')
else:
log.info('Using APEX 16bit precision.')
self.trainer.amp_backend = AMPType.APEX

View File

@ -9,7 +9,8 @@ from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import Callback
from pytorch_lightning.plugins.sharded_native_amp_plugin import ShardedNativeAMPPlugin
from pytorch_lightning.plugins.sharded_plugin import DDPShardedPlugin, FAIRSCALE_AVAILABLE
from pytorch_lightning.utilities import NATIVE_AMP_AVALAIBLE
from pytorch_lightning.utilities import NATIVE_AMP_AVALAIBLE, APEX_AVAILABLE
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from tests.base.boring_model import BoringModel
@ -54,6 +55,26 @@ def test_ddp_choice_sharded(tmpdir, ddp_backend, gpus, num_processes):
trainer.fit(model)
@pytest.mark.skipif(not APEX_AVAILABLE, reason="test requires apex")
@pytest.mark.skipif(not FAIRSCALE_AVAILABLE, reason="Fairscale is not available")
def test_invalid_apex_sharded(tmpdir):
"""
Test to ensure that we raise an error when we try to use apex and sharded
"""
model = BoringModel()
with pytest.raises(MisconfigurationException, match='Sharded Plugin is not supported with Apex AMP'):
trainer = Trainer(
fast_dev_run=True,
distributed_backend='ddp_spawn',
plugins=[DDPShardedPlugin()],
precision=16,
amp_backend='apex'
)
trainer.fit(model)
@mock.patch.dict(
os.environ,
{