Configurable metric formatting for RichProgressBar (#18373)

This commit is contained in:
Quinten Roets 2023-08-29 17:00:31 +02:00 committed by GitHub
parent 13dede4813
commit d5440a0b3f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 22 additions and 4 deletions

View File

@ -93,6 +93,7 @@ Customize the theme for your :class:`~lightning.pytorch.callbacks.RichProgressBa
processing_speed="grey82",
metrics="grey82",
metrics_text_delimiter="\n",
metrics_format=".3e",
)
)

View File

@ -9,6 +9,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
### Added
- Added `metrics_format` attribute to `RichProgressBarTheme` class ([#18373](https://github.com/Lightning-AI/lightning/pull/18373))
- Added `CHECKPOINT_EQUALS_CHAR` attribute to `ModelCheckpoint` class ([#17999](https://github.com/Lightning-AI/lightning/pull/17999))
- Added `**summarize_kwargs` to `ModelSummary` and `RichModelSummary` callbacks ([#16788](https://github.com/Lightning-AI/lightning/pull/16788))

View File

@ -136,13 +136,20 @@ if _RICH_AVAILABLE:
class MetricsTextColumn(ProgressColumn):
"""A column containing text."""
def __init__(self, trainer: "pl.Trainer", style: Union[str, "Style"], text_delimiter: str):
def __init__(
self,
trainer: "pl.Trainer",
style: Union[str, "Style"],
text_delimiter: str,
metrics_format: Union[str, None],
):
self._trainer = trainer
self._tasks: Dict[Union[int, TaskID], Any] = {}
self._current_task_id = 0
self._metrics: Dict[Union[str, "Style"], Any] = {}
self._style = style
self._text_delimiter = text_delimiter
self._metrics_format = metrics_format
super().__init__()
def update(self, metrics: Dict[Any, Any]) -> None:
@ -173,8 +180,10 @@ if _RICH_AVAILABLE:
return Text(text, justify="left", style=self._style)
def _generate_metrics_texts(self) -> Generator[str, None, None]:
for k, v in self._metrics.items():
yield f"{k}: {round(v, 3) if isinstance(v, float) else v}"
for name, value in self._metrics.items():
if not isinstance(value, str):
value = round(value, 3) if self._metrics_format is None else f"{value:{self._metrics_format}}"
yield f"{name}: {value}"
else:
Task, Style = Any, Any # type: ignore[assignment, misc]
@ -207,6 +216,7 @@ class RichProgressBarTheme:
processing_speed: Union[str, Style] = "grey70"
metrics: Union[str, Style] = "white"
metrics_text_delimiter: str = " "
metrics_format: Union[str, None] = None
class RichProgressBar(ProgressBar):
@ -328,7 +338,12 @@ class RichProgressBar(ProgressBar):
reconfigure(**self._console_kwargs)
self._console = get_console()
self._console.clear_live()
self._metric_component = MetricsTextColumn(trainer, self.theme.metrics, self.theme.metrics_text_delimiter)
self._metric_component = MetricsTextColumn(
trainer,
self.theme.metrics,
self.theme.metrics_text_delimiter,
self.theme.metrics_format,
)
self.progress = CustomProgress(
*self.configure_columns(trainer),
self._metric_component,