Improvements for rich progress bar (#9559)
This commit is contained in:
parent
3aba9d16a8
commit
45200fc858
|
@ -110,7 +110,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
|
|||
- Added a friendly error message when DDP attempts to spawn new distributed processes with rank > 0 ([#9005](https://github.com/PyTorchLightning/pytorch-lightning/pull/9005))
|
||||
|
||||
|
||||
- Added Rich Progress Bar ([#8929](https://github.com/PyTorchLightning/pytorch-lightning/pull/8929))
|
||||
- Added Rich Progress Bar:
|
||||
* Rich Progress Bar ([#8929](https://github.com/PyTorchLightning/pytorch-lightning/pull/8929))
|
||||
* Improvements for rich progress bar ([#9559](https://github.com/PyTorchLightning/pytorch-lightning/pull/9559))
|
||||
|
||||
|
||||
- Added validate logic for precision ([#9080](https://github.com/PyTorchLightning/pytorch-lightning/pull/9080))
|
||||
|
|
|
@ -11,15 +11,18 @@
|
|||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from dataclasses import dataclass
|
||||
from datetime import timedelta
|
||||
from typing import Dict, Optional
|
||||
from typing import Optional, Union
|
||||
|
||||
from pytorch_lightning.callbacks.progress.base import ProgressBarBase
|
||||
from pytorch_lightning.utilities import _RICH_AVAILABLE
|
||||
|
||||
Style = None
|
||||
if _RICH_AVAILABLE:
|
||||
from rich.console import Console, RenderableType
|
||||
from rich.progress import BarColumn, Progress, ProgressColumn, SpinnerColumn, TextColumn
|
||||
from rich.progress import BarColumn, Progress, ProgressColumn, TextColumn
|
||||
from rich.style import Style
|
||||
from rich.text import Text
|
||||
|
||||
class CustomTimeColumn(ProgressColumn):
|
||||
|
@ -27,21 +30,33 @@ if _RICH_AVAILABLE:
|
|||
# Only refresh twice a second to prevent jitter
|
||||
max_refresh = 0.5
|
||||
|
||||
def __init__(self, style: Union[str, Style]) -> None:
|
||||
self.style = style
|
||||
super().__init__()
|
||||
|
||||
def render(self, task) -> Text:
|
||||
elapsed = task.finished_time if task.finished else task.elapsed
|
||||
remaining = task.time_remaining
|
||||
elapsed_delta = "-:--:--" if elapsed is None else str(timedelta(seconds=int(elapsed)))
|
||||
remaining_delta = "-:--:--" if remaining is None else str(timedelta(seconds=int(remaining)))
|
||||
return Text.from_markup(f"[progress.elapsed]{elapsed_delta} < [progress.remaining]{remaining_delta}")
|
||||
return Text(f"{elapsed_delta} • {remaining_delta}", style=self.style)
|
||||
|
||||
class BatchesProcessedColumn(ProgressColumn):
|
||||
def __init__(self, style: Union[str, Style]):
|
||||
self.style = style
|
||||
super().__init__()
|
||||
|
||||
def render(self, task) -> RenderableType:
|
||||
return Text.from_markup(f"[magenta] {int(task.completed)}/{task.total}")
|
||||
return Text(f"{int(task.completed)}/{task.total}", style=self.style)
|
||||
|
||||
class ProcessingSpeedColumn(ProgressColumn):
|
||||
def __init__(self, style: Union[str, Style]):
|
||||
self.style = style
|
||||
super().__init__()
|
||||
|
||||
def render(self, task) -> RenderableType:
|
||||
task_speed = f"{task.speed:>.2f}" if task.speed is not None else "0.00"
|
||||
return Text.from_markup(f"[progress.data.speed] {task_speed}it/s")
|
||||
return Text(f"{task_speed}it/s", style=self.style)
|
||||
|
||||
class MetricsTextColumn(ProgressColumn):
|
||||
"""A column containing text."""
|
||||
|
@ -71,19 +86,26 @@ if _RICH_AVAILABLE:
|
|||
metrics = self._trainer.progress_bar_callback.get_metrics(self._trainer, self._pl_module)
|
||||
else:
|
||||
metrics = self._trainer.progress_bar_metrics
|
||||
|
||||
for k, v in metrics.items():
|
||||
_text += f"{k}: {round(v, 3) if isinstance(v, float) else v} "
|
||||
text = Text.from_markup(_text, style=None, justify="left")
|
||||
return text
|
||||
|
||||
|
||||
STYLES: Dict[str, str] = {
|
||||
"train": "red",
|
||||
"sanity_check": "yellow",
|
||||
"validate": "yellow",
|
||||
"test": "yellow",
|
||||
"predict": "yellow",
|
||||
}
|
||||
@dataclass
|
||||
class RichProgressBarTheme:
|
||||
"""Styles to associate to different base components.
|
||||
|
||||
https://rich.readthedocs.io/en/stable/style.html
|
||||
"""
|
||||
|
||||
text_color: str = "white"
|
||||
progress_bar_complete: Union[str, Style] = "#6206E0"
|
||||
progress_bar_finished: Union[str, Style] = "#6206E0"
|
||||
batch_process: str = "white"
|
||||
time: str = "grey54"
|
||||
processing_speed: str = "grey70"
|
||||
|
||||
|
||||
class RichProgressBar(ProgressBarBase):
|
||||
|
@ -104,13 +126,18 @@ class RichProgressBar(ProgressBarBase):
|
|||
|
||||
Args:
|
||||
refresh_rate: the number of updates per second, must be strictly positive
|
||||
theme: Contains styles used to stylize the progress bar.
|
||||
|
||||
Raises:
|
||||
ImportError:
|
||||
If required `rich` package is not installed on the device.
|
||||
"""
|
||||
|
||||
def __init__(self, refresh_rate: float = 1.0):
|
||||
def __init__(
|
||||
self,
|
||||
refresh_rate: float = 1.0,
|
||||
theme: RichProgressBarTheme = RichProgressBarTheme(),
|
||||
) -> None:
|
||||
if not _RICH_AVAILABLE:
|
||||
raise ImportError(
|
||||
"`RichProgressBar` requires `rich` to be installed. Install it by running `pip install rich`."
|
||||
|
@ -126,6 +153,7 @@ class RichProgressBar(ProgressBarBase):
|
|||
self.test_progress_bar_id: Optional[int] = None
|
||||
self.predict_progress_bar_id: Optional[int] = None
|
||||
self.console = Console(record=True)
|
||||
self.theme = theme
|
||||
|
||||
@property
|
||||
def refresh_rate(self) -> int:
|
||||
|
@ -147,31 +175,28 @@ class RichProgressBar(ProgressBarBase):
|
|||
|
||||
@property
|
||||
def sanity_check_description(self) -> str:
|
||||
return "[Validation Sanity Check]"
|
||||
return "Validation Sanity Check"
|
||||
|
||||
@property
|
||||
def validation_description(self) -> str:
|
||||
return "[Validation]"
|
||||
return "Validation"
|
||||
|
||||
@property
|
||||
def test_description(self) -> str:
|
||||
return "[Testing]"
|
||||
return "Testing"
|
||||
|
||||
@property
|
||||
def predict_description(self) -> str:
|
||||
return "[Predicting]"
|
||||
return "Predicting"
|
||||
|
||||
def setup(self, trainer, pl_module, stage):
|
||||
self.progress = Progress(
|
||||
SpinnerColumn(),
|
||||
TextColumn("[progress.description]{task.description}"),
|
||||
BarColumn(),
|
||||
BatchesProcessedColumn(),
|
||||
"[",
|
||||
CustomTimeColumn(),
|
||||
ProcessingSpeedColumn(),
|
||||
BarColumn(complete_style=self.theme.progress_bar_complete, finished_style=self.theme.progress_bar_finished),
|
||||
BatchesProcessedColumn(style=self.theme.batch_process),
|
||||
CustomTimeColumn(style=self.theme.time),
|
||||
ProcessingSpeedColumn(style=self.theme.processing_speed),
|
||||
MetricsTextColumn(trainer, pl_module, stage),
|
||||
"]",
|
||||
console=self.console,
|
||||
refresh_per_second=self.refresh_rate,
|
||||
).__enter__()
|
||||
|
@ -179,7 +204,7 @@ class RichProgressBar(ProgressBarBase):
|
|||
def on_sanity_check_start(self, trainer, pl_module):
|
||||
super().on_sanity_check_start(trainer, pl_module)
|
||||
self.val_sanity_progress_bar_id = self.progress.add_task(
|
||||
f"[{STYLES['sanity_check']}]{self.sanity_check_description}",
|
||||
f"[{self.theme.text_color}]{self.sanity_check_description}",
|
||||
total=trainer.num_sanity_val_steps,
|
||||
)
|
||||
|
||||
|
@ -201,7 +226,7 @@ class RichProgressBar(ProgressBarBase):
|
|||
train_description = self._get_train_description(trainer.current_epoch)
|
||||
|
||||
self.main_progress_bar_id = self.progress.add_task(
|
||||
f"[{STYLES['train']}]{train_description}",
|
||||
f"[{self.theme.text_color}]{train_description}",
|
||||
total=total_batches,
|
||||
)
|
||||
|
||||
|
@ -209,7 +234,7 @@ class RichProgressBar(ProgressBarBase):
|
|||
super().on_validation_epoch_start(trainer, pl_module)
|
||||
if self._total_val_batches > 0:
|
||||
self.val_progress_bar_id = self.progress.add_task(
|
||||
f"[{STYLES['validate']}]{self.validation_description}",
|
||||
f"[{self.theme.text_color}]{self.validation_description}",
|
||||
total=self._total_val_batches,
|
||||
)
|
||||
|
||||
|
@ -221,14 +246,14 @@ class RichProgressBar(ProgressBarBase):
|
|||
def on_test_epoch_start(self, trainer, pl_module):
|
||||
super().on_train_epoch_start(trainer, pl_module)
|
||||
self.test_progress_bar_id = self.progress.add_task(
|
||||
f"[{STYLES['test']}]{self.test_description}",
|
||||
f"[{self.theme.text_color}]{self.test_description}",
|
||||
total=self.total_test_batches,
|
||||
)
|
||||
|
||||
def on_predict_epoch_start(self, trainer, pl_module):
|
||||
super().on_predict_epoch_start(trainer, pl_module)
|
||||
self.predict_progress_bar_id = self.progress.add_task(
|
||||
f"[{STYLES['predict']}]{self.predict_description}",
|
||||
f"[{self.theme.text_color}]{self.predict_description}",
|
||||
total=self.total_predict_batches,
|
||||
)
|
||||
|
||||
|
@ -261,7 +286,7 @@ class RichProgressBar(ProgressBarBase):
|
|||
return self.is_enabled and (current % self.refresh_rate == 0 or current == total)
|
||||
|
||||
def _get_train_description(self, current_epoch: int) -> str:
|
||||
train_description = f"[Epoch {current_epoch}]"
|
||||
train_description = f"Epoch {current_epoch}"
|
||||
if len(self.validation_description) > len(train_description):
|
||||
# Padding is required to avoid flickering due of uneven lengths of "Epoch X"
|
||||
# and "Validation" Bar description
|
||||
|
@ -273,3 +298,7 @@ class RichProgressBar(ProgressBarBase):
|
|||
|
||||
def teardown(self, trainer, pl_module, stage):
|
||||
self.progress.__exit__(None, None, None)
|
||||
|
||||
def on_exception(self, trainer, pl_module, exception: BaseException) -> None:
|
||||
if isinstance(exception, KeyboardInterrupt):
|
||||
self.progress.stop()
|
||||
|
|
|
@ -12,11 +12,13 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from unittest import mock
|
||||
from unittest.mock import DEFAULT
|
||||
|
||||
import pytest
|
||||
|
||||
from pytorch_lightning import Trainer
|
||||
from pytorch_lightning.callbacks import ProgressBarBase, RichProgressBar
|
||||
from pytorch_lightning.callbacks.progress.rich_progress import RichProgressBarTheme
|
||||
from pytorch_lightning.utilities.imports import _RICH_AVAILABLE
|
||||
from tests.helpers.boring_model import BoringModel
|
||||
from tests.helpers.runif import RunIf
|
||||
|
@ -24,7 +26,6 @@ from tests.helpers.runif import RunIf
|
|||
|
||||
@RunIf(rich=True)
|
||||
def test_rich_progress_bar_callback():
|
||||
|
||||
trainer = Trainer(callbacks=RichProgressBar())
|
||||
|
||||
progress_bars = [c for c in trainer.callbacks if isinstance(c, ProgressBarBase)]
|
||||
|
@ -36,7 +37,6 @@ def test_rich_progress_bar_callback():
|
|||
@RunIf(rich=True)
|
||||
@mock.patch("pytorch_lightning.callbacks.progress.rich_progress.Progress.update")
|
||||
def test_rich_progress_bar(progress_update, tmpdir):
|
||||
|
||||
model = BoringModel()
|
||||
|
||||
trainer = Trainer(
|
||||
|
@ -58,7 +58,61 @@ def test_rich_progress_bar(progress_update, tmpdir):
|
|||
|
||||
|
||||
def test_rich_progress_bar_import_error():
|
||||
|
||||
if not _RICH_AVAILABLE:
|
||||
with pytest.raises(ImportError, match="`RichProgressBar` requires `rich` to be installed."):
|
||||
Trainer(callbacks=RichProgressBar())
|
||||
|
||||
|
||||
@RunIf(rich=True)
|
||||
def test_rich_progress_bar_custom_theme(tmpdir):
|
||||
"""Test to ensure that custom theme styles are used."""
|
||||
with mock.patch.multiple(
|
||||
"pytorch_lightning.callbacks.progress.rich_progress",
|
||||
BarColumn=DEFAULT,
|
||||
BatchesProcessedColumn=DEFAULT,
|
||||
CustomTimeColumn=DEFAULT,
|
||||
ProcessingSpeedColumn=DEFAULT,
|
||||
) as mocks:
|
||||
|
||||
theme = RichProgressBarTheme()
|
||||
|
||||
progress_bar = RichProgressBar(theme=theme)
|
||||
progress_bar.setup(Trainer(tmpdir), BoringModel(), stage=None)
|
||||
|
||||
assert progress_bar.theme == theme
|
||||
args, kwargs = mocks["BarColumn"].call_args
|
||||
assert kwargs["complete_style"] == theme.progress_bar_complete
|
||||
assert kwargs["finished_style"] == theme.progress_bar_finished
|
||||
|
||||
args, kwargs = mocks["BatchesProcessedColumn"].call_args
|
||||
assert kwargs["style"] == theme.batch_process
|
||||
|
||||
args, kwargs = mocks["CustomTimeColumn"].call_args
|
||||
assert kwargs["style"] == theme.time
|
||||
|
||||
args, kwargs = mocks["ProcessingSpeedColumn"].call_args
|
||||
assert kwargs["style"] == theme.processing_speed
|
||||
|
||||
|
||||
@RunIf(rich=True)
|
||||
def test_rich_progress_bar_keyboard_interrupt(tmpdir):
|
||||
"""Test to ensure that when the user keyboard interrupts, we close the progress bar."""
|
||||
|
||||
class TestModel(BoringModel):
|
||||
def on_train_start(self) -> None:
|
||||
raise KeyboardInterrupt
|
||||
|
||||
model = TestModel()
|
||||
|
||||
with mock.patch(
|
||||
"pytorch_lightning.callbacks.progress.rich_progress.Progress.stop", autospec=True
|
||||
) as mock_progress_stop:
|
||||
progress_bar = RichProgressBar()
|
||||
trainer = Trainer(
|
||||
default_root_dir=tmpdir,
|
||||
fast_dev_run=True,
|
||||
callbacks=progress_bar,
|
||||
)
|
||||
|
||||
trainer.fit(model)
|
||||
mock_progress_stop.assert_called_once()
|
||||
|
|
Loading…
Reference in New Issue