From 936f42aa1c72b9bb1f7f4108ada0a77cdf256220 Mon Sep 17 00:00:00 2001 From: Jirka Borovec Date: Tue, 16 Feb 2021 20:06:47 +0100 Subject: [PATCH] clean AMP logic (#5994) * clean AMP logic * cleaning * ... * ... * Even apex --- .../accelerators/accelerator_connector.py | 50 +++++++++---------- tests/models/test_amp.py | 15 +----- 2 files changed, 25 insertions(+), 40 deletions(-) diff --git a/pytorch_lightning/accelerators/accelerator_connector.py b/pytorch_lightning/accelerators/accelerator_connector.py index cfa9545ad6..644b382b6b 100644 --- a/pytorch_lightning/accelerators/accelerator_connector.py +++ b/pytorch_lightning/accelerators/accelerator_connector.py @@ -298,46 +298,42 @@ class BackendConnector(object): if self.on_tpu: return TPUHalfPrecisionPlugin() - if self.amp_type == "native": - if not _NATIVE_AMP_AVAILABLE: - rank_zero_warn( - "You have asked for native AMP but your PyTorch version does not support it." - " Consider upgrading with `pip install torch>=1.6`." - " We will attempt to use NVIDIA Apex for this session." - ) - if not _APEX_AVAILABLE and self.on_cpu: - raise MisconfigurationException( - "You have asked for native AMP on CPU, but AMP is only available on GPU." - ) - self.amp_type = "apex" - elif self.on_cpu: + self.amp_type = AMPType(self.amp_type) + if self.amp_type == AMPType.NATIVE: + if self.on_cpu: raise MisconfigurationException( "You have asked for native AMP on CPU, but AMP is only available on GPU." ) + elif not _NATIVE_AMP_AVAILABLE: + msg = "You have asked for native AMP but your PyTorch version does not support it." \ + " Consider upgrading with `pip install torch>=1.6`." + if _APEX_AVAILABLE: + self.amp_type = AMPType.APEX + msg += " We will attempt to use NVIDIA Apex for this session." + rank_zero_warn(msg) + else: + raise MisconfigurationException(msg) else: log.info("Using native 16bit precision.") if isinstance(self.training_type_plugin, (DDPShardedPlugin, DDPSpawnShardedPlugin)): return ShardedNativeMixedPrecisionPlugin() - self.amp_type = AMPType.NATIVE return NativeMixedPrecisionPlugin() - if self.amp_type == "apex": + if self.amp_type == AMPType.APEX: if not _APEX_AVAILABLE: - rank_zero_warn( + raise MisconfigurationException( "You have asked for Apex AMP but you have not installed it yet." " Install apex first using this guide: https://github.com/NVIDIA/apex#linux" ) - else: - if isinstance(self.training_type_plugin, (DDPShardedPlugin, DDPSpawnShardedPlugin)): - raise MisconfigurationException( - "Sharded Plugin is not supported with Apex AMP, " - "please using native AMP for 16-bit precision." - ) - log.info("Using APEX 16bit precision.") - self.amp_type = AMPType.APEX - return ApexMixedPrecisionPlugin(self.amp_level) - else: - raise NotImplementedError("We only support precisions 32 and 16!") + if isinstance(self.training_type_plugin, (DDPShardedPlugin, DDPSpawnShardedPlugin)): + raise MisconfigurationException( + "Sharded Plugin is not supported with Apex AMP," + " please using native AMP for 16-bit precision." + ) + log.info("Using APEX 16bit precision.") + return ApexMixedPrecisionPlugin(self.amp_level) + + raise NotImplementedError("We only support precisions 32 and 16!") def select_training_type_plugin(self) -> TrainingTypePlugin: if self.use_ddp2: diff --git a/tests/models/test_amp.py b/tests/models/test_amp.py index ff623af963..2dd6c9d997 100644 --- a/tests/models/test_amp.py +++ b/tests/models/test_amp.py @@ -18,7 +18,6 @@ import pytest import torch from torch import optim -import tests.helpers.pipelines as tpipes import tests.helpers.utils as tutils from pytorch_lightning import Trainer from pytorch_lightning.plugins.environments import SLURMEnvironment @@ -155,21 +154,11 @@ def test_amp_gpu_ddp_slurm_managed(tmpdir): assert generated == 'abc23' +@pytest.mark.skipif(torch.cuda.is_available(), reason="test is restricted only on CPU") def test_cpu_model_with_amp(tmpdir): """Make sure model trains on CPU.""" - trainer_options = dict( - default_root_dir=tmpdir, - progress_bar_refresh_rate=0, - max_epochs=1, - limit_train_batches=0.4, - limit_val_batches=0.4, - precision=16, - ) - - model = BoringModel() - with pytest.raises(MisconfigurationException, match="AMP is only available on GPU"): - tpipes.run_model_test(trainer_options, model, on_gpu=False) + Trainer(precision=16) @mock.patch.dict(os.environ, {"PL_DEV_DEBUG": "1"})