make RichProgressBar more flexible with Rich.Console (#10875)
Co-authored-by: Kaushik B <45285388+kaushikb11@users.noreply.github.com> Co-authored-by: thomas chaton <thomas@grid.ai>
This commit is contained in:
parent
5eecdcae87
commit
fde326d7e0
|
@ -43,6 +43,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
|
|||
- Added a warning that shows when `max_epochs` in the `Trainer` is not set ([#10700](https://github.com/PyTorchLightning/pytorch-lightning/issues/10700))
|
||||
|
||||
|
||||
- Added `console_kwargs` for `RichProgressBar` to initialize inner Console ([#10875](https://github.com/PyTorchLightning/pytorch-lightning/pull/10875))
|
||||
|
||||
|
||||
### Changed
|
||||
|
||||
- Raised exception in `init_dist_connection()` when torch distibuted is not available ([#10418](https://github.com/PyTorchLightning/pytorch-lightning/issues/10418))
|
||||
|
@ -111,6 +114,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
|
|||
- Removed duplicated file extension when uploading model checkpoints with `NeptuneLogger` ([#11015](https://github.com/PyTorchLightning/pytorch-lightning/pull/11015))
|
||||
|
||||
|
||||
|
||||
### Deprecated
|
||||
|
||||
- Deprecated `ClusterEnvironment.master_{address,port}` in favor of `ClusterEnvironment.main_{address,port}` ([#10103](https://github.com/PyTorchLightning/pytorch-lightning/issues/10103))
|
||||
|
|
|
@ -14,7 +14,7 @@
|
|||
import math
|
||||
from dataclasses import dataclass
|
||||
from datetime import timedelta
|
||||
from typing import Any, Optional, Union
|
||||
from typing import Any, Dict, Optional, Union
|
||||
|
||||
from pytorch_lightning.callbacks.progress.base import ProgressBarBase
|
||||
from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
||||
|
@ -211,6 +211,7 @@ class RichProgressBar(ProgressBarBase):
|
|||
Set it to ``0`` to disable the display.
|
||||
leave: Leaves the finished progress bar in the terminal at the end of the epoch. Default: False
|
||||
theme: Contains styles used to stylize the progress bar.
|
||||
console_kwargs: Args for constructing a `Console`
|
||||
|
||||
Raises:
|
||||
ModuleNotFoundError:
|
||||
|
@ -227,6 +228,7 @@ class RichProgressBar(ProgressBarBase):
|
|||
refresh_rate: int = 1,
|
||||
leave: bool = False,
|
||||
theme: RichProgressBarTheme = RichProgressBarTheme(),
|
||||
console_kwargs: Optional[Dict[str, Any]] = None,
|
||||
) -> None:
|
||||
if not _RICH_AVAILABLE:
|
||||
raise MisconfigurationException(
|
||||
|
@ -236,6 +238,7 @@ class RichProgressBar(ProgressBarBase):
|
|||
super().__init__()
|
||||
self._refresh_rate: int = refresh_rate
|
||||
self._leave: bool = leave
|
||||
self._console_kwargs = console_kwargs or {}
|
||||
self._enabled: bool = True
|
||||
self.progress: Optional[Progress] = None
|
||||
self.val_sanity_progress_bar_id: Optional[int] = None
|
||||
|
@ -281,7 +284,7 @@ class RichProgressBar(ProgressBarBase):
|
|||
def _init_progress(self, trainer):
|
||||
if self.is_enabled and (self.progress is None or self._progress_stopped):
|
||||
self._reset_progress_bar_ids()
|
||||
self._console: Console = Console()
|
||||
self._console = Console(**self._console_kwargs)
|
||||
self._console.clear_live()
|
||||
self._metric_component = MetricsTextColumn(trainer, self.theme.metrics)
|
||||
self.progress = CustomProgress(
|
||||
|
@ -324,7 +327,7 @@ class RichProgressBar(ProgressBarBase):
|
|||
|
||||
def __setstate__(self, state):
|
||||
self.__dict__ = state
|
||||
state["_console"] = Console()
|
||||
self._console = Console(**self._console_kwargs)
|
||||
|
||||
def on_sanity_check_start(self, trainer, pl_module):
|
||||
super().on_sanity_check_start(trainer, pl_module)
|
||||
|
|
Loading…
Reference in New Issue