Update DeepSpeedOptimizer import for deepspeed >= 0.14.1 (#20040)

This commit is contained in:
Leon Lin 2024-07-08 15:45:17 +08:00 committed by GitHub
parent a6562b4ae7
commit 9e6c02a98a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
1 changed files with 5 additions and 1 deletions

View File

@ -43,6 +43,7 @@ if TYPE_CHECKING:
from deepspeed import DeepSpeedEngine
_DEEPSPEED_AVAILABLE = RequirementCache("deepspeed")
_DEEPSPEED_GREATER_EQUAL_0_14_1 = RequirementCache("deepspeed>=0.14.1")
# TODO(fabric): Links in the docstrings to PL-specific deepspeed user docs need to be replaced.
@ -498,7 +499,10 @@ class DeepSpeedStrategy(DDPStrategy, _Sharded):
)
engine = engines[0]
from deepspeed.runtime import DeepSpeedOptimizer
if _DEEPSPEED_GREATER_EQUAL_0_14_1:
from deepspeed.runtime.base_optimizer import DeepSpeedOptimizer
else:
from deepspeed.runtime import DeepSpeedOptimizer
optimzer_state_requested = any(isinstance(item, (Optimizer, DeepSpeedOptimizer)) for item in state.values())