Fix lr scheduler state not being dumped to checkpoint in deepspeed strategy (#11307)
This commit is contained in:
parent
7eab379da2
commit
a8bd7ac73f
|
@ -403,6 +403,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
|
|||
- Fixed a race condition that could result in incorrect (zero) values being observed in prediction writer callbacks ([#11288](https://github.com/PyTorchLightning/pytorch-lightning/pull/11288))
|
||||
|
||||
|
||||
- Fixed the lr-scheduler state not being dumped to checkpoint when using the deepspeed strategy ([#11307](https://github.com/PyTorchLightning/pytorch-lightning/pull/11307))
|
||||
|
||||
|
||||
## [1.5.7] - 2021-12-21
|
||||
|
||||
### Fixed
|
||||
|
|
|
@ -741,7 +741,7 @@ class DeepSpeedStrategy(DDPStrategy):
|
|||
)
|
||||
# Use deepspeed's internal checkpointing function to handle partitioned weights across processes
|
||||
# dump states as a checkpoint dictionary object
|
||||
_exclude_keys = ["state_dict", "optimizer_states", "lr_schedulers"]
|
||||
_exclude_keys = ["state_dict", "optimizer_states"]
|
||||
checkpoint = {k: v for k, v in checkpoint.items() if k not in _exclude_keys}
|
||||
self.deepspeed_engine.save_checkpoint(filepath, client_state=checkpoint)
|
||||
|
||||
|
|
|
@ -558,6 +558,10 @@ class ModelParallelClassificationModel(LightningModule):
|
|||
if not hasattr(self, "model"):
|
||||
self.configure_sharded_model()
|
||||
|
||||
# Lightning saves the lr schedulers, but DeepSpeed saves the optimizer states separately
|
||||
assert len(checkpoint["lr_schedulers"]) == 1
|
||||
assert "optimizer_states" not in checkpoint
|
||||
|
||||
|
||||
class ManualModelParallelClassificationModel(ModelParallelClassificationModel):
|
||||
@property
|
||||
|
|
Loading…
Reference in New Issue