From a8bd7ac73f3e4afc0d53c2aad1b295ad4942201c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Wed, 5 Jan 2022 09:38:08 +0100 Subject: [PATCH] Fix lr scheduler state not being dumped to checkpoint in deepspeed strategy (#11307) --- CHANGELOG.md | 3 +++ pytorch_lightning/strategies/deepspeed.py | 2 +- tests/strategies/test_deepspeed_strategy.py | 4 ++++ 3 files changed, 8 insertions(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 19111ecb3d..c7950d2c57 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -403,6 +403,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fixed a race condition that could result in incorrect (zero) values being observed in prediction writer callbacks ([#11288](https://github.com/PyTorchLightning/pytorch-lightning/pull/11288)) +- Fixed the lr-scheduler state not being dumped to checkpoint when using the deepspeed strategy ([#11307](https://github.com/PyTorchLightning/pytorch-lightning/pull/11307)) + + ## [1.5.7] - 2021-12-21 ### Fixed diff --git a/pytorch_lightning/strategies/deepspeed.py b/pytorch_lightning/strategies/deepspeed.py index 452f3c8e1a..4cd108a998 100644 --- a/pytorch_lightning/strategies/deepspeed.py +++ b/pytorch_lightning/strategies/deepspeed.py @@ -741,7 +741,7 @@ class DeepSpeedStrategy(DDPStrategy): ) # Use deepspeed's internal checkpointing function to handle partitioned weights across processes # dump states as a checkpoint dictionary object - _exclude_keys = ["state_dict", "optimizer_states", "lr_schedulers"] + _exclude_keys = ["state_dict", "optimizer_states"] checkpoint = {k: v for k, v in checkpoint.items() if k not in _exclude_keys} self.deepspeed_engine.save_checkpoint(filepath, client_state=checkpoint) diff --git a/tests/strategies/test_deepspeed_strategy.py b/tests/strategies/test_deepspeed_strategy.py index fed7f29de6..251a2b3c83 100644 --- a/tests/strategies/test_deepspeed_strategy.py +++ b/tests/strategies/test_deepspeed_strategy.py @@ -558,6 +558,10 @@ class ModelParallelClassificationModel(LightningModule): if not hasattr(self, "model"): self.configure_sharded_model() + # Lightning saves the lr schedulers, but DeepSpeed saves the optimizer states separately + assert len(checkpoint["lr_schedulers"]) == 1 + assert "optimizer_states" not in checkpoint + class ManualModelParallelClassificationModel(ModelParallelClassificationModel): @property