From 137b62d80df9896ccce63bf607a29cfdbf1f06f0 Mon Sep 17 00:00:00 2001 From: Kaushik B <45285388+kaushikb11@users.noreply.github.com> Date: Fri, 19 Nov 2021 11:29:57 +0530 Subject: [PATCH] Add `refresh_rate` to RichProgressBar (#10497) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: ananthsub Co-authored-by: Adrian Wälchli --- CHANGELOG.md | 3 + .../callbacks/progress/rich_progress.py | 59 +++++++++++-------- tests/callbacks/test_rich_progress_bar.py | 27 ++++++++- 3 files changed, 63 insertions(+), 26 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index d2380a1dc6..438e2c9933 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -37,6 +37,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Raise `MisconfigurationException` when `enable_progress_bar=False` and a progress bar instance has been passed in the callback list ([#10520](https://github.com/PyTorchLightning/pytorch-lightning/issues/10520)) +- Renamed `refresh_rate_per_second` parameter to `referesh_rate` for `RichProgressBar` signature ([#10497](https://github.com/PyTorchLightning/pytorch-lightning/pull/10497)) + + - Moved ownership of the `PrecisionPlugin` into `TrainingTypePlugin` and updated all references ([#10570](https://github.com/PyTorchLightning/pytorch-lightning/pull/10570)) diff --git a/pytorch_lightning/callbacks/progress/rich_progress.py b/pytorch_lightning/callbacks/progress/rich_progress.py index c091223fba..e2a269d659 100644 --- a/pytorch_lightning/callbacks/progress/rich_progress.py +++ b/pytorch_lightning/callbacks/progress/rich_progress.py @@ -206,7 +206,8 @@ class RichProgressBar(ProgressBarBase): trainer = Trainer(callbacks=RichProgressBar()) Args: - refresh_rate_per_second: the number of updates per second. If refresh_rate is 0, progress bar is disabled. + refresh_rate: Determines at which rate (in number of batches) the progress bars get updated. + Set it to ``0`` to disable the display. leave: Leaves the finished progress bar in the terminal at the end of the epoch. Default: False theme: Contains styles used to stylize the progress bar. @@ -222,7 +223,7 @@ class RichProgressBar(ProgressBarBase): def __init__( self, - refresh_rate_per_second: int = 10, + refresh_rate: int = 1, leave: bool = False, theme: RichProgressBarTheme = RichProgressBarTheme(), ) -> None: @@ -231,7 +232,7 @@ class RichProgressBar(ProgressBarBase): "`RichProgressBar` requires `rich` to be installed. Install it by running `pip install -U rich`." ) super().__init__() - self._refresh_rate_per_second: int = refresh_rate_per_second + self._refresh_rate: int = refresh_rate self._leave: bool = leave self._enabled: bool = True self.progress: Optional[Progress] = None @@ -242,17 +243,12 @@ class RichProgressBar(ProgressBarBase): self.theme = theme @property - def refresh_rate_per_second(self) -> float: - """Refresh rate for Rich Progress. - - Returns: Refresh rate for Progress Bar. - Return 1 if not enabled, as a positive integer is required (ignored by Rich Progress). - """ - return self._refresh_rate_per_second if self._refresh_rate_per_second > 0 else 1 + def refresh_rate(self) -> float: + return self._refresh_rate @property def is_enabled(self) -> bool: - return self._enabled and self._refresh_rate_per_second > 0 + return self._enabled and self.refresh_rate > 0 @property def is_disabled(self) -> bool: @@ -289,7 +285,7 @@ class RichProgressBar(ProgressBarBase): self.progress = CustomProgress( *self.configure_columns(trainer), self._metric_component, - refresh_per_second=self.refresh_rate_per_second, + auto_refresh=False, disable=self.is_disabled, console=self._console, ) @@ -297,6 +293,10 @@ class RichProgressBar(ProgressBarBase): # progress has started self._progress_stopped = False + def refresh(self) -> None: + if self.progress: + self.progress.refresh() + def on_train_start(self, trainer, pl_module): super().on_train_start(trainer, pl_module) self._init_progress(trainer) @@ -328,10 +328,12 @@ class RichProgressBar(ProgressBarBase): super().on_sanity_check_start(trainer, pl_module) self._init_progress(trainer) self.val_sanity_progress_bar_id = self._add_task(trainer.num_sanity_val_steps, self.sanity_check_description) + self.refresh() def on_sanity_check_end(self, trainer, pl_module): super().on_sanity_check_end(trainer, pl_module) self._update(self.val_sanity_progress_bar_id, visible=False) + self.refresh() def on_train_epoch_start(self, trainer, pl_module): super().on_train_epoch_start(trainer, pl_module) @@ -354,6 +356,7 @@ class RichProgressBar(ProgressBarBase): self.progress.reset( self.main_progress_bar_id, total=total_batches, description=train_description, visible=True ) + self.refresh() def on_validation_epoch_start(self, trainer, pl_module): super().on_validation_epoch_start(trainer, pl_module) @@ -364,6 +367,7 @@ class RichProgressBar(ProgressBarBase): val_checks_per_epoch = self.total_train_batches // trainer.val_check_batch total_val_batches = self.total_val_batches * val_checks_per_epoch self.val_progress_bar_id = self._add_task(total_val_batches, self.validation_description, visible=False) + self.refresh() def _add_task(self, total_batches: int, description: str, visible: bool = True) -> Optional[int]: if self.progress is not None: @@ -371,45 +375,54 @@ class RichProgressBar(ProgressBarBase): f"[{self.theme.description}]{description}", total=total_batches, visible=visible ) - def _update(self, progress_bar_id: int, visible: bool = True) -> None: - if self.progress is not None: - self.progress.update(progress_bar_id, advance=1.0, visible=visible) + def _update(self, progress_bar_id: int, current: int, total: int, visible: bool = True) -> None: + if self.progress is not None and self._should_update(current, total): + self.progress.update(progress_bar_id, advance=self.refresh_rate, visible=visible) + self.refresh() + + def _should_update(self, current: int, total: int) -> bool: + return self.is_enabled and (current % self.refresh_rate == 0 or current == total) def on_validation_epoch_end(self, trainer, pl_module): super().on_validation_epoch_end(trainer, pl_module) if self.val_progress_bar_id is not None: - self._update(self.val_progress_bar_id, visible=False) + self._update(self.val_progress_bar_id, self.val_batch_idx, self.total_val_batches, visible=False) def on_test_epoch_start(self, trainer, pl_module): - super().on_train_epoch_start(trainer, pl_module) self.test_progress_bar_id = self._add_task(self.total_test_batches, self.test_description) + self.refresh() def on_predict_epoch_start(self, trainer, pl_module): super().on_predict_epoch_start(trainer, pl_module) self.predict_progress_bar_id = self._add_task(self.total_predict_batches, self.predict_description) + self.refresh() def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx): super().on_train_batch_end(trainer, pl_module, outputs, batch, batch_idx) - self._update(self.main_progress_bar_id) + self._update(self.main_progress_bar_id, self.train_batch_idx, self.total_train_batches) self._update_metrics(trainer, pl_module) + self.refresh() def on_validation_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx): super().on_validation_batch_end(trainer, pl_module, outputs, batch, batch_idx, dataloader_idx) if trainer.sanity_checking: - self._update(self.val_sanity_progress_bar_id) + self._update(self.val_sanity_progress_bar_id, self.val_batch_idx, self.total_val_batches) elif self.val_progress_bar_id is not None: # check to see if we should update the main training progress bar if self.main_progress_bar_id is not None: - self._update(self.main_progress_bar_id) - self._update(self.val_progress_bar_id) + self._update(self.main_progress_bar_id, self.val_batch_idx, self.total_val_batches) + self._update(self.val_progress_bar_id, self.val_batch_idx, self.total_val_batches) + self.refresh() def on_test_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx): super().on_test_batch_end(trainer, pl_module, outputs, batch, batch_idx, dataloader_idx) - self._update(self.test_progress_bar_id) + self._update(self.test_progress_bar_id, self.test_batch_idx, self.total_test_batches) + self.refresh() def on_predict_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx): super().on_predict_batch_end(trainer, pl_module, outputs, batch, batch_idx, dataloader_idx) - self._update(self.predict_progress_bar_id) + self._update(self.predict_progress_bar_id, self.predict_batch_idx, self.total_predict_batches) + self.refresh() def _get_train_description(self, current_epoch: int) -> str: train_description = f"Epoch {current_epoch}" diff --git a/tests/callbacks/test_rich_progress_bar.py b/tests/callbacks/test_rich_progress_bar.py index 8f3f20630b..8ca7326fd7 100644 --- a/tests/callbacks/test_rich_progress_bar.py +++ b/tests/callbacks/test_rich_progress_bar.py @@ -36,11 +36,11 @@ def test_rich_progress_bar_callback(): @RunIf(rich=True) -def test_rich_progress_bar_refresh_rate(): - progress_bar = RichProgressBar(refresh_rate_per_second=1) +def test_rich_progress_bar_refresh_rate_enabled(): + progress_bar = RichProgressBar(refresh_rate=1) assert progress_bar.is_enabled assert not progress_bar.is_disabled - progress_bar = RichProgressBar(refresh_rate_per_second=0) + progress_bar = RichProgressBar(refresh_rate=0) assert not progress_bar.is_enabled assert progress_bar.is_disabled @@ -180,3 +180,24 @@ def test_rich_progress_bar_leave(tmpdir, leave, reset_call_count): ) trainer.fit(model) assert mock_progress_reset.call_count == reset_call_count + + +@RunIf(rich=True) +@mock.patch("pytorch_lightning.callbacks.progress.rich_progress.Progress.update") +@pytest.mark.parametrize(("refresh_rate", "expected_call_count"), ([(0, 0), (3, 7)])) +def test_rich_progress_bar_refresh_rate(progress_update, tmpdir, refresh_rate, expected_call_count): + + model = BoringModel() + + trainer = Trainer( + default_root_dir=tmpdir, + num_sanity_val_steps=0, + limit_train_batches=6, + limit_val_batches=6, + max_epochs=1, + callbacks=RichProgressBar(refresh_rate=refresh_rate), + ) + + trainer.fit(model) + + assert progress_update.call_count == expected_call_count