diff --git a/pytorch_lightning/callbacks/progress/base.py b/pytorch_lightning/callbacks/progress/base.py index 234d62a68c..42ad957235 100644 --- a/pytorch_lightning/callbacks/progress/base.py +++ b/pytorch_lightning/callbacks/progress/base.py @@ -77,6 +77,11 @@ class ProgressBarBase(Callback): def predict_description(self) -> str: return "Predicting" + @property + def _val_processed(self) -> int: + # use total in case validation runs more than once per training epoch + return self.trainer.fit_loop.epoch_loop.val_loop.epoch_loop.batch_progress.total.processed + @property def train_batch_idx(self) -> int: """The number of batches processed during training. diff --git a/pytorch_lightning/callbacks/progress/rich_progress.py b/pytorch_lightning/callbacks/progress/rich_progress.py index 4d7c4b7864..cc475634ff 100644 --- a/pytorch_lightning/callbacks/progress/rich_progress.py +++ b/pytorch_lightning/callbacks/progress/rich_progress.py @@ -368,15 +368,19 @@ class RichProgressBar(ProgressBarBase): f"[{self.theme.description}]{description}", total=total_batches, visible=visible ) - def _update(self, progress_bar_id: int, current: int, total: Union[int, float], visible: bool = True) -> None: - if self.progress is not None and self._should_update(current, total): + def _update(self, progress_bar_id: int, current: int, visible: bool = True) -> None: + if self.progress is not None and self.is_enabled: + total = self.progress.tasks[progress_bar_id].total + if not self._should_update(current, total): + return + leftover = current % self.refresh_rate advance = leftover if (current == total and leftover != 0) else self.refresh_rate self.progress.update(progress_bar_id, advance=advance, visible=visible) self.refresh() def _should_update(self, current: int, total: Union[int, float]) -> bool: - return self.is_enabled and (current % self.refresh_rate == 0 or current == total) + return current % self.refresh_rate == 0 or current == total def on_validation_epoch_end(self, trainer, pl_module): if self.val_progress_bar_id is not None and trainer.state.fn == "fit": @@ -419,7 +423,7 @@ class RichProgressBar(ProgressBarBase): self.refresh() def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx): - self._update(self.main_progress_bar_id, self.train_batch_idx, self.total_train_batches) + self._update(self.main_progress_bar_id, self.train_batch_idx + self._val_processed) self._update_metrics(trainer, pl_module) self.refresh() @@ -428,23 +432,20 @@ class RichProgressBar(ProgressBarBase): def on_validation_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx): if trainer.sanity_checking: - self._update(self.val_sanity_progress_bar_id, self.val_batch_idx, self.total_val_batches_current_dataloader) + self._update(self.val_sanity_progress_bar_id, self.val_batch_idx) elif self.val_progress_bar_id is not None: # check to see if we should update the main training progress bar if self.main_progress_bar_id is not None: - # TODO: Use total val_processed here just like TQDM in a follow-up - self._update(self.main_progress_bar_id, self.val_batch_idx, self.total_val_batches_current_dataloader) - self._update(self.val_progress_bar_id, self.val_batch_idx, self.total_val_batches_current_dataloader) + self._update(self.main_progress_bar_id, self.train_batch_idx + self._val_processed) + self._update(self.val_progress_bar_id, self.val_batch_idx) self.refresh() def on_test_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx): - self._update(self.test_progress_bar_id, self.test_batch_idx, self.total_test_batches_current_dataloader) + self._update(self.test_progress_bar_id, self.test_batch_idx) self.refresh() def on_predict_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx): - self._update( - self.predict_progress_bar_id, self.predict_batch_idx, self.total_predict_batches_current_dataloader - ) + self._update(self.predict_progress_bar_id, self.predict_batch_idx) self.refresh() def _get_train_description(self, current_epoch: int) -> str: diff --git a/pytorch_lightning/callbacks/progress/tqdm_progress.py b/pytorch_lightning/callbacks/progress/tqdm_progress.py index 19090a1efa..a245c20318 100644 --- a/pytorch_lightning/callbacks/progress/tqdm_progress.py +++ b/pytorch_lightning/callbacks/progress/tqdm_progress.py @@ -171,11 +171,6 @@ class TQDMProgressBar(ProgressBarBase): def is_disabled(self) -> bool: return not self.is_enabled - @property - def _val_processed(self) -> int: - # use total in case validation runs more than once per training epoch - return self.trainer.fit_loop.epoch_loop.val_loop.epoch_loop.batch_progress.total.processed - def disable(self) -> None: self._enabled = False diff --git a/tests/callbacks/test_rich_progress_bar.py b/tests/callbacks/test_rich_progress_bar.py index 29ef3aa98f..8fdcb6c99e 100644 --- a/tests/callbacks/test_rich_progress_bar.py +++ b/tests/callbacks/test_rich_progress_bar.py @@ -205,14 +205,28 @@ def test_rich_progress_bar_refresh_rate_disabled(progress_update, tmpdir): @RunIf(rich=True) -@pytest.mark.parametrize(("refresh_rate", "expected_call_count"), ([(3, 7), (4, 7), (7, 4)])) -def test_rich_progress_bar_with_refresh_rate(tmpdir, refresh_rate, expected_call_count): +@pytest.mark.parametrize( + "refresh_rate,train_batches,val_batches,expected_call_count", + [ + (3, 6, 6, 4 + 3), + (4, 6, 6, 3 + 3), + (7, 6, 6, 2 + 2), + (1, 2, 3, 5 + 4), + (1, 0, 0, 0 + 0), + (3, 1, 0, 1 + 0), + (3, 1, 1, 1 + 2), + (3, 5, 0, 2 + 0), + (3, 5, 2, 3 + 2), + (6, 5, 2, 2 + 2), + ], +) +def test_rich_progress_bar_with_refresh_rate(tmpdir, refresh_rate, train_batches, val_batches, expected_call_count): model = BoringModel() trainer = Trainer( default_root_dir=tmpdir, num_sanity_val_steps=0, - limit_train_batches=6, - limit_val_batches=6, + limit_train_batches=train_batches, + limit_val_batches=val_batches, max_epochs=1, callbacks=RichProgressBar(refresh_rate=refresh_rate), ) @@ -224,14 +238,16 @@ def test_rich_progress_bar_with_refresh_rate(tmpdir, refresh_rate, expected_call trainer.fit(model) assert progress_update.call_count == expected_call_count - fit_main_bar = trainer.progress_bar_callback.progress.tasks[0] - fit_val_bar = trainer.progress_bar_callback.progress.tasks[1] - assert fit_main_bar.completed == 12 - assert fit_main_bar.total == 12 - assert fit_main_bar.visible - assert fit_val_bar.completed == 6 - assert fit_val_bar.total == 6 - assert not fit_val_bar.visible + if train_batches > 0: + fit_main_bar = trainer.progress_bar_callback.progress.tasks[0] + assert fit_main_bar.completed == train_batches + val_batches + assert fit_main_bar.total == train_batches + val_batches + assert fit_main_bar.visible + if val_batches > 0: + fit_val_bar = trainer.progress_bar_callback.progress.tasks[1] + assert fit_val_bar.completed == val_batches + assert fit_val_bar.total == val_batches + assert not fit_val_bar.visible @RunIf(rich=True)