From fd2e778cbc70f7967b50d78ca8ea060ad22120de Mon Sep 17 00:00:00 2001 From: Sean Naren Date: Wed, 22 Sep 2021 16:11:37 +0100 Subject: [PATCH] Improvements for rich progress (#9579) --- .../callbacks/progress/rich_progress.py | 45 +++++++++---------- tests/callbacks/test_rich_progress_bar.py | 10 +++++ 2 files changed, 32 insertions(+), 23 deletions(-) diff --git a/pytorch_lightning/callbacks/progress/rich_progress.py b/pytorch_lightning/callbacks/progress/rich_progress.py index 6bc2f4f7d8..2c53f9c5f7 100644 --- a/pytorch_lightning/callbacks/progress/rich_progress.py +++ b/pytorch_lightning/callbacks/progress/rich_progress.py @@ -125,7 +125,7 @@ class RichProgressBar(ProgressBarBase): trainer = Trainer(callbacks=RichProgressBar()) Args: - refresh_rate: the number of updates per second, must be strictly positive + refresh_rate_per_second: the number of updates per second. If refresh_rate is 0, progress bar is disabled. theme: Contains styles used to stylize the progress bar. Raises: @@ -135,7 +135,7 @@ class RichProgressBar(ProgressBarBase): def __init__( self, - refresh_rate: float = 1.0, + refresh_rate_per_second: int = 10, theme: RichProgressBarTheme = RichProgressBarTheme(), ) -> None: if not _RICH_AVAILABLE: @@ -143,7 +143,7 @@ class RichProgressBar(ProgressBarBase): "`RichProgressBar` requires `rich` to be installed. Install it by running `pip install rich`." ) super().__init__() - self._refresh_rate: float = refresh_rate + self._refresh_rate_per_second: int = refresh_rate_per_second self._enabled: bool = True self._total_val_batches: int = 0 self.progress: Progress = None @@ -156,12 +156,17 @@ class RichProgressBar(ProgressBarBase): self.theme = theme @property - def refresh_rate(self) -> int: - return self._refresh_rate + def refresh_rate_per_second(self) -> float: + """Refresh rate for Rich Progress. + + Returns: Refresh rate for Progress Bar. + Return 1 if not enabled, as a positive integer is required (ignored by Rich Progress). + """ + return self._refresh_rate_per_second if self._refresh_rate_per_second > 0 else 1 @property def is_enabled(self) -> bool: - return self._enabled and self.refresh_rate > 0 + return self._enabled and self._refresh_rate_per_second > 0 @property def is_disabled(self) -> bool: @@ -189,7 +194,7 @@ class RichProgressBar(ProgressBarBase): def predict_description(self) -> str: return "Predicting" - def setup(self, trainer, pl_module, stage): + def setup(self, trainer, pl_module, stage: Optional[str] = None): self.progress = Progress( TextColumn("[progress.description]{task.description}"), BarColumn(complete_style=self.theme.progress_bar_complete, finished_style=self.theme.progress_bar_finished), @@ -198,8 +203,10 @@ class RichProgressBar(ProgressBarBase): ProcessingSpeedColumn(style=self.theme.processing_speed), MetricsTextColumn(trainer, pl_module, stage), console=self.console, - refresh_per_second=self.refresh_rate, - ).__enter__() + refresh_per_second=self.refresh_rate_per_second, + disable=self.is_disabled, + ) + self.progress.start() def on_sanity_check_start(self, trainer, pl_module): super().on_sanity_check_start(trainer, pl_module) @@ -259,31 +266,23 @@ class RichProgressBar(ProgressBarBase): def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx): super().on_train_batch_end(trainer, pl_module, outputs, batch, batch_idx, dataloader_idx) - if self._should_update(self.train_batch_idx, self.total_train_batches + self.total_val_batches): - self.progress.update(self.main_progress_bar_id, advance=1.0) + self.progress.update(self.main_progress_bar_id, advance=1.0) def on_validation_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx): super().on_validation_batch_end(trainer, pl_module, outputs, batch, batch_idx, dataloader_idx) if trainer.sanity_checking: self.progress.update(self.val_sanity_progress_bar_id, advance=1.0) - elif self.val_progress_bar_id and self._should_update( - self.val_batch_idx, self.total_train_batches + self.total_val_batches - ): + elif self.val_progress_bar_id: self.progress.update(self.main_progress_bar_id, advance=1.0) self.progress.update(self.val_progress_bar_id, advance=1.0) def on_test_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx): super().on_test_batch_end(trainer, pl_module, outputs, batch, batch_idx, dataloader_idx) - if self._should_update(self.test_batch_idx, self.total_test_batches): - self.progress.update(self.test_progress_bar_id, advance=1.0) + self.progress.update(self.test_progress_bar_id, advance=1.0) def on_predict_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx): super().on_predict_batch_end(trainer, pl_module, outputs, batch, batch_idx, dataloader_idx) - if self._should_update(self.predict_batch_idx, self.total_predict_batches): - self.progress.update(self.predict_progress_bar_id, advance=1.0) - - def _should_update(self, current, total) -> bool: - return self.is_enabled and (current % self.refresh_rate == 0 or current == total) + self.progress.update(self.predict_progress_bar_id, advance=1.0) def _get_train_description(self, current_epoch: int) -> str: train_description = f"Epoch {current_epoch}" @@ -296,8 +295,8 @@ class RichProgressBar(ProgressBarBase): train_description += " " return train_description - def teardown(self, trainer, pl_module, stage): - self.progress.__exit__(None, None, None) + def teardown(self, trainer, pl_module, stage: Optional[str] = None) -> None: + self.progress.stop() def on_exception(self, trainer, pl_module, exception: BaseException) -> None: if isinstance(exception, KeyboardInterrupt): diff --git a/tests/callbacks/test_rich_progress_bar.py b/tests/callbacks/test_rich_progress_bar.py index ce25c929b0..edea9c8c5c 100644 --- a/tests/callbacks/test_rich_progress_bar.py +++ b/tests/callbacks/test_rich_progress_bar.py @@ -34,6 +34,16 @@ def test_rich_progress_bar_callback(): assert isinstance(trainer.progress_bar_callback, RichProgressBar) +@RunIf(rich=True) +def test_rich_progress_bar_refresh_rate(): + progress_bar = RichProgressBar(refresh_rate_per_second=1) + assert progress_bar.is_enabled + assert not progress_bar.is_disabled + progress_bar = RichProgressBar(refresh_rate_per_second=0) + assert not progress_bar.is_enabled + assert progress_bar.is_disabled + + @RunIf(rich=True) @mock.patch("pytorch_lightning.callbacks.progress.rich_progress.Progress.update") def test_rich_progress_bar(progress_update, tmpdir):