[BUG] `estimated_stepping_batches` requires distributed comms in `configure_optimizers` for `DeepSpeedStrategy` (#13350)

This commit is contained in:
Sean Naren 2022-06-21 17:48:27 +01:00 committed by GitHub
parent 749709fb4f
commit 89e2e69b01
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 28 additions and 0 deletions

View File

@ -238,6 +238,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Fixed issue where the CLI fails with certain torch objects ([#13153](https://github.com/PyTorchLightning/pytorch-lightning/pull/13153))
- Fixed `estimated_stepping_batches` requiring distributed comms in `configure_optimizers` for the `DeepSpeedStrategy` ([#13350](https://github.com/PyTorchLightning/pytorch-lightning/pull/13350))
-

View File

@ -357,6 +357,8 @@ class DeepSpeedStrategy(DDPStrategy):
def setup(self, trainer: "pl.Trainer") -> None:
self.accelerator.setup(trainer)
# we set the device so that optimizers can be created with distributed comms.
self.lightning_module._device = self.root_device
self.setup_optimizers(trainer)
self.setup_precision_plugin()
optimizers_to_device(self.optimizers, self.root_device)

View File

@ -1337,3 +1337,26 @@ def test_error_with_invalid_accelerator(tmpdir):
model = BoringModel()
with pytest.raises(MisconfigurationException, match="DeepSpeed strategy is only supported on GPU"):
trainer.fit(model)
@RunIf(min_cuda_gpus=2, deepspeed=True, standalone=True)
def test_deepspeed_configure_optimizer_device_set(tmpdir):
"""Test to ensure that the LM has access to the device within the ``configure_optimizer`` function, and
estimated_stepping_batches works correctly as a result."""
class TestModel(BoringModel):
def configure_optimizers(self):
assert self.trainer.estimated_stepping_batches == 1
assert self.device.type == "cuda"
raise SystemExit
model = TestModel()
trainer = Trainer(
default_root_dir=tmpdir,
fast_dev_run=True,
accelerator="gpu",
devices=2,
strategy=DeepSpeedStrategy(),
)
with pytest.raises(SystemExit):
trainer.fit(model)