Change attributes of `RichProgressBarTheme` dataclass (#10454)

Co-authored-by: Kaushik B <45285388+kaushikb11@users.noreply.github.com>
This commit is contained in:
Raahul Singh 2021-11-12 01:23:40 +05:30 committed by GitHub
parent 5ba5b72473
commit 09cf167237
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 25 additions and 13 deletions

View File

@ -129,11 +129,12 @@ if _RICH_AVAILABLE:
class MetricsTextColumn(ProgressColumn):
"""A column containing text."""
def __init__(self, trainer):
def __init__(self, trainer, style):
self._trainer = trainer
self._tasks = {}
self._current_task_id = 0
self._metrics = {}
self._style = style
super().__init__()
def update(self, metrics):
@ -158,23 +159,34 @@ if _RICH_AVAILABLE:
for k, v in self._metrics.items():
_text += f"{k}: {round(v, 3) if isinstance(v, float) else v} "
return Text(_text, justify="left")
return Text(_text, justify="left", style=self._style)
@dataclass
class RichProgressBarTheme:
"""Styles to associate to different base components.
Args:
description: Style for the progress bar description. For eg., Epoch x, Testing, etc.
progress_bar: Style for the bar in progress.
progress_bar_finished: Style for the finished progress bar.
progress_bar_pulse: Style for the progress bar when `IterableDataset` is being processed.
batch_progress: Style for the progress tracker (i.e 10/50 batches completed).
time: Style for the processed time and estimate time remaining.
processing_speed: Style for the speed of the batches being processed.
metrics: Style for the metrics
https://rich.readthedocs.io/en/stable/style.html
"""
text_color: str = "white"
progress_bar_complete: Union[str, Style] = "#6206E0"
description: Union[str, Style] = "white"
progress_bar: Union[str, Style] = "#6206E0"
progress_bar_finished: Union[str, Style] = "#6206E0"
progress_bar_pulse: Union[str, Style] = "#6206E0"
batch_process: str = "white"
time: str = "grey54"
processing_speed: str = "grey70"
batch_progress: Union[str, Style] = "white"
time: Union[str, Style] = "grey54"
processing_speed: Union[str, Style] = "grey70"
metrics: Union[str, Style] = "white"
class RichProgressBar(ProgressBarBase):
@ -273,7 +285,7 @@ class RichProgressBar(ProgressBarBase):
self._reset_progress_bar_ids()
self._console: Console = Console()
self._console.clear_live()
self._metric_component = MetricsTextColumn(trainer)
self._metric_component = MetricsTextColumn(trainer, self.theme.metrics)
self.progress = CustomProgress(
*self.configure_columns(trainer),
self._metric_component,
@ -356,7 +368,7 @@ class RichProgressBar(ProgressBarBase):
def _add_task(self, total_batches: int, description: str, visible: bool = True) -> Optional[int]:
if self.progress is not None:
return self.progress.add_task(
f"[{self.theme.text_color}]{description}", total=total_batches, visible=visible
f"[{self.theme.description}]{description}", total=total_batches, visible=visible
)
def _update(self, progress_bar_id: int, visible: bool = True) -> None:
@ -453,11 +465,11 @@ class RichProgressBar(ProgressBarBase):
return [
TextColumn("[progress.description]{task.description}"),
CustomBarColumn(
complete_style=self.theme.progress_bar_complete,
complete_style=self.theme.progress_bar,
finished_style=self.theme.progress_bar_finished,
pulse_style=self.theme.progress_bar_pulse,
),
BatchesProcessedColumn(style=self.theme.batch_process),
BatchesProcessedColumn(style=self.theme.batch_progress),
CustomTimeColumn(style=self.theme.time),
ProcessingSpeedColumn(style=self.theme.processing_speed),
]

View File

@ -106,11 +106,11 @@ def test_rich_progress_bar_custom_theme(tmpdir):
assert progress_bar.theme == theme
args, kwargs = mocks["CustomBarColumn"].call_args
assert kwargs["complete_style"] == theme.progress_bar_complete
assert kwargs["complete_style"] == theme.progress_bar
assert kwargs["finished_style"] == theme.progress_bar_finished
args, kwargs = mocks["BatchesProcessedColumn"].call_args
assert kwargs["style"] == theme.batch_process
assert kwargs["style"] == theme.batch_progress
args, kwargs = mocks["CustomTimeColumn"].call_args
assert kwargs["style"] == theme.time