diff --git a/docs/source-pytorch/common/progress_bar.rst b/docs/source-pytorch/common/progress_bar.rst index 24c4285cdf..f16846e8c9 100644 --- a/docs/source-pytorch/common/progress_bar.rst +++ b/docs/source-pytorch/common/progress_bar.rst @@ -93,6 +93,7 @@ Customize the theme for your :class:`~lightning.pytorch.callbacks.RichProgressBa processing_speed="grey82", metrics="grey82", metrics_text_delimiter="\n", + metrics_format=".3e", ) ) diff --git a/src/lightning/pytorch/CHANGELOG.md b/src/lightning/pytorch/CHANGELOG.md index 62ba59c596..fb44687142 100644 --- a/src/lightning/pytorch/CHANGELOG.md +++ b/src/lightning/pytorch/CHANGELOG.md @@ -9,6 +9,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Added +- Added `metrics_format` attribute to `RichProgressBarTheme` class ([#18373](https://github.com/Lightning-AI/lightning/pull/18373)) + - Added `CHECKPOINT_EQUALS_CHAR` attribute to `ModelCheckpoint` class ([#17999](https://github.com/Lightning-AI/lightning/pull/17999)) - Added `**summarize_kwargs` to `ModelSummary` and `RichModelSummary` callbacks ([#16788](https://github.com/Lightning-AI/lightning/pull/16788)) diff --git a/src/lightning/pytorch/callbacks/progress/rich_progress.py b/src/lightning/pytorch/callbacks/progress/rich_progress.py index 0bdfcda87c..fcc7e9e695 100644 --- a/src/lightning/pytorch/callbacks/progress/rich_progress.py +++ b/src/lightning/pytorch/callbacks/progress/rich_progress.py @@ -136,13 +136,20 @@ if _RICH_AVAILABLE: class MetricsTextColumn(ProgressColumn): """A column containing text.""" - def __init__(self, trainer: "pl.Trainer", style: Union[str, "Style"], text_delimiter: str): + def __init__( + self, + trainer: "pl.Trainer", + style: Union[str, "Style"], + text_delimiter: str, + metrics_format: Union[str, None], + ): self._trainer = trainer self._tasks: Dict[Union[int, TaskID], Any] = {} self._current_task_id = 0 self._metrics: Dict[Union[str, "Style"], Any] = {} self._style = style self._text_delimiter = text_delimiter + self._metrics_format = metrics_format super().__init__() def update(self, metrics: Dict[Any, Any]) -> None: @@ -173,8 +180,10 @@ if _RICH_AVAILABLE: return Text(text, justify="left", style=self._style) def _generate_metrics_texts(self) -> Generator[str, None, None]: - for k, v in self._metrics.items(): - yield f"{k}: {round(v, 3) if isinstance(v, float) else v}" + for name, value in self._metrics.items(): + if not isinstance(value, str): + value = round(value, 3) if self._metrics_format is None else f"{value:{self._metrics_format}}" + yield f"{name}: {value}" else: Task, Style = Any, Any # type: ignore[assignment, misc] @@ -207,6 +216,7 @@ class RichProgressBarTheme: processing_speed: Union[str, Style] = "grey70" metrics: Union[str, Style] = "white" metrics_text_delimiter: str = " " + metrics_format: Union[str, None] = None class RichProgressBar(ProgressBar): @@ -328,7 +338,12 @@ class RichProgressBar(ProgressBar): reconfigure(**self._console_kwargs) self._console = get_console() self._console.clear_live() - self._metric_component = MetricsTextColumn(trainer, self.theme.metrics, self.theme.metrics_text_delimiter) + self._metric_component = MetricsTextColumn( + trainer, + self.theme.metrics, + self.theme.metrics_text_delimiter, + self.theme.metrics_format, + ) self.progress = CustomProgress( *self.configure_columns(trainer), self._metric_component,