Support DDP with LRFinder (#15304)
* Support DDP for LRFinder * Apply suggestions from code review * rank 0 is the decision maker Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com>
This commit is contained in:
parent
750f62f6c3
commit
1a8f2e8516
|
@ -270,7 +270,13 @@ initial learning rate.
|
|||
.. warning::
|
||||
|
||||
For the moment, this feature only works with models having a single optimizer.
|
||||
LR Finder support for DDP and any of its variations is not implemented yet. It is coming soon.
|
||||
|
||||
|
||||
.. note::
|
||||
|
||||
With DDP: Since all the processes run in isolation, only process with ``global_rank=0`` will make the decision to stop the
|
||||
learning rate finder and broadcast its results to all other ranks. That means, at the end of LR finder, each process will be running with
|
||||
the learning rate found on ``global_rank=0``.
|
||||
|
||||
|
||||
Using Lightning's built-in LR finder
|
||||
|
|
|
@ -10,6 +10,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
|
|||
|
||||
### Added
|
||||
|
||||
- Added support for DDP with `LRFinder` ([#15304](https://github.com/Lightning-AI/lightning/pull/15304))
|
||||
|
||||
|
||||
- Added utilities to migrate checkpoints from one Lightning version to another ([#15237](https://github.com/Lightning-AI/lightning/pull/15237))
|
||||
|
||||
- Added back the accidentally removed `pytorch_lightning.utilities.distributed.rank_zero_only` function ([#15536](https://github.com/Lightning-AI/lightning/pull/15536))
|
||||
|
|
|
@ -217,6 +217,7 @@ def lr_find(
|
|||
|
||||
# Save initial model, that is loaded after learning rate is found
|
||||
ckpt_path = os.path.join(trainer.default_root_dir, f".lr_find_{uuid.uuid4()}.ckpt")
|
||||
ckpt_path = trainer.strategy.broadcast(ckpt_path)
|
||||
trainer.save_checkpoint(ckpt_path)
|
||||
|
||||
# Arguments we adjust during the lr finder, save for restoring
|
||||
|
@ -252,6 +253,7 @@ def lr_find(
|
|||
trainer.progress_bar_callback.enable()
|
||||
|
||||
# Update lr attr if required
|
||||
lr_finder.results = trainer.strategy.broadcast(lr_finder.results)
|
||||
if update_attr:
|
||||
lr = lr_finder.suggestion()
|
||||
|
||||
|
@ -311,6 +313,7 @@ def __lr_finder_restore_params(trainer: "pl.Trainer", params: Dict[str, Any]) ->
|
|||
loop = trainer.fit_loop
|
||||
loop.load_state_dict(deepcopy(params["loop_state_dict"]))
|
||||
loop.restarting = False
|
||||
trainer.should_stop = False
|
||||
|
||||
|
||||
class _LRCallback(Callback):
|
||||
|
@ -380,10 +383,12 @@ class _LRCallback(Callback):
|
|||
# Check if we diverging
|
||||
if self.early_stop_threshold is not None:
|
||||
if current_step > 1 and smoothed_loss > self.early_stop_threshold * self.best_loss:
|
||||
trainer.fit_loop.max_steps = current_step # stop signal
|
||||
trainer.should_stop = True # stop signal
|
||||
if self.progress_bar:
|
||||
self.progress_bar.close()
|
||||
|
||||
trainer.should_stop = trainer.strategy.broadcast(trainer.should_stop)
|
||||
|
||||
# Save best loss for diverging checking
|
||||
if smoothed_loss < self.best_loss or current_step == 1:
|
||||
self.best_loss = smoothed_loss
|
||||
|
|
|
@ -59,7 +59,7 @@ class Tuner:
|
|||
|
||||
self.trainer.strategy.connect(model)
|
||||
|
||||
is_tuning = self.trainer.auto_scale_batch_size or self.trainer.auto_lr_find
|
||||
is_tuning = self.trainer.auto_scale_batch_size
|
||||
if self.trainer._accelerator_connector.is_distributed and is_tuning:
|
||||
raise MisconfigurationException(
|
||||
"`trainer.tune()` is currently not supported with"
|
||||
|
|
|
@ -93,6 +93,7 @@ def test_trainer_reset_correctly(tmpdir):
|
|||
"max_steps",
|
||||
"fit_loop.max_steps",
|
||||
"strategy.setup_optimizers",
|
||||
"should_stop",
|
||||
]
|
||||
expected = {ca: getattr_recursive(trainer, ca) for ca in changed_attributes}
|
||||
|
||||
|
@ -438,3 +439,27 @@ def test_if_lr_finder_callback_already_configured():
|
|||
|
||||
with pytest.raises(MisconfigurationException, match="Trainer is already configured with a .* callback"):
|
||||
trainer.tune(model)
|
||||
|
||||
|
||||
@mock.patch.dict(os.environ, os.environ.copy(), clear=True)
|
||||
@RunIf(standalone=True)
|
||||
def test_lr_finder_with_ddp(tmpdir):
|
||||
seed_everything(7)
|
||||
|
||||
init_lr = 1e-4
|
||||
dm = ClassifDataModule()
|
||||
model = ClassificationModel(lr=init_lr)
|
||||
|
||||
trainer = Trainer(
|
||||
default_root_dir=tmpdir,
|
||||
max_epochs=1,
|
||||
strategy="ddp",
|
||||
devices=2,
|
||||
accelerator="cpu",
|
||||
)
|
||||
|
||||
trainer.tuner.lr_find(model, datamodule=dm, update_attr=True, num_training=20)
|
||||
lr = trainer.lightning_module.lr
|
||||
lr = trainer.strategy.broadcast(lr)
|
||||
assert trainer.lightning_module.lr == lr
|
||||
assert lr != init_lr
|
||||
|
|
Loading…
Reference in New Issue