diff --git a/CHANGELOG.md b/CHANGELOG.md index 331185ee82..0ad5733118 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -43,6 +43,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Added a warning that shows when `max_epochs` in the `Trainer` is not set ([#10700](https://github.com/PyTorchLightning/pytorch-lightning/issues/10700)) +- Added `console_kwargs` for `RichProgressBar` to initialize inner Console ([#10875](https://github.com/PyTorchLightning/pytorch-lightning/pull/10875)) + + ### Changed - Raised exception in `init_dist_connection()` when torch distibuted is not available ([#10418](https://github.com/PyTorchLightning/pytorch-lightning/issues/10418)) @@ -111,6 +114,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Removed duplicated file extension when uploading model checkpoints with `NeptuneLogger` ([#11015](https://github.com/PyTorchLightning/pytorch-lightning/pull/11015)) + ### Deprecated - Deprecated `ClusterEnvironment.master_{address,port}` in favor of `ClusterEnvironment.main_{address,port}` ([#10103](https://github.com/PyTorchLightning/pytorch-lightning/issues/10103)) diff --git a/pytorch_lightning/callbacks/progress/rich_progress.py b/pytorch_lightning/callbacks/progress/rich_progress.py index fa4d925800..2e24df6da6 100644 --- a/pytorch_lightning/callbacks/progress/rich_progress.py +++ b/pytorch_lightning/callbacks/progress/rich_progress.py @@ -14,7 +14,7 @@ import math from dataclasses import dataclass from datetime import timedelta -from typing import Any, Optional, Union +from typing import Any, Dict, Optional, Union from pytorch_lightning.callbacks.progress.base import ProgressBarBase from pytorch_lightning.utilities.exceptions import MisconfigurationException @@ -211,6 +211,7 @@ class RichProgressBar(ProgressBarBase): Set it to ``0`` to disable the display. leave: Leaves the finished progress bar in the terminal at the end of the epoch. Default: False theme: Contains styles used to stylize the progress bar. + console_kwargs: Args for constructing a `Console` Raises: ModuleNotFoundError: @@ -227,6 +228,7 @@ class RichProgressBar(ProgressBarBase): refresh_rate: int = 1, leave: bool = False, theme: RichProgressBarTheme = RichProgressBarTheme(), + console_kwargs: Optional[Dict[str, Any]] = None, ) -> None: if not _RICH_AVAILABLE: raise MisconfigurationException( @@ -236,6 +238,7 @@ class RichProgressBar(ProgressBarBase): super().__init__() self._refresh_rate: int = refresh_rate self._leave: bool = leave + self._console_kwargs = console_kwargs or {} self._enabled: bool = True self.progress: Optional[Progress] = None self.val_sanity_progress_bar_id: Optional[int] = None @@ -281,7 +284,7 @@ class RichProgressBar(ProgressBarBase): def _init_progress(self, trainer): if self.is_enabled and (self.progress is None or self._progress_stopped): self._reset_progress_bar_ids() - self._console: Console = Console() + self._console = Console(**self._console_kwargs) self._console.clear_live() self._metric_component = MetricsTextColumn(trainer, self.theme.metrics) self.progress = CustomProgress( @@ -324,7 +327,7 @@ class RichProgressBar(ProgressBarBase): def __setstate__(self, state): self.__dict__ = state - state["_console"] = Console() + self._console = Console(**self._console_kwargs) def on_sanity_check_start(self, trainer, pl_module): super().on_sanity_check_start(trainer, pl_module)