diff --git a/pytorch_lightning/callbacks/progress/rich_progress.py b/pytorch_lightning/callbacks/progress/rich_progress.py index b07b487927..c091223fba 100644 --- a/pytorch_lightning/callbacks/progress/rich_progress.py +++ b/pytorch_lightning/callbacks/progress/rich_progress.py @@ -129,11 +129,12 @@ if _RICH_AVAILABLE: class MetricsTextColumn(ProgressColumn): """A column containing text.""" - def __init__(self, trainer): + def __init__(self, trainer, style): self._trainer = trainer self._tasks = {} self._current_task_id = 0 self._metrics = {} + self._style = style super().__init__() def update(self, metrics): @@ -158,23 +159,34 @@ if _RICH_AVAILABLE: for k, v in self._metrics.items(): _text += f"{k}: {round(v, 3) if isinstance(v, float) else v} " - return Text(_text, justify="left") + return Text(_text, justify="left", style=self._style) @dataclass class RichProgressBarTheme: """Styles to associate to different base components. + Args: + description: Style for the progress bar description. For eg., Epoch x, Testing, etc. + progress_bar: Style for the bar in progress. + progress_bar_finished: Style for the finished progress bar. + progress_bar_pulse: Style for the progress bar when `IterableDataset` is being processed. + batch_progress: Style for the progress tracker (i.e 10/50 batches completed). + time: Style for the processed time and estimate time remaining. + processing_speed: Style for the speed of the batches being processed. + metrics: Style for the metrics + https://rich.readthedocs.io/en/stable/style.html """ - text_color: str = "white" - progress_bar_complete: Union[str, Style] = "#6206E0" + description: Union[str, Style] = "white" + progress_bar: Union[str, Style] = "#6206E0" progress_bar_finished: Union[str, Style] = "#6206E0" progress_bar_pulse: Union[str, Style] = "#6206E0" - batch_process: str = "white" - time: str = "grey54" - processing_speed: str = "grey70" + batch_progress: Union[str, Style] = "white" + time: Union[str, Style] = "grey54" + processing_speed: Union[str, Style] = "grey70" + metrics: Union[str, Style] = "white" class RichProgressBar(ProgressBarBase): @@ -273,7 +285,7 @@ class RichProgressBar(ProgressBarBase): self._reset_progress_bar_ids() self._console: Console = Console() self._console.clear_live() - self._metric_component = MetricsTextColumn(trainer) + self._metric_component = MetricsTextColumn(trainer, self.theme.metrics) self.progress = CustomProgress( *self.configure_columns(trainer), self._metric_component, @@ -356,7 +368,7 @@ class RichProgressBar(ProgressBarBase): def _add_task(self, total_batches: int, description: str, visible: bool = True) -> Optional[int]: if self.progress is not None: return self.progress.add_task( - f"[{self.theme.text_color}]{description}", total=total_batches, visible=visible + f"[{self.theme.description}]{description}", total=total_batches, visible=visible ) def _update(self, progress_bar_id: int, visible: bool = True) -> None: @@ -453,11 +465,11 @@ class RichProgressBar(ProgressBarBase): return [ TextColumn("[progress.description]{task.description}"), CustomBarColumn( - complete_style=self.theme.progress_bar_complete, + complete_style=self.theme.progress_bar, finished_style=self.theme.progress_bar_finished, pulse_style=self.theme.progress_bar_pulse, ), - BatchesProcessedColumn(style=self.theme.batch_process), + BatchesProcessedColumn(style=self.theme.batch_progress), CustomTimeColumn(style=self.theme.time), ProcessingSpeedColumn(style=self.theme.processing_speed), ] diff --git a/tests/callbacks/test_rich_progress_bar.py b/tests/callbacks/test_rich_progress_bar.py index 3168175442..8f3f20630b 100644 --- a/tests/callbacks/test_rich_progress_bar.py +++ b/tests/callbacks/test_rich_progress_bar.py @@ -106,11 +106,11 @@ def test_rich_progress_bar_custom_theme(tmpdir): assert progress_bar.theme == theme args, kwargs = mocks["CustomBarColumn"].call_args - assert kwargs["complete_style"] == theme.progress_bar_complete + assert kwargs["complete_style"] == theme.progress_bar assert kwargs["finished_style"] == theme.progress_bar_finished args, kwargs = mocks["BatchesProcessedColumn"].call_args - assert kwargs["style"] == theme.batch_process + assert kwargs["style"] == theme.batch_progress args, kwargs = mocks["CustomTimeColumn"].call_args assert kwargs["style"] == theme.time