Improvements for rich progress bar (#9559)

This commit is contained in:
Sean Naren 2021-09-16 22:11:59 +01:00 committed by GitHub
parent 3aba9d16a8
commit 45200fc858
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 119 additions and 34 deletions

View File

@ -110,7 +110,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Added a friendly error message when DDP attempts to spawn new distributed processes with rank > 0 ([#9005](https://github.com/PyTorchLightning/pytorch-lightning/pull/9005))
- Added Rich Progress Bar ([#8929](https://github.com/PyTorchLightning/pytorch-lightning/pull/8929))
- Added Rich Progress Bar:
* Rich Progress Bar ([#8929](https://github.com/PyTorchLightning/pytorch-lightning/pull/8929))
* Improvements for rich progress bar ([#9559](https://github.com/PyTorchLightning/pytorch-lightning/pull/9559))
- Added validate logic for precision ([#9080](https://github.com/PyTorchLightning/pytorch-lightning/pull/9080))

View File

@ -11,15 +11,18 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass
from datetime import timedelta
from typing import Dict, Optional
from typing import Optional, Union
from pytorch_lightning.callbacks.progress.base import ProgressBarBase
from pytorch_lightning.utilities import _RICH_AVAILABLE
Style = None
if _RICH_AVAILABLE:
from rich.console import Console, RenderableType
from rich.progress import BarColumn, Progress, ProgressColumn, SpinnerColumn, TextColumn
from rich.progress import BarColumn, Progress, ProgressColumn, TextColumn
from rich.style import Style
from rich.text import Text
class CustomTimeColumn(ProgressColumn):
@ -27,21 +30,33 @@ if _RICH_AVAILABLE:
# Only refresh twice a second to prevent jitter
max_refresh = 0.5
def __init__(self, style: Union[str, Style]) -> None:
self.style = style
super().__init__()
def render(self, task) -> Text:
elapsed = task.finished_time if task.finished else task.elapsed
remaining = task.time_remaining
elapsed_delta = "-:--:--" if elapsed is None else str(timedelta(seconds=int(elapsed)))
remaining_delta = "-:--:--" if remaining is None else str(timedelta(seconds=int(remaining)))
return Text.from_markup(f"[progress.elapsed]{elapsed_delta} < [progress.remaining]{remaining_delta}")
return Text(f"{elapsed_delta}{remaining_delta}", style=self.style)
class BatchesProcessedColumn(ProgressColumn):
def __init__(self, style: Union[str, Style]):
self.style = style
super().__init__()
def render(self, task) -> RenderableType:
return Text.from_markup(f"[magenta] {int(task.completed)}/{task.total}")
return Text(f"{int(task.completed)}/{task.total}", style=self.style)
class ProcessingSpeedColumn(ProgressColumn):
def __init__(self, style: Union[str, Style]):
self.style = style
super().__init__()
def render(self, task) -> RenderableType:
task_speed = f"{task.speed:>.2f}" if task.speed is not None else "0.00"
return Text.from_markup(f"[progress.data.speed] {task_speed}it/s")
return Text(f"{task_speed}it/s", style=self.style)
class MetricsTextColumn(ProgressColumn):
"""A column containing text."""
@ -71,19 +86,26 @@ if _RICH_AVAILABLE:
metrics = self._trainer.progress_bar_callback.get_metrics(self._trainer, self._pl_module)
else:
metrics = self._trainer.progress_bar_metrics
for k, v in metrics.items():
_text += f"{k}: {round(v, 3) if isinstance(v, float) else v} "
text = Text.from_markup(_text, style=None, justify="left")
return text
STYLES: Dict[str, str] = {
"train": "red",
"sanity_check": "yellow",
"validate": "yellow",
"test": "yellow",
"predict": "yellow",
}
@dataclass
class RichProgressBarTheme:
"""Styles to associate to different base components.
https://rich.readthedocs.io/en/stable/style.html
"""
text_color: str = "white"
progress_bar_complete: Union[str, Style] = "#6206E0"
progress_bar_finished: Union[str, Style] = "#6206E0"
batch_process: str = "white"
time: str = "grey54"
processing_speed: str = "grey70"
class RichProgressBar(ProgressBarBase):
@ -104,13 +126,18 @@ class RichProgressBar(ProgressBarBase):
Args:
refresh_rate: the number of updates per second, must be strictly positive
theme: Contains styles used to stylize the progress bar.
Raises:
ImportError:
If required `rich` package is not installed on the device.
"""
def __init__(self, refresh_rate: float = 1.0):
def __init__(
self,
refresh_rate: float = 1.0,
theme: RichProgressBarTheme = RichProgressBarTheme(),
) -> None:
if not _RICH_AVAILABLE:
raise ImportError(
"`RichProgressBar` requires `rich` to be installed. Install it by running `pip install rich`."
@ -126,6 +153,7 @@ class RichProgressBar(ProgressBarBase):
self.test_progress_bar_id: Optional[int] = None
self.predict_progress_bar_id: Optional[int] = None
self.console = Console(record=True)
self.theme = theme
@property
def refresh_rate(self) -> int:
@ -147,31 +175,28 @@ class RichProgressBar(ProgressBarBase):
@property
def sanity_check_description(self) -> str:
return "[Validation Sanity Check]"
return "Validation Sanity Check"
@property
def validation_description(self) -> str:
return "[Validation]"
return "Validation"
@property
def test_description(self) -> str:
return "[Testing]"
return "Testing"
@property
def predict_description(self) -> str:
return "[Predicting]"
return "Predicting"
def setup(self, trainer, pl_module, stage):
self.progress = Progress(
SpinnerColumn(),
TextColumn("[progress.description]{task.description}"),
BarColumn(),
BatchesProcessedColumn(),
"[",
CustomTimeColumn(),
ProcessingSpeedColumn(),
BarColumn(complete_style=self.theme.progress_bar_complete, finished_style=self.theme.progress_bar_finished),
BatchesProcessedColumn(style=self.theme.batch_process),
CustomTimeColumn(style=self.theme.time),
ProcessingSpeedColumn(style=self.theme.processing_speed),
MetricsTextColumn(trainer, pl_module, stage),
"]",
console=self.console,
refresh_per_second=self.refresh_rate,
).__enter__()
@ -179,7 +204,7 @@ class RichProgressBar(ProgressBarBase):
def on_sanity_check_start(self, trainer, pl_module):
super().on_sanity_check_start(trainer, pl_module)
self.val_sanity_progress_bar_id = self.progress.add_task(
f"[{STYLES['sanity_check']}]{self.sanity_check_description}",
f"[{self.theme.text_color}]{self.sanity_check_description}",
total=trainer.num_sanity_val_steps,
)
@ -201,7 +226,7 @@ class RichProgressBar(ProgressBarBase):
train_description = self._get_train_description(trainer.current_epoch)
self.main_progress_bar_id = self.progress.add_task(
f"[{STYLES['train']}]{train_description}",
f"[{self.theme.text_color}]{train_description}",
total=total_batches,
)
@ -209,7 +234,7 @@ class RichProgressBar(ProgressBarBase):
super().on_validation_epoch_start(trainer, pl_module)
if self._total_val_batches > 0:
self.val_progress_bar_id = self.progress.add_task(
f"[{STYLES['validate']}]{self.validation_description}",
f"[{self.theme.text_color}]{self.validation_description}",
total=self._total_val_batches,
)
@ -221,14 +246,14 @@ class RichProgressBar(ProgressBarBase):
def on_test_epoch_start(self, trainer, pl_module):
super().on_train_epoch_start(trainer, pl_module)
self.test_progress_bar_id = self.progress.add_task(
f"[{STYLES['test']}]{self.test_description}",
f"[{self.theme.text_color}]{self.test_description}",
total=self.total_test_batches,
)
def on_predict_epoch_start(self, trainer, pl_module):
super().on_predict_epoch_start(trainer, pl_module)
self.predict_progress_bar_id = self.progress.add_task(
f"[{STYLES['predict']}]{self.predict_description}",
f"[{self.theme.text_color}]{self.predict_description}",
total=self.total_predict_batches,
)
@ -261,7 +286,7 @@ class RichProgressBar(ProgressBarBase):
return self.is_enabled and (current % self.refresh_rate == 0 or current == total)
def _get_train_description(self, current_epoch: int) -> str:
train_description = f"[Epoch {current_epoch}]"
train_description = f"Epoch {current_epoch}"
if len(self.validation_description) > len(train_description):
# Padding is required to avoid flickering due of uneven lengths of "Epoch X"
# and "Validation" Bar description
@ -273,3 +298,7 @@ class RichProgressBar(ProgressBarBase):
def teardown(self, trainer, pl_module, stage):
self.progress.__exit__(None, None, None)
def on_exception(self, trainer, pl_module, exception: BaseException) -> None:
if isinstance(exception, KeyboardInterrupt):
self.progress.stop()

View File

@ -12,11 +12,13 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from unittest import mock
from unittest.mock import DEFAULT
import pytest
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import ProgressBarBase, RichProgressBar
from pytorch_lightning.callbacks.progress.rich_progress import RichProgressBarTheme
from pytorch_lightning.utilities.imports import _RICH_AVAILABLE
from tests.helpers.boring_model import BoringModel
from tests.helpers.runif import RunIf
@ -24,7 +26,6 @@ from tests.helpers.runif import RunIf
@RunIf(rich=True)
def test_rich_progress_bar_callback():
trainer = Trainer(callbacks=RichProgressBar())
progress_bars = [c for c in trainer.callbacks if isinstance(c, ProgressBarBase)]
@ -36,7 +37,6 @@ def test_rich_progress_bar_callback():
@RunIf(rich=True)
@mock.patch("pytorch_lightning.callbacks.progress.rich_progress.Progress.update")
def test_rich_progress_bar(progress_update, tmpdir):
model = BoringModel()
trainer = Trainer(
@ -58,7 +58,61 @@ def test_rich_progress_bar(progress_update, tmpdir):
def test_rich_progress_bar_import_error():
if not _RICH_AVAILABLE:
with pytest.raises(ImportError, match="`RichProgressBar` requires `rich` to be installed."):
Trainer(callbacks=RichProgressBar())
@RunIf(rich=True)
def test_rich_progress_bar_custom_theme(tmpdir):
"""Test to ensure that custom theme styles are used."""
with mock.patch.multiple(
"pytorch_lightning.callbacks.progress.rich_progress",
BarColumn=DEFAULT,
BatchesProcessedColumn=DEFAULT,
CustomTimeColumn=DEFAULT,
ProcessingSpeedColumn=DEFAULT,
) as mocks:
theme = RichProgressBarTheme()
progress_bar = RichProgressBar(theme=theme)
progress_bar.setup(Trainer(tmpdir), BoringModel(), stage=None)
assert progress_bar.theme == theme
args, kwargs = mocks["BarColumn"].call_args
assert kwargs["complete_style"] == theme.progress_bar_complete
assert kwargs["finished_style"] == theme.progress_bar_finished
args, kwargs = mocks["BatchesProcessedColumn"].call_args
assert kwargs["style"] == theme.batch_process
args, kwargs = mocks["CustomTimeColumn"].call_args
assert kwargs["style"] == theme.time
args, kwargs = mocks["ProcessingSpeedColumn"].call_args
assert kwargs["style"] == theme.processing_speed
@RunIf(rich=True)
def test_rich_progress_bar_keyboard_interrupt(tmpdir):
"""Test to ensure that when the user keyboard interrupts, we close the progress bar."""
class TestModel(BoringModel):
def on_train_start(self) -> None:
raise KeyboardInterrupt
model = TestModel()
with mock.patch(
"pytorch_lightning.callbacks.progress.rich_progress.Progress.stop", autospec=True
) as mock_progress_stop:
progress_bar = RichProgressBar()
trainer = Trainer(
default_root_dir=tmpdir,
fast_dev_run=True,
callbacks=progress_bar,
)
trainer.fit(model)
mock_progress_stop.assert_called_once()