Fix deepspeed default precision plugin `amp_level` to O2 (#13897)
Co-authored-by: Akihiro Nitta <nitta@akihironitta.com>
This commit is contained in:
parent
aefb9ab43f
commit
0f6caffa57
|
@ -396,6 +396,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
|
|||
- Fixed an issue that caused the learning rate finder to set the model's learning rate to None when no suggestion was possible ([#13845](https://github.com/Lightning-AI/lightning/pull/13845))
|
||||
|
||||
|
||||
- Fixed default `amp_level` for `DeepSpeedPrecisionPlugin` to `O2` ([#13897](https://github.com/PyTorchLightning/pytorch-lightning/pull/13897))
|
||||
|
||||
|
||||
|
||||
## [1.6.5] - 2022-07-13
|
||||
|
||||
|
|
|
@ -35,7 +35,7 @@ class ApexMixedPrecisionPlugin(MixedPrecisionPlugin):
|
|||
def __init__(self, amp_level: str = "O2") -> None:
|
||||
if not _APEX_AVAILABLE:
|
||||
raise MisconfigurationException(
|
||||
"You have asked for Apex AMP but you have not installed it."
|
||||
"You have asked for Apex AMP but `apex` is not installed."
|
||||
" Install `apex` using this guide: https://github.com/NVIDIA/apex"
|
||||
)
|
||||
super().__init__()
|
||||
|
|
|
@ -20,9 +20,9 @@ from torch.optim import LBFGS, Optimizer
|
|||
import pytorch_lightning as pl
|
||||
from pytorch_lightning.plugins.precision.precision_plugin import PrecisionPlugin
|
||||
from pytorch_lightning.utilities import GradClipAlgorithmType
|
||||
from pytorch_lightning.utilities.enums import PrecisionType
|
||||
from pytorch_lightning.utilities.enums import AMPType, PrecisionType
|
||||
from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
||||
from pytorch_lightning.utilities.imports import _RequirementAvailable
|
||||
from pytorch_lightning.utilities.imports import _APEX_AVAILABLE, _RequirementAvailable
|
||||
from pytorch_lightning.utilities.model_helpers import is_overridden
|
||||
from pytorch_lightning.utilities.warnings import WarningCache
|
||||
|
||||
|
@ -51,6 +51,15 @@ class DeepSpeedPrecisionPlugin(PrecisionPlugin):
|
|||
"""
|
||||
|
||||
def __init__(self, precision: Union[str, int], amp_type: str, amp_level: Optional[str] = None) -> None:
|
||||
if amp_type == AMPType.APEX:
|
||||
if not _APEX_AVAILABLE:
|
||||
raise MisconfigurationException(
|
||||
"You have asked for Apex AMP but `apex` is not installed."
|
||||
" Install `apex` using this guide: https://github.com/NVIDIA/apex"
|
||||
)
|
||||
|
||||
amp_level = amp_level or "O2"
|
||||
|
||||
supported_precision = (PrecisionType.HALF, PrecisionType.FLOAT, PrecisionType.BFLOAT, PrecisionType.MIXED)
|
||||
if precision not in supported_precision:
|
||||
raise ValueError(
|
||||
|
|
|
@ -11,11 +11,29 @@
|
|||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from unittest import mock
|
||||
|
||||
import pytest
|
||||
|
||||
from pytorch_lightning.plugins.precision.deepspeed import DeepSpeedPrecisionPlugin
|
||||
from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
||||
|
||||
|
||||
def test_invalid_precision_with_deepspeed_precision():
|
||||
with pytest.raises(ValueError, match="is not supported. `precision` must be one of"):
|
||||
DeepSpeedPrecisionPlugin(precision=64, amp_type="native")
|
||||
|
||||
|
||||
def test_deepspeed_precision_apex_not_installed(monkeypatch):
|
||||
import pytorch_lightning.plugins.precision.deepspeed as deepspeed_apex
|
||||
|
||||
monkeypatch.setattr(deepspeed_apex, "_APEX_AVAILABLE", False)
|
||||
with pytest.raises(MisconfigurationException, match="You have asked for Apex AMP but `apex` is not installed."):
|
||||
DeepSpeedPrecisionPlugin(precision=16, amp_type="apex")
|
||||
|
||||
|
||||
@mock.patch("pytorch_lightning.plugins.precision.deepspeed._APEX_AVAILABLE", return_value=True)
|
||||
def test_deepspeed_precision_apex_default_level(_):
|
||||
precision_plugin = DeepSpeedPrecisionPlugin(precision=16, amp_type="apex")
|
||||
assert isinstance(precision_plugin, DeepSpeedPrecisionPlugin)
|
||||
assert precision_plugin.amp_level == "O2"
|
||||
|
|
|
@ -289,5 +289,5 @@ def test_precision_selection_raises(monkeypatch):
|
|||
monkeypatch.setattr(apex, "_APEX_AVAILABLE", False)
|
||||
with mock.patch("pytorch_lightning.utilities.device_parser.num_cuda_devices", return_value=1), mock.patch(
|
||||
"pytorch_lightning.utilities.device_parser.is_cuda_available", return_value=True
|
||||
), pytest.raises(MisconfigurationException, match="asked for Apex AMP but you have not installed it"):
|
||||
), pytest.raises(MisconfigurationException, match="asked for Apex AMP but `apex` is not installed"):
|
||||
Trainer(amp_backend="apex", precision=16, accelerator="gpu", devices=1)
|
||||
|
|
Loading…
Reference in New Issue