Change attributes of `RichProgressBarTheme` dataclass (#10454)
Co-authored-by: Kaushik B <45285388+kaushikb11@users.noreply.github.com>
This commit is contained in:
parent
5ba5b72473
commit
09cf167237
|
@ -129,11 +129,12 @@ if _RICH_AVAILABLE:
|
|||
class MetricsTextColumn(ProgressColumn):
|
||||
"""A column containing text."""
|
||||
|
||||
def __init__(self, trainer):
|
||||
def __init__(self, trainer, style):
|
||||
self._trainer = trainer
|
||||
self._tasks = {}
|
||||
self._current_task_id = 0
|
||||
self._metrics = {}
|
||||
self._style = style
|
||||
super().__init__()
|
||||
|
||||
def update(self, metrics):
|
||||
|
@ -158,23 +159,34 @@ if _RICH_AVAILABLE:
|
|||
|
||||
for k, v in self._metrics.items():
|
||||
_text += f"{k}: {round(v, 3) if isinstance(v, float) else v} "
|
||||
return Text(_text, justify="left")
|
||||
return Text(_text, justify="left", style=self._style)
|
||||
|
||||
|
||||
@dataclass
|
||||
class RichProgressBarTheme:
|
||||
"""Styles to associate to different base components.
|
||||
|
||||
Args:
|
||||
description: Style for the progress bar description. For eg., Epoch x, Testing, etc.
|
||||
progress_bar: Style for the bar in progress.
|
||||
progress_bar_finished: Style for the finished progress bar.
|
||||
progress_bar_pulse: Style for the progress bar when `IterableDataset` is being processed.
|
||||
batch_progress: Style for the progress tracker (i.e 10/50 batches completed).
|
||||
time: Style for the processed time and estimate time remaining.
|
||||
processing_speed: Style for the speed of the batches being processed.
|
||||
metrics: Style for the metrics
|
||||
|
||||
https://rich.readthedocs.io/en/stable/style.html
|
||||
"""
|
||||
|
||||
text_color: str = "white"
|
||||
progress_bar_complete: Union[str, Style] = "#6206E0"
|
||||
description: Union[str, Style] = "white"
|
||||
progress_bar: Union[str, Style] = "#6206E0"
|
||||
progress_bar_finished: Union[str, Style] = "#6206E0"
|
||||
progress_bar_pulse: Union[str, Style] = "#6206E0"
|
||||
batch_process: str = "white"
|
||||
time: str = "grey54"
|
||||
processing_speed: str = "grey70"
|
||||
batch_progress: Union[str, Style] = "white"
|
||||
time: Union[str, Style] = "grey54"
|
||||
processing_speed: Union[str, Style] = "grey70"
|
||||
metrics: Union[str, Style] = "white"
|
||||
|
||||
|
||||
class RichProgressBar(ProgressBarBase):
|
||||
|
@ -273,7 +285,7 @@ class RichProgressBar(ProgressBarBase):
|
|||
self._reset_progress_bar_ids()
|
||||
self._console: Console = Console()
|
||||
self._console.clear_live()
|
||||
self._metric_component = MetricsTextColumn(trainer)
|
||||
self._metric_component = MetricsTextColumn(trainer, self.theme.metrics)
|
||||
self.progress = CustomProgress(
|
||||
*self.configure_columns(trainer),
|
||||
self._metric_component,
|
||||
|
@ -356,7 +368,7 @@ class RichProgressBar(ProgressBarBase):
|
|||
def _add_task(self, total_batches: int, description: str, visible: bool = True) -> Optional[int]:
|
||||
if self.progress is not None:
|
||||
return self.progress.add_task(
|
||||
f"[{self.theme.text_color}]{description}", total=total_batches, visible=visible
|
||||
f"[{self.theme.description}]{description}", total=total_batches, visible=visible
|
||||
)
|
||||
|
||||
def _update(self, progress_bar_id: int, visible: bool = True) -> None:
|
||||
|
@ -453,11 +465,11 @@ class RichProgressBar(ProgressBarBase):
|
|||
return [
|
||||
TextColumn("[progress.description]{task.description}"),
|
||||
CustomBarColumn(
|
||||
complete_style=self.theme.progress_bar_complete,
|
||||
complete_style=self.theme.progress_bar,
|
||||
finished_style=self.theme.progress_bar_finished,
|
||||
pulse_style=self.theme.progress_bar_pulse,
|
||||
),
|
||||
BatchesProcessedColumn(style=self.theme.batch_process),
|
||||
BatchesProcessedColumn(style=self.theme.batch_progress),
|
||||
CustomTimeColumn(style=self.theme.time),
|
||||
ProcessingSpeedColumn(style=self.theme.processing_speed),
|
||||
]
|
||||
|
|
|
@ -106,11 +106,11 @@ def test_rich_progress_bar_custom_theme(tmpdir):
|
|||
|
||||
assert progress_bar.theme == theme
|
||||
args, kwargs = mocks["CustomBarColumn"].call_args
|
||||
assert kwargs["complete_style"] == theme.progress_bar_complete
|
||||
assert kwargs["complete_style"] == theme.progress_bar
|
||||
assert kwargs["finished_style"] == theme.progress_bar_finished
|
||||
|
||||
args, kwargs = mocks["BatchesProcessedColumn"].call_args
|
||||
assert kwargs["style"] == theme.batch_process
|
||||
assert kwargs["style"] == theme.batch_progress
|
||||
|
||||
args, kwargs = mocks["CustomTimeColumn"].call_args
|
||||
assert kwargs["style"] == theme.time
|
||||
|
|
Loading…
Reference in New Issue