From 63a9ab4ae298f1e32e6a3a289f67bf49a60786fc Mon Sep 17 00:00:00 2001 From: Atharva Phatak Date: Wed, 22 Jun 2022 11:09:33 -0400 Subject: [PATCH] Improved Deepspeed Imports (#13223) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Carlos MocholĂ­ Co-authored-by: Jirka --- .../plugins/precision/deepspeed.py | 18 +++++++++++------- src/pytorch_lightning/strategies/deepspeed.py | 3 ++- src/pytorch_lightning/utilities/__init__.py | 1 - src/pytorch_lightning/utilities/deepspeed.py | 2 +- src/pytorch_lightning/utilities/imports.py | 3 --- tests/tests_pytorch/helpers/runif.py | 2 +- .../strategies/test_deepspeed_strategy.py | 10 ++++------ tests/tests_pytorch/utilities/test_imports.py | 2 +- 8 files changed, 20 insertions(+), 21 deletions(-) diff --git a/src/pytorch_lightning/plugins/precision/deepspeed.py b/src/pytorch_lightning/plugins/precision/deepspeed.py index cf69ec548f..2421c171fb 100644 --- a/src/pytorch_lightning/plugins/precision/deepspeed.py +++ b/src/pytorch_lightning/plugins/precision/deepspeed.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Callable, Optional, Union +from typing import Any, Callable, Optional, TYPE_CHECKING, Union from torch import Tensor from torch.nn import Module @@ -22,12 +22,14 @@ from pytorch_lightning.plugins.precision.precision_plugin import PrecisionPlugin from pytorch_lightning.utilities import GradClipAlgorithmType from pytorch_lightning.utilities.enums import PrecisionType from pytorch_lightning.utilities.exceptions import MisconfigurationException -from pytorch_lightning.utilities.imports import _DEEPSPEED_AVAILABLE, _DEEPSPEED_GREATER_EQUAL_0_6 +from pytorch_lightning.utilities.imports import _RequirementAvailable from pytorch_lightning.utilities.model_helpers import is_overridden from pytorch_lightning.utilities.warnings import WarningCache -if _DEEPSPEED_AVAILABLE: - from deepspeed import DeepSpeedEngine +_DEEPSPEED_GREATER_EQUAL_0_6 = _RequirementAvailable("deepspeed>=0.6.0") +if TYPE_CHECKING: + if pl.strategies.deepspeed._DEEPSPEED_AVAILABLE: + import deepspeed warning_cache = WarningCache() @@ -75,10 +77,12 @@ class DeepSpeedPrecisionPlugin(PrecisionPlugin): " the backward logic internally." ) assert model.trainer is not None - deepspeed_engine: DeepSpeedEngine = model.trainer.model + deepspeed_engine: "deepspeed.DeepSpeedEngine" = model.trainer.model deepspeed_engine.backward(closure_loss, *args, **kwargs) - def _run_backward(self, tensor: Tensor, model: Optional["DeepSpeedEngine"], *args: Any, **kwargs: Any) -> None: + def _run_backward( + self, tensor: Tensor, model: Optional["deepspeed.DeepSpeedEngine"], *args: Any, **kwargs: Any + ) -> None: if model is None: raise ValueError("Please provide the model as input to `backward`.") model.backward(tensor, *args, **kwargs) @@ -104,7 +108,7 @@ class DeepSpeedPrecisionPlugin(PrecisionPlugin): "Skipping backward by returning `None` from your `training_step` is not supported by `DeepSpeed`" ) # DeepSpeed handles the optimizer step internally - deepspeed_engine: DeepSpeedEngine + deepspeed_engine: "deepspeed.DeepSpeedEngine" if isinstance(model, pl.LightningModule): assert model.trainer is not None deepspeed_engine = model.trainer.model diff --git a/src/pytorch_lightning/strategies/deepspeed.py b/src/pytorch_lightning/strategies/deepspeed.py index bae617561f..9b4d3513c1 100644 --- a/src/pytorch_lightning/strategies/deepspeed.py +++ b/src/pytorch_lightning/strategies/deepspeed.py @@ -43,7 +43,7 @@ from pytorch_lightning.utilities.distributed import ( ) from pytorch_lightning.utilities.enums import AMPType, PrecisionType from pytorch_lightning.utilities.exceptions import MisconfigurationException -from pytorch_lightning.utilities.imports import _DEEPSPEED_AVAILABLE +from pytorch_lightning.utilities.imports import _RequirementAvailable from pytorch_lightning.utilities.model_helpers import is_overridden from pytorch_lightning.utilities.optimizer import optimizers_to_device from pytorch_lightning.utilities.rank_zero import rank_zero_info @@ -53,6 +53,7 @@ from pytorch_lightning.utilities.warnings import rank_zero_warn, WarningCache warning_cache = WarningCache() +_DEEPSPEED_AVAILABLE: bool = _RequirementAvailable("deepspeed") if _DEEPSPEED_AVAILABLE: import deepspeed diff --git a/src/pytorch_lightning/utilities/__init__.py b/src/pytorch_lightning/utilities/__init__.py index 139caed97e..8182eca905 100644 --- a/src/pytorch_lightning/utilities/__init__.py +++ b/src/pytorch_lightning/utilities/__init__.py @@ -29,7 +29,6 @@ from pytorch_lightning.utilities.grads import grad_norm # noqa: F401 from pytorch_lightning.utilities.imports import ( # noqa: F401 _APEX_AVAILABLE, _BAGUA_AVAILABLE, - _DEEPSPEED_AVAILABLE, _FAIRSCALE_AVAILABLE, _FAIRSCALE_FULLY_SHARDED_AVAILABLE, _FAIRSCALE_OSS_FP16_BROADCAST_AVAILABLE, diff --git a/src/pytorch_lightning/utilities/deepspeed.py b/src/pytorch_lightning/utilities/deepspeed.py index 3492165f88..d671be1e6c 100644 --- a/src/pytorch_lightning/utilities/deepspeed.py +++ b/src/pytorch_lightning/utilities/deepspeed.py @@ -20,7 +20,7 @@ import os import torch -from pytorch_lightning.utilities import _DEEPSPEED_AVAILABLE +from pytorch_lightning.strategies.deepspeed import _DEEPSPEED_AVAILABLE from pytorch_lightning.utilities.types import _PATH if _DEEPSPEED_AVAILABLE: diff --git a/src/pytorch_lightning/utilities/imports.py b/src/pytorch_lightning/utilities/imports.py index e47da3051e..ce3d23e6d7 100644 --- a/src/pytorch_lightning/utilities/imports.py +++ b/src/pytorch_lightning/utilities/imports.py @@ -133,9 +133,6 @@ _TORCH_GREATER_EQUAL_1_12 = _compare_version("torch", operator.ge, "1.12.0", use _APEX_AVAILABLE = _module_available("apex.amp") _BAGUA_AVAILABLE = _package_available("bagua") _DALI_AVAILABLE = _module_available("nvidia.dali") -_DEEPSPEED_AVAILABLE = _package_available("deepspeed") -_DEEPSPEED_GREATER_EQUAL_0_5_9 = _DEEPSPEED_AVAILABLE and _compare_version("deepspeed", operator.ge, "0.5.9") -_DEEPSPEED_GREATER_EQUAL_0_6 = _DEEPSPEED_AVAILABLE and _compare_version("deepspeed", operator.ge, "0.6.0") _FAIRSCALE_AVAILABLE = not _IS_WINDOWS and _module_available("fairscale.nn") _FAIRSCALE_OSS_FP16_BROADCAST_AVAILABLE = _FAIRSCALE_AVAILABLE and _compare_version("fairscale", operator.ge, "0.3.3") _FAIRSCALE_FULLY_SHARDED_AVAILABLE = _FAIRSCALE_AVAILABLE and _compare_version("fairscale", operator.ge, "0.3.4") diff --git a/tests/tests_pytorch/helpers/runif.py b/tests/tests_pytorch/helpers/runif.py index 6ad86653fb..1fc7ca893b 100644 --- a/tests/tests_pytorch/helpers/runif.py +++ b/tests/tests_pytorch/helpers/runif.py @@ -20,10 +20,10 @@ import torch from packaging.version import Version from pkg_resources import get_distribution +from pytorch_lightning.strategies.deepspeed import _DEEPSPEED_AVAILABLE from pytorch_lightning.utilities.imports import ( _APEX_AVAILABLE, _BAGUA_AVAILABLE, - _DEEPSPEED_AVAILABLE, _FAIRSCALE_AVAILABLE, _FAIRSCALE_FULLY_SHARDED_AVAILABLE, _HIVEMIND_AVAILABLE, diff --git a/tests/tests_pytorch/strategies/test_deepspeed_strategy.py b/tests/tests_pytorch/strategies/test_deepspeed_strategy.py index 601e559ab3..41faee02f3 100644 --- a/tests/tests_pytorch/strategies/test_deepspeed_strategy.py +++ b/tests/tests_pytorch/strategies/test_deepspeed_strategy.py @@ -30,14 +30,11 @@ from pytorch_lightning import LightningDataModule, LightningModule, seed_everyth from pytorch_lightning.callbacks import Callback, LearningRateMonitor, ModelCheckpoint from pytorch_lightning.demos.boring_classes import BoringModel, RandomDataset from pytorch_lightning.plugins import DeepSpeedPrecisionPlugin +from pytorch_lightning.plugins.precision.deepspeed import _DEEPSPEED_GREATER_EQUAL_0_6 from pytorch_lightning.strategies import DeepSpeedStrategy -from pytorch_lightning.strategies.deepspeed import LightningDeepSpeedModule +from pytorch_lightning.strategies.deepspeed import _DEEPSPEED_AVAILABLE, LightningDeepSpeedModule from pytorch_lightning.utilities.exceptions import MisconfigurationException -from pytorch_lightning.utilities.imports import ( - _DEEPSPEED_AVAILABLE, - _DEEPSPEED_GREATER_EQUAL_0_5_9, - _DEEPSPEED_GREATER_EQUAL_0_6, -) +from pytorch_lightning.utilities.imports import _RequirementAvailable from pytorch_lightning.utilities.meta import init_meta_context from tests_pytorch.helpers.datamodules import ClassifDataModule from tests_pytorch.helpers.datasets import RandomIterableDataset @@ -47,6 +44,7 @@ if _DEEPSPEED_AVAILABLE: import deepspeed from deepspeed.utils.zero_to_fp32 import convert_zero_checkpoint_to_fp32_state_dict + _DEEPSPEED_GREATER_EQUAL_0_5_9 = _RequirementAvailable("deepspeed>=0.5.9") if _DEEPSPEED_GREATER_EQUAL_0_5_9: from deepspeed.runtime.zero.stage_1_and_2 import DeepSpeedZeroOptimizer else: diff --git a/tests/tests_pytorch/utilities/test_imports.py b/tests/tests_pytorch/utilities/test_imports.py index 45736ae36f..b72e68740a 100644 --- a/tests/tests_pytorch/utilities/test_imports.py +++ b/tests/tests_pytorch/utilities/test_imports.py @@ -13,10 +13,10 @@ # limitations under the License. import operator +from pytorch_lightning.strategies.deepspeed import _DEEPSPEED_AVAILABLE from pytorch_lightning.utilities import ( _APEX_AVAILABLE, _BAGUA_AVAILABLE, - _DEEPSPEED_AVAILABLE, _FAIRSCALE_AVAILABLE, _HOROVOD_AVAILABLE, _module_available,