fix progress bar restart with fault-tolerant training enabled (#9310)
* reset progress updates * update docs * add test Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
parent
f9132e8db6
commit
50198d7483
|
@ -297,6 +297,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
|
|||
- Fixed `move_metrics_to_cpu` moving the loss on cpu while training on device ([#9308](https://github.com/PyTorchLightning/pytorch-lightning/pull/9308))
|
||||
|
||||
|
||||
- Fixed incorrect main progress bar indicator when resuming training mid-epoch ([#9310](https://github.com/PyTorchLightning/pytorch-lightning/pull/9310))
|
||||
|
||||
|
||||
## [1.4.5] - 2021-08-31
|
||||
|
||||
- Fixed reduction using `self.log(sync_dict=True, reduce_fx={mean,max})` ([#9142](https://github.com/PyTorchLightning/pytorch-lightning/pull/9142))
|
||||
|
|
|
@ -154,10 +154,10 @@ class ProgressBarBase(Callback):
|
|||
self._trainer = trainer
|
||||
|
||||
def on_train_start(self, trainer, pl_module):
|
||||
self._train_batch_idx = trainer.fit_loop.batch_idx
|
||||
self._train_batch_idx = 0
|
||||
|
||||
def on_train_epoch_start(self, trainer, pl_module):
|
||||
self._train_batch_idx = 0
|
||||
self._train_batch_idx = trainer.fit_loop.epoch_loop.batch_progress.current.completed
|
||||
|
||||
def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx):
|
||||
self._train_batch_idx += 1
|
||||
|
|
|
@ -229,7 +229,7 @@ class ProgressBar(ProgressBarBase):
|
|||
val_checks_per_epoch = total_train_batches // trainer.val_check_batch
|
||||
total_val_batches = total_val_batches * val_checks_per_epoch
|
||||
total_batches = total_train_batches + total_val_batches
|
||||
reset(self.main_progress_bar, total_batches)
|
||||
reset(self.main_progress_bar, total=total_batches, current=self.train_batch_idx)
|
||||
self.main_progress_bar.set_description(f"Epoch {trainer.current_epoch}")
|
||||
|
||||
def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx):
|
||||
|
@ -243,11 +243,11 @@ class ProgressBar(ProgressBarBase):
|
|||
def on_validation_start(self, trainer, pl_module):
|
||||
super().on_validation_start(trainer, pl_module)
|
||||
if trainer.sanity_checking:
|
||||
reset(self.val_progress_bar, sum(trainer.num_sanity_val_batches))
|
||||
reset(self.val_progress_bar, total=sum(trainer.num_sanity_val_batches), current=self.val_batch_idx)
|
||||
else:
|
||||
self._update_bar(self.main_progress_bar) # fill up remaining
|
||||
self.val_progress_bar = self.init_validation_tqdm()
|
||||
reset(self.val_progress_bar, self.total_val_batches)
|
||||
reset(self.val_progress_bar, total=self.total_val_batches, current=self.val_batch_idx)
|
||||
|
||||
def on_validation_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx):
|
||||
super().on_validation_batch_end(trainer, pl_module, outputs, batch, batch_idx, dataloader_idx)
|
||||
|
@ -333,7 +333,8 @@ def convert_inf(x: Optional[Union[int, float]]) -> Optional[Union[int, float]]:
|
|||
return x
|
||||
|
||||
|
||||
def reset(bar: tqdm, total: Optional[int] = None) -> None:
|
||||
"""Resets the tqdm bar to 0 progress with a new total, unless it is disabled."""
|
||||
def reset(bar: tqdm, total: Optional[int] = None, current: int = 0) -> None:
|
||||
"""Resets the tqdm bar to the desired position and sets a new total, unless it is disabled."""
|
||||
if not bar.disable:
|
||||
bar.reset(total=convert_inf(total))
|
||||
bar.n = current
|
||||
|
|
|
@ -558,3 +558,33 @@ def _test_progress_bar_max_val_check_interval(
|
|||
total_val_batches = total_val_batches * val_checks_per_epoch
|
||||
if trainer.is_global_zero:
|
||||
assert trainer.progress_bar_callback.main_progress_bar.total == total_train_batches + total_val_batches
|
||||
|
||||
|
||||
def test_progress_bar_main_bar_resume():
|
||||
"""Test that the progress bar can resume its counters based on the Trainer state."""
|
||||
bar = ProgressBar()
|
||||
trainer = Mock()
|
||||
model = Mock()
|
||||
|
||||
trainer.sanity_checking = False
|
||||
trainer.check_val_every_n_epoch = 1
|
||||
trainer.current_epoch = 1
|
||||
trainer.num_training_batches = 5
|
||||
trainer.val_check_batch = 5
|
||||
trainer.num_val_batches = [3]
|
||||
trainer.fit_loop.epoch_loop.batch_progress.current.completed = 3
|
||||
|
||||
bar.on_init_end(trainer)
|
||||
bar.on_train_start(trainer, model)
|
||||
bar.on_train_epoch_start(trainer, model)
|
||||
|
||||
assert bar.main_progress_bar.n == 3
|
||||
assert bar.main_progress_bar.total == 8
|
||||
|
||||
# bar.on_train_epoch_end(trainer, model)
|
||||
bar.on_validation_start(trainer, model)
|
||||
bar.on_validation_epoch_start(trainer, model)
|
||||
|
||||
# restarting mid validation epoch is not currently supported
|
||||
assert bar.val_progress_bar.n == 0
|
||||
assert bar.val_progress_bar.total == 3
|
||||
|
|
Loading…
Reference in New Issue