diff --git a/CHANGELOG.md b/CHANGELOG.md index 98e615f1d8..deb88330d7 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -297,6 +297,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fixed `move_metrics_to_cpu` moving the loss on cpu while training on device ([#9308](https://github.com/PyTorchLightning/pytorch-lightning/pull/9308)) +- Fixed incorrect main progress bar indicator when resuming training mid-epoch ([#9310](https://github.com/PyTorchLightning/pytorch-lightning/pull/9310)) + + ## [1.4.5] - 2021-08-31 - Fixed reduction using `self.log(sync_dict=True, reduce_fx={mean,max})` ([#9142](https://github.com/PyTorchLightning/pytorch-lightning/pull/9142)) diff --git a/pytorch_lightning/callbacks/progress/base.py b/pytorch_lightning/callbacks/progress/base.py index db1de97a22..c1963345fd 100644 --- a/pytorch_lightning/callbacks/progress/base.py +++ b/pytorch_lightning/callbacks/progress/base.py @@ -154,10 +154,10 @@ class ProgressBarBase(Callback): self._trainer = trainer def on_train_start(self, trainer, pl_module): - self._train_batch_idx = trainer.fit_loop.batch_idx + self._train_batch_idx = 0 def on_train_epoch_start(self, trainer, pl_module): - self._train_batch_idx = 0 + self._train_batch_idx = trainer.fit_loop.epoch_loop.batch_progress.current.completed def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx): self._train_batch_idx += 1 diff --git a/pytorch_lightning/callbacks/progress/progress.py b/pytorch_lightning/callbacks/progress/progress.py index adcea3d581..aaea867680 100644 --- a/pytorch_lightning/callbacks/progress/progress.py +++ b/pytorch_lightning/callbacks/progress/progress.py @@ -229,7 +229,7 @@ class ProgressBar(ProgressBarBase): val_checks_per_epoch = total_train_batches // trainer.val_check_batch total_val_batches = total_val_batches * val_checks_per_epoch total_batches = total_train_batches + total_val_batches - reset(self.main_progress_bar, total_batches) + reset(self.main_progress_bar, total=total_batches, current=self.train_batch_idx) self.main_progress_bar.set_description(f"Epoch {trainer.current_epoch}") def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx): @@ -243,11 +243,11 @@ class ProgressBar(ProgressBarBase): def on_validation_start(self, trainer, pl_module): super().on_validation_start(trainer, pl_module) if trainer.sanity_checking: - reset(self.val_progress_bar, sum(trainer.num_sanity_val_batches)) + reset(self.val_progress_bar, total=sum(trainer.num_sanity_val_batches), current=self.val_batch_idx) else: self._update_bar(self.main_progress_bar) # fill up remaining self.val_progress_bar = self.init_validation_tqdm() - reset(self.val_progress_bar, self.total_val_batches) + reset(self.val_progress_bar, total=self.total_val_batches, current=self.val_batch_idx) def on_validation_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx): super().on_validation_batch_end(trainer, pl_module, outputs, batch, batch_idx, dataloader_idx) @@ -333,7 +333,8 @@ def convert_inf(x: Optional[Union[int, float]]) -> Optional[Union[int, float]]: return x -def reset(bar: tqdm, total: Optional[int] = None) -> None: - """Resets the tqdm bar to 0 progress with a new total, unless it is disabled.""" +def reset(bar: tqdm, total: Optional[int] = None, current: int = 0) -> None: + """Resets the tqdm bar to the desired position and sets a new total, unless it is disabled.""" if not bar.disable: bar.reset(total=convert_inf(total)) + bar.n = current diff --git a/tests/callbacks/test_progress_bar.py b/tests/callbacks/test_progress_bar.py index 1c3176f39a..7634e18bed 100644 --- a/tests/callbacks/test_progress_bar.py +++ b/tests/callbacks/test_progress_bar.py @@ -558,3 +558,33 @@ def _test_progress_bar_max_val_check_interval( total_val_batches = total_val_batches * val_checks_per_epoch if trainer.is_global_zero: assert trainer.progress_bar_callback.main_progress_bar.total == total_train_batches + total_val_batches + + +def test_progress_bar_main_bar_resume(): + """Test that the progress bar can resume its counters based on the Trainer state.""" + bar = ProgressBar() + trainer = Mock() + model = Mock() + + trainer.sanity_checking = False + trainer.check_val_every_n_epoch = 1 + trainer.current_epoch = 1 + trainer.num_training_batches = 5 + trainer.val_check_batch = 5 + trainer.num_val_batches = [3] + trainer.fit_loop.epoch_loop.batch_progress.current.completed = 3 + + bar.on_init_end(trainer) + bar.on_train_start(trainer, model) + bar.on_train_epoch_start(trainer, model) + + assert bar.main_progress_bar.n == 3 + assert bar.main_progress_bar.total == 8 + + # bar.on_train_epoch_end(trainer, model) + bar.on_validation_start(trainer, model) + bar.on_validation_epoch_start(trainer, model) + + # restarting mid validation epoch is not currently supported + assert bar.val_progress_bar.n == 0 + assert bar.val_progress_bar.total == 3