fix progress bar restart with fault-tolerant training enabled (#9310)

* reset progress updates
* update docs
* add test

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
Adrian Wälchli 2021-09-06 10:43:59 +02:00 committed by GitHub
parent f9132e8db6
commit 50198d7483
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 41 additions and 7 deletions

View File

@ -297,6 +297,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Fixed `move_metrics_to_cpu` moving the loss on cpu while training on device ([#9308](https://github.com/PyTorchLightning/pytorch-lightning/pull/9308))
- Fixed incorrect main progress bar indicator when resuming training mid-epoch ([#9310](https://github.com/PyTorchLightning/pytorch-lightning/pull/9310))
## [1.4.5] - 2021-08-31
- Fixed reduction using `self.log(sync_dict=True, reduce_fx={mean,max})` ([#9142](https://github.com/PyTorchLightning/pytorch-lightning/pull/9142))

View File

@ -154,10 +154,10 @@ class ProgressBarBase(Callback):
self._trainer = trainer
def on_train_start(self, trainer, pl_module):
self._train_batch_idx = trainer.fit_loop.batch_idx
self._train_batch_idx = 0
def on_train_epoch_start(self, trainer, pl_module):
self._train_batch_idx = 0
self._train_batch_idx = trainer.fit_loop.epoch_loop.batch_progress.current.completed
def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx):
self._train_batch_idx += 1

View File

@ -229,7 +229,7 @@ class ProgressBar(ProgressBarBase):
val_checks_per_epoch = total_train_batches // trainer.val_check_batch
total_val_batches = total_val_batches * val_checks_per_epoch
total_batches = total_train_batches + total_val_batches
reset(self.main_progress_bar, total_batches)
reset(self.main_progress_bar, total=total_batches, current=self.train_batch_idx)
self.main_progress_bar.set_description(f"Epoch {trainer.current_epoch}")
def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx):
@ -243,11 +243,11 @@ class ProgressBar(ProgressBarBase):
def on_validation_start(self, trainer, pl_module):
super().on_validation_start(trainer, pl_module)
if trainer.sanity_checking:
reset(self.val_progress_bar, sum(trainer.num_sanity_val_batches))
reset(self.val_progress_bar, total=sum(trainer.num_sanity_val_batches), current=self.val_batch_idx)
else:
self._update_bar(self.main_progress_bar) # fill up remaining
self.val_progress_bar = self.init_validation_tqdm()
reset(self.val_progress_bar, self.total_val_batches)
reset(self.val_progress_bar, total=self.total_val_batches, current=self.val_batch_idx)
def on_validation_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx):
super().on_validation_batch_end(trainer, pl_module, outputs, batch, batch_idx, dataloader_idx)
@ -333,7 +333,8 @@ def convert_inf(x: Optional[Union[int, float]]) -> Optional[Union[int, float]]:
return x
def reset(bar: tqdm, total: Optional[int] = None) -> None:
"""Resets the tqdm bar to 0 progress with a new total, unless it is disabled."""
def reset(bar: tqdm, total: Optional[int] = None, current: int = 0) -> None:
"""Resets the tqdm bar to the desired position and sets a new total, unless it is disabled."""
if not bar.disable:
bar.reset(total=convert_inf(total))
bar.n = current

View File

@ -558,3 +558,33 @@ def _test_progress_bar_max_val_check_interval(
total_val_batches = total_val_batches * val_checks_per_epoch
if trainer.is_global_zero:
assert trainer.progress_bar_callback.main_progress_bar.total == total_train_batches + total_val_batches
def test_progress_bar_main_bar_resume():
"""Test that the progress bar can resume its counters based on the Trainer state."""
bar = ProgressBar()
trainer = Mock()
model = Mock()
trainer.sanity_checking = False
trainer.check_val_every_n_epoch = 1
trainer.current_epoch = 1
trainer.num_training_batches = 5
trainer.val_check_batch = 5
trainer.num_val_batches = [3]
trainer.fit_loop.epoch_loop.batch_progress.current.completed = 3
bar.on_init_end(trainer)
bar.on_train_start(trainer, model)
bar.on_train_epoch_start(trainer, model)
assert bar.main_progress_bar.n == 3
assert bar.main_progress_bar.total == 8
# bar.on_train_epoch_end(trainer, model)
bar.on_validation_start(trainer, model)
bar.on_validation_epoch_start(trainer, model)
# restarting mid validation epoch is not currently supported
assert bar.val_progress_bar.n == 0
assert bar.val_progress_bar.total == 3