Support DDP with LRFinder (#15304)

* Support DDP for LRFinder
* Apply suggestions from code review
* rank 0 is the decision maker

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com>
This commit is contained in:
Rohit Gupta 2022-11-08 23:25:15 +05:30 committed by GitHub
parent 750f62f6c3
commit 1a8f2e8516
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 42 additions and 3 deletions

View File

@ -270,7 +270,13 @@ initial learning rate.
.. warning::
For the moment, this feature only works with models having a single optimizer.
LR Finder support for DDP and any of its variations is not implemented yet. It is coming soon.
.. note::
With DDP: Since all the processes run in isolation, only process with ``global_rank=0`` will make the decision to stop the
learning rate finder and broadcast its results to all other ranks. That means, at the end of LR finder, each process will be running with
the learning rate found on ``global_rank=0``.
Using Lightning's built-in LR finder

View File

@ -10,6 +10,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
### Added
- Added support for DDP with `LRFinder` ([#15304](https://github.com/Lightning-AI/lightning/pull/15304))
- Added utilities to migrate checkpoints from one Lightning version to another ([#15237](https://github.com/Lightning-AI/lightning/pull/15237))
- Added back the accidentally removed `pytorch_lightning.utilities.distributed.rank_zero_only` function ([#15536](https://github.com/Lightning-AI/lightning/pull/15536))

View File

@ -217,6 +217,7 @@ def lr_find(
# Save initial model, that is loaded after learning rate is found
ckpt_path = os.path.join(trainer.default_root_dir, f".lr_find_{uuid.uuid4()}.ckpt")
ckpt_path = trainer.strategy.broadcast(ckpt_path)
trainer.save_checkpoint(ckpt_path)
# Arguments we adjust during the lr finder, save for restoring
@ -252,6 +253,7 @@ def lr_find(
trainer.progress_bar_callback.enable()
# Update lr attr if required
lr_finder.results = trainer.strategy.broadcast(lr_finder.results)
if update_attr:
lr = lr_finder.suggestion()
@ -311,6 +313,7 @@ def __lr_finder_restore_params(trainer: "pl.Trainer", params: Dict[str, Any]) ->
loop = trainer.fit_loop
loop.load_state_dict(deepcopy(params["loop_state_dict"]))
loop.restarting = False
trainer.should_stop = False
class _LRCallback(Callback):
@ -380,10 +383,12 @@ class _LRCallback(Callback):
# Check if we diverging
if self.early_stop_threshold is not None:
if current_step > 1 and smoothed_loss > self.early_stop_threshold * self.best_loss:
trainer.fit_loop.max_steps = current_step # stop signal
trainer.should_stop = True # stop signal
if self.progress_bar:
self.progress_bar.close()
trainer.should_stop = trainer.strategy.broadcast(trainer.should_stop)
# Save best loss for diverging checking
if smoothed_loss < self.best_loss or current_step == 1:
self.best_loss = smoothed_loss

View File

@ -59,7 +59,7 @@ class Tuner:
self.trainer.strategy.connect(model)
is_tuning = self.trainer.auto_scale_batch_size or self.trainer.auto_lr_find
is_tuning = self.trainer.auto_scale_batch_size
if self.trainer._accelerator_connector.is_distributed and is_tuning:
raise MisconfigurationException(
"`trainer.tune()` is currently not supported with"

View File

@ -93,6 +93,7 @@ def test_trainer_reset_correctly(tmpdir):
"max_steps",
"fit_loop.max_steps",
"strategy.setup_optimizers",
"should_stop",
]
expected = {ca: getattr_recursive(trainer, ca) for ca in changed_attributes}
@ -438,3 +439,27 @@ def test_if_lr_finder_callback_already_configured():
with pytest.raises(MisconfigurationException, match="Trainer is already configured with a .* callback"):
trainer.tune(model)
@mock.patch.dict(os.environ, os.environ.copy(), clear=True)
@RunIf(standalone=True)
def test_lr_finder_with_ddp(tmpdir):
seed_everything(7)
init_lr = 1e-4
dm = ClassifDataModule()
model = ClassificationModel(lr=init_lr)
trainer = Trainer(
default_root_dir=tmpdir,
max_epochs=1,
strategy="ddp",
devices=2,
accelerator="cpu",
)
trainer.tuner.lr_find(model, datamodule=dm, update_attr=True, num_training=20)
lr = trainer.lightning_module.lr
lr = trainer.strategy.broadcast(lr)
assert trainer.lightning_module.lr == lr
assert lr != init_lr