Improved Deepspeed Imports (#13223)
Co-authored-by: Carlos Mocholí <carlossmocholi@gmail.com> Co-authored-by: Jirka <jirka.borovec@seznam.cz>
This commit is contained in:
parent
90d2f3e787
commit
63a9ab4ae2
|
@ -11,7 +11,7 @@
|
|||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from typing import Any, Callable, Optional, Union
|
||||
from typing import Any, Callable, Optional, TYPE_CHECKING, Union
|
||||
|
||||
from torch import Tensor
|
||||
from torch.nn import Module
|
||||
|
@ -22,12 +22,14 @@ from pytorch_lightning.plugins.precision.precision_plugin import PrecisionPlugin
|
|||
from pytorch_lightning.utilities import GradClipAlgorithmType
|
||||
from pytorch_lightning.utilities.enums import PrecisionType
|
||||
from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
||||
from pytorch_lightning.utilities.imports import _DEEPSPEED_AVAILABLE, _DEEPSPEED_GREATER_EQUAL_0_6
|
||||
from pytorch_lightning.utilities.imports import _RequirementAvailable
|
||||
from pytorch_lightning.utilities.model_helpers import is_overridden
|
||||
from pytorch_lightning.utilities.warnings import WarningCache
|
||||
|
||||
if _DEEPSPEED_AVAILABLE:
|
||||
from deepspeed import DeepSpeedEngine
|
||||
_DEEPSPEED_GREATER_EQUAL_0_6 = _RequirementAvailable("deepspeed>=0.6.0")
|
||||
if TYPE_CHECKING:
|
||||
if pl.strategies.deepspeed._DEEPSPEED_AVAILABLE:
|
||||
import deepspeed
|
||||
|
||||
warning_cache = WarningCache()
|
||||
|
||||
|
@ -75,10 +77,12 @@ class DeepSpeedPrecisionPlugin(PrecisionPlugin):
|
|||
" the backward logic internally."
|
||||
)
|
||||
assert model.trainer is not None
|
||||
deepspeed_engine: DeepSpeedEngine = model.trainer.model
|
||||
deepspeed_engine: "deepspeed.DeepSpeedEngine" = model.trainer.model
|
||||
deepspeed_engine.backward(closure_loss, *args, **kwargs)
|
||||
|
||||
def _run_backward(self, tensor: Tensor, model: Optional["DeepSpeedEngine"], *args: Any, **kwargs: Any) -> None:
|
||||
def _run_backward(
|
||||
self, tensor: Tensor, model: Optional["deepspeed.DeepSpeedEngine"], *args: Any, **kwargs: Any
|
||||
) -> None:
|
||||
if model is None:
|
||||
raise ValueError("Please provide the model as input to `backward`.")
|
||||
model.backward(tensor, *args, **kwargs)
|
||||
|
@ -104,7 +108,7 @@ class DeepSpeedPrecisionPlugin(PrecisionPlugin):
|
|||
"Skipping backward by returning `None` from your `training_step` is not supported by `DeepSpeed`"
|
||||
)
|
||||
# DeepSpeed handles the optimizer step internally
|
||||
deepspeed_engine: DeepSpeedEngine
|
||||
deepspeed_engine: "deepspeed.DeepSpeedEngine"
|
||||
if isinstance(model, pl.LightningModule):
|
||||
assert model.trainer is not None
|
||||
deepspeed_engine = model.trainer.model
|
||||
|
|
|
@ -43,7 +43,7 @@ from pytorch_lightning.utilities.distributed import (
|
|||
)
|
||||
from pytorch_lightning.utilities.enums import AMPType, PrecisionType
|
||||
from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
||||
from pytorch_lightning.utilities.imports import _DEEPSPEED_AVAILABLE
|
||||
from pytorch_lightning.utilities.imports import _RequirementAvailable
|
||||
from pytorch_lightning.utilities.model_helpers import is_overridden
|
||||
from pytorch_lightning.utilities.optimizer import optimizers_to_device
|
||||
from pytorch_lightning.utilities.rank_zero import rank_zero_info
|
||||
|
@ -53,6 +53,7 @@ from pytorch_lightning.utilities.warnings import rank_zero_warn, WarningCache
|
|||
|
||||
warning_cache = WarningCache()
|
||||
|
||||
_DEEPSPEED_AVAILABLE: bool = _RequirementAvailable("deepspeed")
|
||||
if _DEEPSPEED_AVAILABLE:
|
||||
import deepspeed
|
||||
|
||||
|
|
|
@ -29,7 +29,6 @@ from pytorch_lightning.utilities.grads import grad_norm # noqa: F401
|
|||
from pytorch_lightning.utilities.imports import ( # noqa: F401
|
||||
_APEX_AVAILABLE,
|
||||
_BAGUA_AVAILABLE,
|
||||
_DEEPSPEED_AVAILABLE,
|
||||
_FAIRSCALE_AVAILABLE,
|
||||
_FAIRSCALE_FULLY_SHARDED_AVAILABLE,
|
||||
_FAIRSCALE_OSS_FP16_BROADCAST_AVAILABLE,
|
||||
|
|
|
@ -20,7 +20,7 @@ import os
|
|||
|
||||
import torch
|
||||
|
||||
from pytorch_lightning.utilities import _DEEPSPEED_AVAILABLE
|
||||
from pytorch_lightning.strategies.deepspeed import _DEEPSPEED_AVAILABLE
|
||||
from pytorch_lightning.utilities.types import _PATH
|
||||
|
||||
if _DEEPSPEED_AVAILABLE:
|
||||
|
|
|
@ -133,9 +133,6 @@ _TORCH_GREATER_EQUAL_1_12 = _compare_version("torch", operator.ge, "1.12.0", use
|
|||
_APEX_AVAILABLE = _module_available("apex.amp")
|
||||
_BAGUA_AVAILABLE = _package_available("bagua")
|
||||
_DALI_AVAILABLE = _module_available("nvidia.dali")
|
||||
_DEEPSPEED_AVAILABLE = _package_available("deepspeed")
|
||||
_DEEPSPEED_GREATER_EQUAL_0_5_9 = _DEEPSPEED_AVAILABLE and _compare_version("deepspeed", operator.ge, "0.5.9")
|
||||
_DEEPSPEED_GREATER_EQUAL_0_6 = _DEEPSPEED_AVAILABLE and _compare_version("deepspeed", operator.ge, "0.6.0")
|
||||
_FAIRSCALE_AVAILABLE = not _IS_WINDOWS and _module_available("fairscale.nn")
|
||||
_FAIRSCALE_OSS_FP16_BROADCAST_AVAILABLE = _FAIRSCALE_AVAILABLE and _compare_version("fairscale", operator.ge, "0.3.3")
|
||||
_FAIRSCALE_FULLY_SHARDED_AVAILABLE = _FAIRSCALE_AVAILABLE and _compare_version("fairscale", operator.ge, "0.3.4")
|
||||
|
|
|
@ -20,10 +20,10 @@ import torch
|
|||
from packaging.version import Version
|
||||
from pkg_resources import get_distribution
|
||||
|
||||
from pytorch_lightning.strategies.deepspeed import _DEEPSPEED_AVAILABLE
|
||||
from pytorch_lightning.utilities.imports import (
|
||||
_APEX_AVAILABLE,
|
||||
_BAGUA_AVAILABLE,
|
||||
_DEEPSPEED_AVAILABLE,
|
||||
_FAIRSCALE_AVAILABLE,
|
||||
_FAIRSCALE_FULLY_SHARDED_AVAILABLE,
|
||||
_HIVEMIND_AVAILABLE,
|
||||
|
|
|
@ -30,14 +30,11 @@ from pytorch_lightning import LightningDataModule, LightningModule, seed_everyth
|
|||
from pytorch_lightning.callbacks import Callback, LearningRateMonitor, ModelCheckpoint
|
||||
from pytorch_lightning.demos.boring_classes import BoringModel, RandomDataset
|
||||
from pytorch_lightning.plugins import DeepSpeedPrecisionPlugin
|
||||
from pytorch_lightning.plugins.precision.deepspeed import _DEEPSPEED_GREATER_EQUAL_0_6
|
||||
from pytorch_lightning.strategies import DeepSpeedStrategy
|
||||
from pytorch_lightning.strategies.deepspeed import LightningDeepSpeedModule
|
||||
from pytorch_lightning.strategies.deepspeed import _DEEPSPEED_AVAILABLE, LightningDeepSpeedModule
|
||||
from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
||||
from pytorch_lightning.utilities.imports import (
|
||||
_DEEPSPEED_AVAILABLE,
|
||||
_DEEPSPEED_GREATER_EQUAL_0_5_9,
|
||||
_DEEPSPEED_GREATER_EQUAL_0_6,
|
||||
)
|
||||
from pytorch_lightning.utilities.imports import _RequirementAvailable
|
||||
from pytorch_lightning.utilities.meta import init_meta_context
|
||||
from tests_pytorch.helpers.datamodules import ClassifDataModule
|
||||
from tests_pytorch.helpers.datasets import RandomIterableDataset
|
||||
|
@ -47,6 +44,7 @@ if _DEEPSPEED_AVAILABLE:
|
|||
import deepspeed
|
||||
from deepspeed.utils.zero_to_fp32 import convert_zero_checkpoint_to_fp32_state_dict
|
||||
|
||||
_DEEPSPEED_GREATER_EQUAL_0_5_9 = _RequirementAvailable("deepspeed>=0.5.9")
|
||||
if _DEEPSPEED_GREATER_EQUAL_0_5_9:
|
||||
from deepspeed.runtime.zero.stage_1_and_2 import DeepSpeedZeroOptimizer
|
||||
else:
|
||||
|
|
|
@ -13,10 +13,10 @@
|
|||
# limitations under the License.
|
||||
import operator
|
||||
|
||||
from pytorch_lightning.strategies.deepspeed import _DEEPSPEED_AVAILABLE
|
||||
from pytorch_lightning.utilities import (
|
||||
_APEX_AVAILABLE,
|
||||
_BAGUA_AVAILABLE,
|
||||
_DEEPSPEED_AVAILABLE,
|
||||
_FAIRSCALE_AVAILABLE,
|
||||
_HOROVOD_AVAILABLE,
|
||||
_module_available,
|
||||
|
|
Loading…
Reference in New Issue