diff --git a/benchmarks/test_sharded_correctness.py b/benchmarks/test_sharded_parity.py similarity index 87% rename from benchmarks/test_sharded_correctness.py rename to benchmarks/test_sharded_parity.py index c6472f0b01..78eb98e148 100644 --- a/benchmarks/test_sharded_correctness.py +++ b/benchmarks/test_sharded_parity.py @@ -10,6 +10,7 @@ from torch.utils.data.distributed import DistributedSampler from pytorch_lightning import Trainer, seed_everything from pytorch_lightning.plugins.sharded_plugin import DDPShardedPlugin from pytorch_lightning.utilities import FAIRSCALE_AVAILABLE, NATIVE_AMP_AVALAIBLE +from tests.backends.launcher import DDPLauncher from tests.base.boring_model import BoringModel, RandomDataset @@ -55,6 +56,26 @@ def test_ddp_sharded_plugin_correctness_amp_multi_gpu(): run_sharded_correctness(gpus=2, precision=16, accelerator='ddp_spawn') +@pytest.mark.skipif(not FAIRSCALE_AVAILABLE, reason="Fairscale is not available") +@mock.patch.dict(os.environ, {"PL_DEV_DEBUG": "1"}) +@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine") +@pytest.mark.skipif(not os.getenv("PL_RUNNING_SPECIAL_TESTS", '0') == '1', + reason="test should be run outside of pytest") +@DDPLauncher.run("--distributed_backend ddp --gpus 2 --precision 32") +def test_ddp_sharded_plugin_correctness_amp_multi_gpu_ddp(tmpdir, args=None): + run_sharded_correctness(gpus=args.gpus, precision=args.precision, accelerator=args.distributed_backend) + + +@pytest.mark.skipif(not FAIRSCALE_AVAILABLE, reason="Fairscale is not available") +@mock.patch.dict(os.environ, {"PL_DEV_DEBUG": "1"}) +@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine") +@pytest.mark.skipif(not os.getenv("PL_RUNNING_SPECIAL_TESTS", '0') == '1', + reason="test should be run outside of pytest") +@DDPLauncher.run("--distributed_backend ddp --gpus 2 --precision 16") +def test_ddp_sharded_plugin_correctness_amp_multi_gpu_ddp(tmpdir, args=None): + run_sharded_correctness(gpus=args.gpus, precision=args.precision, accelerator=args.distributed_backend) + + @pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine") @pytest.mark.skipif(platform.system() == "Windows", reason="Distributed training is not supported on Windows") diff --git a/pytorch_lightning/trainer/connectors/precision_connector.py b/pytorch_lightning/trainer/connectors/precision_connector.py index e4ea99bdf2..71d884bdfe 100644 --- a/pytorch_lightning/trainer/connectors/precision_connector.py +++ b/pytorch_lightning/trainer/connectors/precision_connector.py @@ -73,8 +73,8 @@ class PrecisionConnector: rank_zero_warn('You have asked for Apex AMP but you have not installed it yet.' ' Install apex first using this guide: https://github.com/NVIDIA/apex#linux') elif using_sharded_plugin: - rank_zero_warn( - 'Sharded Plugin is not supported with Apex AMP, please using native AMP for 16 bit precision.') + raise MisconfigurationException('Sharded Plugin is not supported with Apex AMP, ' + 'please using native AMP for 16 bit precision.') else: log.info('Using APEX 16bit precision.') self.trainer.amp_backend = AMPType.APEX diff --git a/tests/plugins/test_sharded_plugin.py b/tests/plugins/test_sharded_plugin.py index 87d568a3cd..6d93929761 100644 --- a/tests/plugins/test_sharded_plugin.py +++ b/tests/plugins/test_sharded_plugin.py @@ -9,7 +9,8 @@ from pytorch_lightning import Trainer from pytorch_lightning.callbacks import Callback from pytorch_lightning.plugins.sharded_native_amp_plugin import ShardedNativeAMPPlugin from pytorch_lightning.plugins.sharded_plugin import DDPShardedPlugin, FAIRSCALE_AVAILABLE -from pytorch_lightning.utilities import NATIVE_AMP_AVALAIBLE +from pytorch_lightning.utilities import NATIVE_AMP_AVALAIBLE, APEX_AVAILABLE +from pytorch_lightning.utilities.exceptions import MisconfigurationException from tests.base.boring_model import BoringModel @@ -54,6 +55,26 @@ def test_ddp_choice_sharded(tmpdir, ddp_backend, gpus, num_processes): trainer.fit(model) +@pytest.mark.skipif(not APEX_AVAILABLE, reason="test requires apex") +@pytest.mark.skipif(not FAIRSCALE_AVAILABLE, reason="Fairscale is not available") +def test_invalid_apex_sharded(tmpdir): + """ + Test to ensure that we raise an error when we try to use apex and sharded + """ + + model = BoringModel() + with pytest.raises(MisconfigurationException, match='Sharded Plugin is not supported with Apex AMP'): + trainer = Trainer( + fast_dev_run=True, + distributed_backend='ddp_spawn', + plugins=[DDPShardedPlugin()], + precision=16, + amp_backend='apex' + ) + + trainer.fit(model) + + @mock.patch.dict( os.environ, {