Configurable metric formatting for RichProgressBar (#18373)
This commit is contained in:
parent
13dede4813
commit
d5440a0b3f
|
@ -93,6 +93,7 @@ Customize the theme for your :class:`~lightning.pytorch.callbacks.RichProgressBa
|
|||
processing_speed="grey82",
|
||||
metrics="grey82",
|
||||
metrics_text_delimiter="\n",
|
||||
metrics_format=".3e",
|
||||
)
|
||||
)
|
||||
|
||||
|
|
|
@ -9,6 +9,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
|
|||
|
||||
### Added
|
||||
|
||||
- Added `metrics_format` attribute to `RichProgressBarTheme` class ([#18373](https://github.com/Lightning-AI/lightning/pull/18373))
|
||||
|
||||
- Added `CHECKPOINT_EQUALS_CHAR` attribute to `ModelCheckpoint` class ([#17999](https://github.com/Lightning-AI/lightning/pull/17999))
|
||||
|
||||
- Added `**summarize_kwargs` to `ModelSummary` and `RichModelSummary` callbacks ([#16788](https://github.com/Lightning-AI/lightning/pull/16788))
|
||||
|
|
|
@ -136,13 +136,20 @@ if _RICH_AVAILABLE:
|
|||
class MetricsTextColumn(ProgressColumn):
|
||||
"""A column containing text."""
|
||||
|
||||
def __init__(self, trainer: "pl.Trainer", style: Union[str, "Style"], text_delimiter: str):
|
||||
def __init__(
|
||||
self,
|
||||
trainer: "pl.Trainer",
|
||||
style: Union[str, "Style"],
|
||||
text_delimiter: str,
|
||||
metrics_format: Union[str, None],
|
||||
):
|
||||
self._trainer = trainer
|
||||
self._tasks: Dict[Union[int, TaskID], Any] = {}
|
||||
self._current_task_id = 0
|
||||
self._metrics: Dict[Union[str, "Style"], Any] = {}
|
||||
self._style = style
|
||||
self._text_delimiter = text_delimiter
|
||||
self._metrics_format = metrics_format
|
||||
super().__init__()
|
||||
|
||||
def update(self, metrics: Dict[Any, Any]) -> None:
|
||||
|
@ -173,8 +180,10 @@ if _RICH_AVAILABLE:
|
|||
return Text(text, justify="left", style=self._style)
|
||||
|
||||
def _generate_metrics_texts(self) -> Generator[str, None, None]:
|
||||
for k, v in self._metrics.items():
|
||||
yield f"{k}: {round(v, 3) if isinstance(v, float) else v}"
|
||||
for name, value in self._metrics.items():
|
||||
if not isinstance(value, str):
|
||||
value = round(value, 3) if self._metrics_format is None else f"{value:{self._metrics_format}}"
|
||||
yield f"{name}: {value}"
|
||||
|
||||
else:
|
||||
Task, Style = Any, Any # type: ignore[assignment, misc]
|
||||
|
@ -207,6 +216,7 @@ class RichProgressBarTheme:
|
|||
processing_speed: Union[str, Style] = "grey70"
|
||||
metrics: Union[str, Style] = "white"
|
||||
metrics_text_delimiter: str = " "
|
||||
metrics_format: Union[str, None] = None
|
||||
|
||||
|
||||
class RichProgressBar(ProgressBar):
|
||||
|
@ -328,7 +338,12 @@ class RichProgressBar(ProgressBar):
|
|||
reconfigure(**self._console_kwargs)
|
||||
self._console = get_console()
|
||||
self._console.clear_live()
|
||||
self._metric_component = MetricsTextColumn(trainer, self.theme.metrics, self.theme.metrics_text_delimiter)
|
||||
self._metric_component = MetricsTextColumn(
|
||||
trainer,
|
||||
self.theme.metrics,
|
||||
self.theme.metrics_text_delimiter,
|
||||
self.theme.metrics_format,
|
||||
)
|
||||
self.progress = CustomProgress(
|
||||
*self.configure_columns(trainer),
|
||||
self._metric_component,
|
||||
|
|
Loading…
Reference in New Issue