Fix deepspeed default precision plugin `amp_level` to O2 (#13897)

Co-authored-by: Akihiro Nitta <nitta@akihironitta.com>
This commit is contained in:
Rohit Gupta 2022-07-30 02:06:51 +05:30 committed by GitHub
parent aefb9ab43f
commit 0f6caffa57
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 34 additions and 4 deletions

View File

@ -396,6 +396,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Fixed an issue that caused the learning rate finder to set the model's learning rate to None when no suggestion was possible ([#13845](https://github.com/Lightning-AI/lightning/pull/13845))
- Fixed default `amp_level` for `DeepSpeedPrecisionPlugin` to `O2` ([#13897](https://github.com/PyTorchLightning/pytorch-lightning/pull/13897))
## [1.6.5] - 2022-07-13

View File

@ -35,7 +35,7 @@ class ApexMixedPrecisionPlugin(MixedPrecisionPlugin):
def __init__(self, amp_level: str = "O2") -> None:
if not _APEX_AVAILABLE:
raise MisconfigurationException(
"You have asked for Apex AMP but you have not installed it."
"You have asked for Apex AMP but `apex` is not installed."
" Install `apex` using this guide: https://github.com/NVIDIA/apex"
)
super().__init__()

View File

@ -20,9 +20,9 @@ from torch.optim import LBFGS, Optimizer
import pytorch_lightning as pl
from pytorch_lightning.plugins.precision.precision_plugin import PrecisionPlugin
from pytorch_lightning.utilities import GradClipAlgorithmType
from pytorch_lightning.utilities.enums import PrecisionType
from pytorch_lightning.utilities.enums import AMPType, PrecisionType
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.imports import _RequirementAvailable
from pytorch_lightning.utilities.imports import _APEX_AVAILABLE, _RequirementAvailable
from pytorch_lightning.utilities.model_helpers import is_overridden
from pytorch_lightning.utilities.warnings import WarningCache
@ -51,6 +51,15 @@ class DeepSpeedPrecisionPlugin(PrecisionPlugin):
"""
def __init__(self, precision: Union[str, int], amp_type: str, amp_level: Optional[str] = None) -> None:
if amp_type == AMPType.APEX:
if not _APEX_AVAILABLE:
raise MisconfigurationException(
"You have asked for Apex AMP but `apex` is not installed."
" Install `apex` using this guide: https://github.com/NVIDIA/apex"
)
amp_level = amp_level or "O2"
supported_precision = (PrecisionType.HALF, PrecisionType.FLOAT, PrecisionType.BFLOAT, PrecisionType.MIXED)
if precision not in supported_precision:
raise ValueError(

View File

@ -11,11 +11,29 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from unittest import mock
import pytest
from pytorch_lightning.plugins.precision.deepspeed import DeepSpeedPrecisionPlugin
from pytorch_lightning.utilities.exceptions import MisconfigurationException
def test_invalid_precision_with_deepspeed_precision():
with pytest.raises(ValueError, match="is not supported. `precision` must be one of"):
DeepSpeedPrecisionPlugin(precision=64, amp_type="native")
def test_deepspeed_precision_apex_not_installed(monkeypatch):
import pytorch_lightning.plugins.precision.deepspeed as deepspeed_apex
monkeypatch.setattr(deepspeed_apex, "_APEX_AVAILABLE", False)
with pytest.raises(MisconfigurationException, match="You have asked for Apex AMP but `apex` is not installed."):
DeepSpeedPrecisionPlugin(precision=16, amp_type="apex")
@mock.patch("pytorch_lightning.plugins.precision.deepspeed._APEX_AVAILABLE", return_value=True)
def test_deepspeed_precision_apex_default_level(_):
precision_plugin = DeepSpeedPrecisionPlugin(precision=16, amp_type="apex")
assert isinstance(precision_plugin, DeepSpeedPrecisionPlugin)
assert precision_plugin.amp_level == "O2"

View File

@ -289,5 +289,5 @@ def test_precision_selection_raises(monkeypatch):
monkeypatch.setattr(apex, "_APEX_AVAILABLE", False)
with mock.patch("pytorch_lightning.utilities.device_parser.num_cuda_devices", return_value=1), mock.patch(
"pytorch_lightning.utilities.device_parser.is_cuda_available", return_value=True
), pytest.raises(MisconfigurationException, match="asked for Apex AMP but you have not installed it"):
), pytest.raises(MisconfigurationException, match="asked for Apex AMP but `apex` is not installed"):
Trainer(amp_backend="apex", precision=16, accelerator="gpu", devices=1)