From 9e6c02a98a867d13005d6cd8841dd518ca2b6096 Mon Sep 17 00:00:00 2001 From: Leon Lin Date: Mon, 8 Jul 2024 15:45:17 +0800 Subject: [PATCH] Update DeepSpeedOptimizer import for deepspeed >= 0.14.1 (#20040) --- src/lightning/fabric/strategies/deepspeed.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/src/lightning/fabric/strategies/deepspeed.py b/src/lightning/fabric/strategies/deepspeed.py index 07e01e7674..93a17f10c8 100644 --- a/src/lightning/fabric/strategies/deepspeed.py +++ b/src/lightning/fabric/strategies/deepspeed.py @@ -43,6 +43,7 @@ if TYPE_CHECKING: from deepspeed import DeepSpeedEngine _DEEPSPEED_AVAILABLE = RequirementCache("deepspeed") +_DEEPSPEED_GREATER_EQUAL_0_14_1 = RequirementCache("deepspeed>=0.14.1") # TODO(fabric): Links in the docstrings to PL-specific deepspeed user docs need to be replaced. @@ -498,7 +499,10 @@ class DeepSpeedStrategy(DDPStrategy, _Sharded): ) engine = engines[0] - from deepspeed.runtime import DeepSpeedOptimizer + if _DEEPSPEED_GREATER_EQUAL_0_14_1: + from deepspeed.runtime.base_optimizer import DeepSpeedOptimizer + else: + from deepspeed.runtime import DeepSpeedOptimizer optimzer_state_requested = any(isinstance(item, (Optimizer, DeepSpeedOptimizer)) for item in state.values())