Improved Deepspeed Imports (#13223)

Co-authored-by: Carlos Mocholí <carlossmocholi@gmail.com>
Co-authored-by: Jirka <jirka.borovec@seznam.cz>
This commit is contained in:
Atharva Phatak 2022-06-22 11:09:33 -04:00 committed by GitHub
parent 90d2f3e787
commit 63a9ab4ae2
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 20 additions and 21 deletions

View File

@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Callable, Optional, Union
from typing import Any, Callable, Optional, TYPE_CHECKING, Union
from torch import Tensor
from torch.nn import Module
@ -22,12 +22,14 @@ from pytorch_lightning.plugins.precision.precision_plugin import PrecisionPlugin
from pytorch_lightning.utilities import GradClipAlgorithmType
from pytorch_lightning.utilities.enums import PrecisionType
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.imports import _DEEPSPEED_AVAILABLE, _DEEPSPEED_GREATER_EQUAL_0_6
from pytorch_lightning.utilities.imports import _RequirementAvailable
from pytorch_lightning.utilities.model_helpers import is_overridden
from pytorch_lightning.utilities.warnings import WarningCache
if _DEEPSPEED_AVAILABLE:
from deepspeed import DeepSpeedEngine
_DEEPSPEED_GREATER_EQUAL_0_6 = _RequirementAvailable("deepspeed>=0.6.0")
if TYPE_CHECKING:
if pl.strategies.deepspeed._DEEPSPEED_AVAILABLE:
import deepspeed
warning_cache = WarningCache()
@ -75,10 +77,12 @@ class DeepSpeedPrecisionPlugin(PrecisionPlugin):
" the backward logic internally."
)
assert model.trainer is not None
deepspeed_engine: DeepSpeedEngine = model.trainer.model
deepspeed_engine: "deepspeed.DeepSpeedEngine" = model.trainer.model
deepspeed_engine.backward(closure_loss, *args, **kwargs)
def _run_backward(self, tensor: Tensor, model: Optional["DeepSpeedEngine"], *args: Any, **kwargs: Any) -> None:
def _run_backward(
self, tensor: Tensor, model: Optional["deepspeed.DeepSpeedEngine"], *args: Any, **kwargs: Any
) -> None:
if model is None:
raise ValueError("Please provide the model as input to `backward`.")
model.backward(tensor, *args, **kwargs)
@ -104,7 +108,7 @@ class DeepSpeedPrecisionPlugin(PrecisionPlugin):
"Skipping backward by returning `None` from your `training_step` is not supported by `DeepSpeed`"
)
# DeepSpeed handles the optimizer step internally
deepspeed_engine: DeepSpeedEngine
deepspeed_engine: "deepspeed.DeepSpeedEngine"
if isinstance(model, pl.LightningModule):
assert model.trainer is not None
deepspeed_engine = model.trainer.model

View File

@ -43,7 +43,7 @@ from pytorch_lightning.utilities.distributed import (
)
from pytorch_lightning.utilities.enums import AMPType, PrecisionType
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.imports import _DEEPSPEED_AVAILABLE
from pytorch_lightning.utilities.imports import _RequirementAvailable
from pytorch_lightning.utilities.model_helpers import is_overridden
from pytorch_lightning.utilities.optimizer import optimizers_to_device
from pytorch_lightning.utilities.rank_zero import rank_zero_info
@ -53,6 +53,7 @@ from pytorch_lightning.utilities.warnings import rank_zero_warn, WarningCache
warning_cache = WarningCache()
_DEEPSPEED_AVAILABLE: bool = _RequirementAvailable("deepspeed")
if _DEEPSPEED_AVAILABLE:
import deepspeed

View File

@ -29,7 +29,6 @@ from pytorch_lightning.utilities.grads import grad_norm # noqa: F401
from pytorch_lightning.utilities.imports import ( # noqa: F401
_APEX_AVAILABLE,
_BAGUA_AVAILABLE,
_DEEPSPEED_AVAILABLE,
_FAIRSCALE_AVAILABLE,
_FAIRSCALE_FULLY_SHARDED_AVAILABLE,
_FAIRSCALE_OSS_FP16_BROADCAST_AVAILABLE,

View File

@ -20,7 +20,7 @@ import os
import torch
from pytorch_lightning.utilities import _DEEPSPEED_AVAILABLE
from pytorch_lightning.strategies.deepspeed import _DEEPSPEED_AVAILABLE
from pytorch_lightning.utilities.types import _PATH
if _DEEPSPEED_AVAILABLE:

View File

@ -133,9 +133,6 @@ _TORCH_GREATER_EQUAL_1_12 = _compare_version("torch", operator.ge, "1.12.0", use
_APEX_AVAILABLE = _module_available("apex.amp")
_BAGUA_AVAILABLE = _package_available("bagua")
_DALI_AVAILABLE = _module_available("nvidia.dali")
_DEEPSPEED_AVAILABLE = _package_available("deepspeed")
_DEEPSPEED_GREATER_EQUAL_0_5_9 = _DEEPSPEED_AVAILABLE and _compare_version("deepspeed", operator.ge, "0.5.9")
_DEEPSPEED_GREATER_EQUAL_0_6 = _DEEPSPEED_AVAILABLE and _compare_version("deepspeed", operator.ge, "0.6.0")
_FAIRSCALE_AVAILABLE = not _IS_WINDOWS and _module_available("fairscale.nn")
_FAIRSCALE_OSS_FP16_BROADCAST_AVAILABLE = _FAIRSCALE_AVAILABLE and _compare_version("fairscale", operator.ge, "0.3.3")
_FAIRSCALE_FULLY_SHARDED_AVAILABLE = _FAIRSCALE_AVAILABLE and _compare_version("fairscale", operator.ge, "0.3.4")

View File

@ -20,10 +20,10 @@ import torch
from packaging.version import Version
from pkg_resources import get_distribution
from pytorch_lightning.strategies.deepspeed import _DEEPSPEED_AVAILABLE
from pytorch_lightning.utilities.imports import (
_APEX_AVAILABLE,
_BAGUA_AVAILABLE,
_DEEPSPEED_AVAILABLE,
_FAIRSCALE_AVAILABLE,
_FAIRSCALE_FULLY_SHARDED_AVAILABLE,
_HIVEMIND_AVAILABLE,

View File

@ -30,14 +30,11 @@ from pytorch_lightning import LightningDataModule, LightningModule, seed_everyth
from pytorch_lightning.callbacks import Callback, LearningRateMonitor, ModelCheckpoint
from pytorch_lightning.demos.boring_classes import BoringModel, RandomDataset
from pytorch_lightning.plugins import DeepSpeedPrecisionPlugin
from pytorch_lightning.plugins.precision.deepspeed import _DEEPSPEED_GREATER_EQUAL_0_6
from pytorch_lightning.strategies import DeepSpeedStrategy
from pytorch_lightning.strategies.deepspeed import LightningDeepSpeedModule
from pytorch_lightning.strategies.deepspeed import _DEEPSPEED_AVAILABLE, LightningDeepSpeedModule
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.imports import (
_DEEPSPEED_AVAILABLE,
_DEEPSPEED_GREATER_EQUAL_0_5_9,
_DEEPSPEED_GREATER_EQUAL_0_6,
)
from pytorch_lightning.utilities.imports import _RequirementAvailable
from pytorch_lightning.utilities.meta import init_meta_context
from tests_pytorch.helpers.datamodules import ClassifDataModule
from tests_pytorch.helpers.datasets import RandomIterableDataset
@ -47,6 +44,7 @@ if _DEEPSPEED_AVAILABLE:
import deepspeed
from deepspeed.utils.zero_to_fp32 import convert_zero_checkpoint_to_fp32_state_dict
_DEEPSPEED_GREATER_EQUAL_0_5_9 = _RequirementAvailable("deepspeed>=0.5.9")
if _DEEPSPEED_GREATER_EQUAL_0_5_9:
from deepspeed.runtime.zero.stage_1_and_2 import DeepSpeedZeroOptimizer
else:

View File

@ -13,10 +13,10 @@
# limitations under the License.
import operator
from pytorch_lightning.strategies.deepspeed import _DEEPSPEED_AVAILABLE
from pytorch_lightning.utilities import (
_APEX_AVAILABLE,
_BAGUA_AVAILABLE,
_DEEPSPEED_AVAILABLE,
_FAIRSCALE_AVAILABLE,
_HOROVOD_AVAILABLE,
_module_available,