diff --git a/src/lightning/fabric/strategies/deepspeed.py b/src/lightning/fabric/strategies/deepspeed.py index 07e01e7674..93a17f10c8 100644 --- a/src/lightning/fabric/strategies/deepspeed.py +++ b/src/lightning/fabric/strategies/deepspeed.py @@ -43,6 +43,7 @@ if TYPE_CHECKING: from deepspeed import DeepSpeedEngine _DEEPSPEED_AVAILABLE = RequirementCache("deepspeed") +_DEEPSPEED_GREATER_EQUAL_0_14_1 = RequirementCache("deepspeed>=0.14.1") # TODO(fabric): Links in the docstrings to PL-specific deepspeed user docs need to be replaced. @@ -498,7 +499,10 @@ class DeepSpeedStrategy(DDPStrategy, _Sharded): ) engine = engines[0] - from deepspeed.runtime import DeepSpeedOptimizer + if _DEEPSPEED_GREATER_EQUAL_0_14_1: + from deepspeed.runtime.base_optimizer import DeepSpeedOptimizer + else: + from deepspeed.runtime import DeepSpeedOptimizer optimzer_state_requested = any(isinstance(item, (Optimizer, DeepSpeedOptimizer)) for item in state.values())