Fix lr scheduler state not being dumped to checkpoint in deepspeed strategy (#11307)

This commit is contained in:
Adrian Wälchli 2022-01-05 09:38:08 +01:00 committed by GitHub
parent 7eab379da2
commit a8bd7ac73f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 8 additions and 1 deletions

View File

@ -403,6 +403,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Fixed a race condition that could result in incorrect (zero) values being observed in prediction writer callbacks ([#11288](https://github.com/PyTorchLightning/pytorch-lightning/pull/11288))
- Fixed the lr-scheduler state not being dumped to checkpoint when using the deepspeed strategy ([#11307](https://github.com/PyTorchLightning/pytorch-lightning/pull/11307))
## [1.5.7] - 2021-12-21
### Fixed

View File

@ -741,7 +741,7 @@ class DeepSpeedStrategy(DDPStrategy):
)
# Use deepspeed's internal checkpointing function to handle partitioned weights across processes
# dump states as a checkpoint dictionary object
_exclude_keys = ["state_dict", "optimizer_states", "lr_schedulers"]
_exclude_keys = ["state_dict", "optimizer_states"]
checkpoint = {k: v for k, v in checkpoint.items() if k not in _exclude_keys}
self.deepspeed_engine.save_checkpoint(filepath, client_state=checkpoint)

View File

@ -558,6 +558,10 @@ class ModelParallelClassificationModel(LightningModule):
if not hasattr(self, "model"):
self.configure_sharded_model()
# Lightning saves the lr schedulers, but DeepSpeed saves the optimizer states separately
assert len(checkpoint["lr_schedulers"]) == 1
assert "optimizer_states" not in checkpoint
class ManualModelParallelClassificationModel(ModelParallelClassificationModel):
@property