lightning/pytorch_lightning/callbacks/progress/rich_progress.py

270 lines
11 KiB
Python

# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from datetime import timedelta
from typing import Dict, Optional
from pytorch_lightning.callbacks.progress.base import ProgressBarBase
from pytorch_lightning.utilities import _RICH_AVAILABLE
if _RICH_AVAILABLE:
from rich.console import Console, RenderableType
from rich.progress import BarColumn, Progress, ProgressColumn, SpinnerColumn, TextColumn
from rich.text import Text
class CustomTimeColumn(ProgressColumn):
# Only refresh twice a second to prevent jitter
max_refresh = 0.5
def render(self, task) -> Text:
elapsed = task.finished_time if task.finished else task.elapsed
remaining = task.time_remaining
elapsed_delta = "-:--:--" if elapsed is None else str(timedelta(seconds=int(elapsed)))
remaining_delta = "-:--:--" if remaining is None else str(timedelta(seconds=int(remaining)))
return Text.from_markup(f"[progress.elapsed]{elapsed_delta} < [progress.remaining]{remaining_delta}")
class BatchesProcessedColumn(ProgressColumn):
def render(self, task) -> RenderableType:
return Text.from_markup(f"[magenta] {int(task.completed)}/{task.total}")
class ProcessingSpeedColumn(ProgressColumn):
def render(self, task) -> RenderableType:
task_speed = f"{task.speed:>.2f}" if task.speed is not None else "0.00"
return Text.from_markup(f"[progress.data.speed] {task_speed}it/s")
class MetricsTextColumn(ProgressColumn):
"""A column containing text."""
def __init__(self, trainer, stage):
self._trainer = trainer
self._stage = stage
self._tasks = {}
self._current_task_id = 0
super().__init__()
def render(self, task) -> Text:
if self._stage != "fit" or self._trainer.sanity_checking:
return ""
if self._trainer.training and task.id not in self._tasks:
self._tasks[task.id] = "None"
if self._renderable_cache:
self._tasks[self._current_task_id] = self._renderable_cache[self._current_task_id][1]
self._current_task_id = task.id
if self._trainer.training and task.id != self._current_task_id:
return self._tasks[task.id]
_text = ""
for k, v in self._trainer.progress_bar_dict.items():
_text += f"{k}: {round(v, 3) if isinstance(v, float) else v} "
text = Text.from_markup(_text, style=None, justify="left")
return text
STYLES: Dict[str, str] = {
"train": "red",
"sanity_check": "yellow",
"validate": "yellow",
"test": "yellow",
"predict": "yellow",
}
class RichProgressBar(ProgressBarBase):
"""
Create a progress bar with `rich text formatting <https://github.com/willmcgugan/rich>`_.
Install it with pip:
.. code-block:: bash
pip install rich
.. code-block:: python
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import RichProgressBar
trainer = Trainer(callbacks=RichProgressBar())
Args:
refresh_rate: the number of updates per second, must be strictly positive
Raises:
ImportError:
If required `rich` package is not installed on the device.
"""
def __init__(self, refresh_rate: float = 1.0):
if not _RICH_AVAILABLE:
raise ImportError(
"`RichProgressBar` requires `rich` to be installed. Install it by running `pip install rich`."
)
super().__init__()
self._refresh_rate: float = refresh_rate
self._enabled: bool = True
self._total_val_batches: int = 0
self.progress: Progress = None
self.val_sanity_progress_bar_id: Optional[int] = None
self.main_progress_bar_id: Optional[int] = None
self.val_progress_bar_id: Optional[int] = None
self.test_progress_bar_id: Optional[int] = None
self.predict_progress_bar_id: Optional[int] = None
self.console = Console(record=True)
@property
def refresh_rate(self) -> int:
return self._refresh_rate
@property
def is_enabled(self) -> bool:
return self._enabled and self.refresh_rate > 0
@property
def is_disabled(self) -> bool:
return not self.is_enabled
def disable(self) -> None:
self._enabled = False
def enable(self) -> None:
self._enabled = True
@property
def sanity_check_description(self) -> str:
return "[Validation Sanity Check]"
@property
def validation_description(self) -> str:
return "[Validation]"
@property
def test_description(self) -> str:
return "[Testing]"
@property
def predict_description(self) -> str:
return "[Predicting]"
def setup(self, trainer, pl_module, stage):
self.progress = Progress(
SpinnerColumn(),
TextColumn("[progress.description]{task.description}"),
BarColumn(),
BatchesProcessedColumn(),
"[",
CustomTimeColumn(),
ProcessingSpeedColumn(),
MetricsTextColumn(trainer, stage),
"]",
console=self.console,
refresh_per_second=self.refresh_rate,
).__enter__()
def on_sanity_check_start(self, trainer, pl_module):
super().on_sanity_check_start(trainer, pl_module)
self.val_sanity_progress_bar_id = self.progress.add_task(
f"[{STYLES['sanity_check']}]{self.sanity_check_description}",
total=trainer.num_sanity_val_steps,
)
def on_sanity_check_end(self, trainer, pl_module):
super().on_sanity_check_end(trainer, pl_module)
self.progress.update(self.val_sanity_progress_bar_id, visible=False)
def on_train_epoch_start(self, trainer, pl_module):
super().on_train_epoch_start(trainer, pl_module)
total_train_batches = self.total_train_batches
self._total_val_batches = self.total_val_batches
if total_train_batches != float("inf"):
# val can be checked multiple times per epoch
val_checks_per_epoch = total_train_batches // trainer.val_check_batch
self._total_val_batches = self._total_val_batches * val_checks_per_epoch
total_batches = total_train_batches + self._total_val_batches
train_description = self._get_train_description(trainer.current_epoch)
self.main_progress_bar_id = self.progress.add_task(
f"[{STYLES['train']}]{train_description}",
total=total_batches,
)
def on_validation_epoch_start(self, trainer, pl_module):
super().on_validation_epoch_start(trainer, pl_module)
if self._total_val_batches > 0:
self.val_progress_bar_id = self.progress.add_task(
f"[{STYLES['validate']}]{self.validation_description}",
total=self._total_val_batches,
)
def on_validation_epoch_end(self, trainer, pl_module):
super().on_validation_epoch_end(trainer, pl_module)
if self.val_progress_bar_id is not None:
self.progress.update(self.val_progress_bar_id, visible=False)
def on_test_epoch_start(self, trainer, pl_module):
super().on_train_epoch_start(trainer, pl_module)
self.test_progress_bar_id = self.progress.add_task(
f"[{STYLES['test']}]{self.test_description}",
total=self.total_test_batches,
)
def on_predict_epoch_start(self, trainer, pl_module):
super().on_predict_epoch_start(trainer, pl_module)
self.predict_progress_bar_id = self.progress.add_task(
f"[{STYLES['predict']}]{self.predict_description}",
total=self.total_predict_batches,
)
def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx):
super().on_train_batch_end(trainer, pl_module, outputs, batch, batch_idx, dataloader_idx)
if self._should_update(self.train_batch_idx, self.total_train_batches + self.total_val_batches):
self.progress.update(self.main_progress_bar_id, advance=1.0)
def on_validation_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx):
super().on_validation_batch_end(trainer, pl_module, outputs, batch, batch_idx, dataloader_idx)
if trainer.sanity_checking:
self.progress.update(self.val_sanity_progress_bar_id, advance=1.0)
elif self.val_progress_bar_id and self._should_update(
self.val_batch_idx, self.total_train_batches + self.total_val_batches
):
self.progress.update(self.main_progress_bar_id, advance=1.0)
self.progress.update(self.val_progress_bar_id, advance=1.0)
def on_test_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx):
super().on_test_batch_end(trainer, pl_module, outputs, batch, batch_idx, dataloader_idx)
if self._should_update(self.test_batch_idx, self.total_test_batches):
self.progress.update(self.test_progress_bar_id, advance=1.0)
def on_predict_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx):
super().on_predict_batch_end(trainer, pl_module, outputs, batch, batch_idx, dataloader_idx)
if self._should_update(self.predict_batch_idx, self.total_predict_batches):
self.progress.update(self.predict_progress_bar_id, advance=1.0)
def _should_update(self, current, total) -> bool:
return self.is_enabled and (current % self.refresh_rate == 0 or current == total)
def _get_train_description(self, current_epoch: int) -> str:
train_description = f"[Epoch {current_epoch}]"
if len(self.validation_description) > len(train_description):
# Padding is required to avoid flickering due of uneven lengths of "Epoch X"
# and "Validation" Bar description
num_digits = len(str(current_epoch))
required_padding = (len(self.validation_description) - len(train_description) + 1) - num_digits
for _ in range(required_padding):
train_description += " "
return train_description
def teardown(self, trainer, pl_module, stage):
self.progress.__exit__(None, None, None)