make RichProgressBar more flexible with Rich.Console (#10875)

Co-authored-by: Kaushik B <45285388+kaushikb11@users.noreply.github.com>
Co-authored-by: thomas chaton <thomas@grid.ai>
This commit is contained in:
quancs 2021-12-15 21:26:11 +08:00 committed by GitHub
parent 5eecdcae87
commit fde326d7e0
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 10 additions and 3 deletions

View File

@ -43,6 +43,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Added a warning that shows when `max_epochs` in the `Trainer` is not set ([#10700](https://github.com/PyTorchLightning/pytorch-lightning/issues/10700))
- Added `console_kwargs` for `RichProgressBar` to initialize inner Console ([#10875](https://github.com/PyTorchLightning/pytorch-lightning/pull/10875))
### Changed
- Raised exception in `init_dist_connection()` when torch distibuted is not available ([#10418](https://github.com/PyTorchLightning/pytorch-lightning/issues/10418))
@ -111,6 +114,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Removed duplicated file extension when uploading model checkpoints with `NeptuneLogger` ([#11015](https://github.com/PyTorchLightning/pytorch-lightning/pull/11015))
### Deprecated
- Deprecated `ClusterEnvironment.master_{address,port}` in favor of `ClusterEnvironment.main_{address,port}` ([#10103](https://github.com/PyTorchLightning/pytorch-lightning/issues/10103))

View File

@ -14,7 +14,7 @@
import math
from dataclasses import dataclass
from datetime import timedelta
from typing import Any, Optional, Union
from typing import Any, Dict, Optional, Union
from pytorch_lightning.callbacks.progress.base import ProgressBarBase
from pytorch_lightning.utilities.exceptions import MisconfigurationException
@ -211,6 +211,7 @@ class RichProgressBar(ProgressBarBase):
Set it to ``0`` to disable the display.
leave: Leaves the finished progress bar in the terminal at the end of the epoch. Default: False
theme: Contains styles used to stylize the progress bar.
console_kwargs: Args for constructing a `Console`
Raises:
ModuleNotFoundError:
@ -227,6 +228,7 @@ class RichProgressBar(ProgressBarBase):
refresh_rate: int = 1,
leave: bool = False,
theme: RichProgressBarTheme = RichProgressBarTheme(),
console_kwargs: Optional[Dict[str, Any]] = None,
) -> None:
if not _RICH_AVAILABLE:
raise MisconfigurationException(
@ -236,6 +238,7 @@ class RichProgressBar(ProgressBarBase):
super().__init__()
self._refresh_rate: int = refresh_rate
self._leave: bool = leave
self._console_kwargs = console_kwargs or {}
self._enabled: bool = True
self.progress: Optional[Progress] = None
self.val_sanity_progress_bar_id: Optional[int] = None
@ -281,7 +284,7 @@ class RichProgressBar(ProgressBarBase):
def _init_progress(self, trainer):
if self.is_enabled and (self.progress is None or self._progress_stopped):
self._reset_progress_bar_ids()
self._console: Console = Console()
self._console = Console(**self._console_kwargs)
self._console.clear_live()
self._metric_component = MetricsTextColumn(trainer, self.theme.metrics)
self.progress = CustomProgress(
@ -324,7 +327,7 @@ class RichProgressBar(ProgressBarBase):
def __setstate__(self, state):
self.__dict__ = state
state["_console"] = Console()
self._console = Console(**self._console_kwargs)
def on_sanity_check_start(self, trainer, pl_module):
super().on_sanity_check_start(trainer, pl_module)