Update DeepSpeedOptimizer import for deepspeed >= 0.14.1 (#20040)
This commit is contained in:
parent
a6562b4ae7
commit
9e6c02a98a
|
@ -43,6 +43,7 @@ if TYPE_CHECKING:
|
|||
from deepspeed import DeepSpeedEngine
|
||||
|
||||
_DEEPSPEED_AVAILABLE = RequirementCache("deepspeed")
|
||||
_DEEPSPEED_GREATER_EQUAL_0_14_1 = RequirementCache("deepspeed>=0.14.1")
|
||||
|
||||
|
||||
# TODO(fabric): Links in the docstrings to PL-specific deepspeed user docs need to be replaced.
|
||||
|
@ -498,7 +499,10 @@ class DeepSpeedStrategy(DDPStrategy, _Sharded):
|
|||
)
|
||||
engine = engines[0]
|
||||
|
||||
from deepspeed.runtime import DeepSpeedOptimizer
|
||||
if _DEEPSPEED_GREATER_EQUAL_0_14_1:
|
||||
from deepspeed.runtime.base_optimizer import DeepSpeedOptimizer
|
||||
else:
|
||||
from deepspeed.runtime import DeepSpeedOptimizer
|
||||
|
||||
optimzer_state_requested = any(isinstance(item, (Optimizer, DeepSpeedOptimizer)) for item in state.values())
|
||||
|
||||
|
|
Loading…
Reference in New Issue