[BUG] `estimated_stepping_batches` requires distributed comms in `configure_optimizers` for `DeepSpeedStrategy` (#13350)
This commit is contained in:
parent
749709fb4f
commit
89e2e69b01
|
@ -238,6 +238,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
|
|||
- Fixed issue where the CLI fails with certain torch objects ([#13153](https://github.com/PyTorchLightning/pytorch-lightning/pull/13153))
|
||||
|
||||
|
||||
- Fixed `estimated_stepping_batches` requiring distributed comms in `configure_optimizers` for the `DeepSpeedStrategy` ([#13350](https://github.com/PyTorchLightning/pytorch-lightning/pull/13350))
|
||||
|
||||
|
||||
-
|
||||
|
||||
|
||||
|
|
|
@ -357,6 +357,8 @@ class DeepSpeedStrategy(DDPStrategy):
|
|||
|
||||
def setup(self, trainer: "pl.Trainer") -> None:
|
||||
self.accelerator.setup(trainer)
|
||||
# we set the device so that optimizers can be created with distributed comms.
|
||||
self.lightning_module._device = self.root_device
|
||||
self.setup_optimizers(trainer)
|
||||
self.setup_precision_plugin()
|
||||
optimizers_to_device(self.optimizers, self.root_device)
|
||||
|
|
|
@ -1337,3 +1337,26 @@ def test_error_with_invalid_accelerator(tmpdir):
|
|||
model = BoringModel()
|
||||
with pytest.raises(MisconfigurationException, match="DeepSpeed strategy is only supported on GPU"):
|
||||
trainer.fit(model)
|
||||
|
||||
|
||||
@RunIf(min_cuda_gpus=2, deepspeed=True, standalone=True)
|
||||
def test_deepspeed_configure_optimizer_device_set(tmpdir):
|
||||
"""Test to ensure that the LM has access to the device within the ``configure_optimizer`` function, and
|
||||
estimated_stepping_batches works correctly as a result."""
|
||||
|
||||
class TestModel(BoringModel):
|
||||
def configure_optimizers(self):
|
||||
assert self.trainer.estimated_stepping_batches == 1
|
||||
assert self.device.type == "cuda"
|
||||
raise SystemExit
|
||||
|
||||
model = TestModel()
|
||||
trainer = Trainer(
|
||||
default_root_dir=tmpdir,
|
||||
fast_dev_run=True,
|
||||
accelerator="gpu",
|
||||
devices=2,
|
||||
strategy=DeepSpeedStrategy(),
|
||||
)
|
||||
with pytest.raises(SystemExit):
|
||||
trainer.fit(model)
|
||||
|
|
Loading…
Reference in New Issue