270 lines
11 KiB
Python
270 lines
11 KiB
Python
|
# Copyright The PyTorch Lightning team.
|
||
|
#
|
||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||
|
# you may not use this file except in compliance with the License.
|
||
|
# You may obtain a copy of the License at
|
||
|
#
|
||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||
|
#
|
||
|
# Unless required by applicable law or agreed to in writing, software
|
||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||
|
# See the License for the specific language governing permissions and
|
||
|
# limitations under the License.
|
||
|
from datetime import timedelta
|
||
|
from typing import Dict, Optional
|
||
|
|
||
|
from pytorch_lightning.callbacks.progress.base import ProgressBarBase
|
||
|
from pytorch_lightning.utilities import _RICH_AVAILABLE
|
||
|
|
||
|
if _RICH_AVAILABLE:
|
||
|
from rich.console import Console, RenderableType
|
||
|
from rich.progress import BarColumn, Progress, ProgressColumn, SpinnerColumn, TextColumn
|
||
|
from rich.text import Text
|
||
|
|
||
|
class CustomTimeColumn(ProgressColumn):
|
||
|
|
||
|
# Only refresh twice a second to prevent jitter
|
||
|
max_refresh = 0.5
|
||
|
|
||
|
def render(self, task) -> Text:
|
||
|
elapsed = task.finished_time if task.finished else task.elapsed
|
||
|
remaining = task.time_remaining
|
||
|
elapsed_delta = "-:--:--" if elapsed is None else str(timedelta(seconds=int(elapsed)))
|
||
|
remaining_delta = "-:--:--" if remaining is None else str(timedelta(seconds=int(remaining)))
|
||
|
return Text.from_markup(f"[progress.elapsed]{elapsed_delta} < [progress.remaining]{remaining_delta}")
|
||
|
|
||
|
class BatchesProcessedColumn(ProgressColumn):
|
||
|
def render(self, task) -> RenderableType:
|
||
|
return Text.from_markup(f"[magenta] {int(task.completed)}/{task.total}")
|
||
|
|
||
|
class ProcessingSpeedColumn(ProgressColumn):
|
||
|
def render(self, task) -> RenderableType:
|
||
|
task_speed = f"{task.speed:>.2f}" if task.speed is not None else "0.00"
|
||
|
return Text.from_markup(f"[progress.data.speed] {task_speed}it/s")
|
||
|
|
||
|
class MetricsTextColumn(ProgressColumn):
|
||
|
"""A column containing text."""
|
||
|
|
||
|
def __init__(self, trainer, stage):
|
||
|
self._trainer = trainer
|
||
|
self._stage = stage
|
||
|
self._tasks = {}
|
||
|
self._current_task_id = 0
|
||
|
super().__init__()
|
||
|
|
||
|
def render(self, task) -> Text:
|
||
|
if self._stage != "fit" or self._trainer.sanity_checking:
|
||
|
return ""
|
||
|
if self._trainer.training and task.id not in self._tasks:
|
||
|
self._tasks[task.id] = "None"
|
||
|
if self._renderable_cache:
|
||
|
self._tasks[self._current_task_id] = self._renderable_cache[self._current_task_id][1]
|
||
|
self._current_task_id = task.id
|
||
|
if self._trainer.training and task.id != self._current_task_id:
|
||
|
return self._tasks[task.id]
|
||
|
_text = ""
|
||
|
for k, v in self._trainer.progress_bar_dict.items():
|
||
|
_text += f"{k}: {round(v, 3) if isinstance(v, float) else v} "
|
||
|
text = Text.from_markup(_text, style=None, justify="left")
|
||
|
return text
|
||
|
|
||
|
|
||
|
STYLES: Dict[str, str] = {
|
||
|
"train": "red",
|
||
|
"sanity_check": "yellow",
|
||
|
"validate": "yellow",
|
||
|
"test": "yellow",
|
||
|
"predict": "yellow",
|
||
|
}
|
||
|
|
||
|
|
||
|
class RichProgressBar(ProgressBarBase):
|
||
|
"""
|
||
|
Create a progress bar with `rich text formatting <https://github.com/willmcgugan/rich>`_.
|
||
|
|
||
|
Install it with pip:
|
||
|
|
||
|
.. code-block:: bash
|
||
|
|
||
|
pip install rich
|
||
|
|
||
|
.. code-block:: python
|
||
|
|
||
|
from pytorch_lightning import Trainer
|
||
|
from pytorch_lightning.callbacks import RichProgressBar
|
||
|
|
||
|
trainer = Trainer(callbacks=RichProgressBar())
|
||
|
|
||
|
Args:
|
||
|
refresh_rate: the number of updates per second, must be strictly positive
|
||
|
|
||
|
Raises:
|
||
|
ImportError:
|
||
|
If required `rich` package is not installed on the device.
|
||
|
"""
|
||
|
|
||
|
def __init__(self, refresh_rate: float = 1.0):
|
||
|
if not _RICH_AVAILABLE:
|
||
|
raise ImportError(
|
||
|
"`RichProgressBar` requires `rich` to be installed. Install it by running `pip install rich`."
|
||
|
)
|
||
|
super().__init__()
|
||
|
self._refresh_rate: float = refresh_rate
|
||
|
self._enabled: bool = True
|
||
|
self._total_val_batches: int = 0
|
||
|
self.progress: Progress = None
|
||
|
self.val_sanity_progress_bar_id: Optional[int] = None
|
||
|
self.main_progress_bar_id: Optional[int] = None
|
||
|
self.val_progress_bar_id: Optional[int] = None
|
||
|
self.test_progress_bar_id: Optional[int] = None
|
||
|
self.predict_progress_bar_id: Optional[int] = None
|
||
|
self.console = Console(record=True)
|
||
|
|
||
|
@property
|
||
|
def refresh_rate(self) -> int:
|
||
|
return self._refresh_rate
|
||
|
|
||
|
@property
|
||
|
def is_enabled(self) -> bool:
|
||
|
return self._enabled and self.refresh_rate > 0
|
||
|
|
||
|
@property
|
||
|
def is_disabled(self) -> bool:
|
||
|
return not self.is_enabled
|
||
|
|
||
|
def disable(self) -> None:
|
||
|
self._enabled = False
|
||
|
|
||
|
def enable(self) -> None:
|
||
|
self._enabled = True
|
||
|
|
||
|
@property
|
||
|
def sanity_check_description(self) -> str:
|
||
|
return "[Validation Sanity Check]"
|
||
|
|
||
|
@property
|
||
|
def validation_description(self) -> str:
|
||
|
return "[Validation]"
|
||
|
|
||
|
@property
|
||
|
def test_description(self) -> str:
|
||
|
return "[Testing]"
|
||
|
|
||
|
@property
|
||
|
def predict_description(self) -> str:
|
||
|
return "[Predicting]"
|
||
|
|
||
|
def setup(self, trainer, pl_module, stage):
|
||
|
self.progress = Progress(
|
||
|
SpinnerColumn(),
|
||
|
TextColumn("[progress.description]{task.description}"),
|
||
|
BarColumn(),
|
||
|
BatchesProcessedColumn(),
|
||
|
"[",
|
||
|
CustomTimeColumn(),
|
||
|
ProcessingSpeedColumn(),
|
||
|
MetricsTextColumn(trainer, stage),
|
||
|
"]",
|
||
|
console=self.console,
|
||
|
refresh_per_second=self.refresh_rate,
|
||
|
).__enter__()
|
||
|
|
||
|
def on_sanity_check_start(self, trainer, pl_module):
|
||
|
super().on_sanity_check_start(trainer, pl_module)
|
||
|
self.val_sanity_progress_bar_id = self.progress.add_task(
|
||
|
f"[{STYLES['sanity_check']}]{self.sanity_check_description}",
|
||
|
total=trainer.num_sanity_val_steps,
|
||
|
)
|
||
|
|
||
|
def on_sanity_check_end(self, trainer, pl_module):
|
||
|
super().on_sanity_check_end(trainer, pl_module)
|
||
|
self.progress.update(self.val_sanity_progress_bar_id, visible=False)
|
||
|
|
||
|
def on_train_epoch_start(self, trainer, pl_module):
|
||
|
super().on_train_epoch_start(trainer, pl_module)
|
||
|
total_train_batches = self.total_train_batches
|
||
|
self._total_val_batches = self.total_val_batches
|
||
|
if total_train_batches != float("inf"):
|
||
|
# val can be checked multiple times per epoch
|
||
|
val_checks_per_epoch = total_train_batches // trainer.val_check_batch
|
||
|
self._total_val_batches = self._total_val_batches * val_checks_per_epoch
|
||
|
|
||
|
total_batches = total_train_batches + self._total_val_batches
|
||
|
|
||
|
train_description = self._get_train_description(trainer.current_epoch)
|
||
|
|
||
|
self.main_progress_bar_id = self.progress.add_task(
|
||
|
f"[{STYLES['train']}]{train_description}",
|
||
|
total=total_batches,
|
||
|
)
|
||
|
|
||
|
def on_validation_epoch_start(self, trainer, pl_module):
|
||
|
super().on_validation_epoch_start(trainer, pl_module)
|
||
|
if self._total_val_batches > 0:
|
||
|
self.val_progress_bar_id = self.progress.add_task(
|
||
|
f"[{STYLES['validate']}]{self.validation_description}",
|
||
|
total=self._total_val_batches,
|
||
|
)
|
||
|
|
||
|
def on_validation_epoch_end(self, trainer, pl_module):
|
||
|
super().on_validation_epoch_end(trainer, pl_module)
|
||
|
if self.val_progress_bar_id is not None:
|
||
|
self.progress.update(self.val_progress_bar_id, visible=False)
|
||
|
|
||
|
def on_test_epoch_start(self, trainer, pl_module):
|
||
|
super().on_train_epoch_start(trainer, pl_module)
|
||
|
self.test_progress_bar_id = self.progress.add_task(
|
||
|
f"[{STYLES['test']}]{self.test_description}",
|
||
|
total=self.total_test_batches,
|
||
|
)
|
||
|
|
||
|
def on_predict_epoch_start(self, trainer, pl_module):
|
||
|
super().on_predict_epoch_start(trainer, pl_module)
|
||
|
self.predict_progress_bar_id = self.progress.add_task(
|
||
|
f"[{STYLES['predict']}]{self.predict_description}",
|
||
|
total=self.total_predict_batches,
|
||
|
)
|
||
|
|
||
|
def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx):
|
||
|
super().on_train_batch_end(trainer, pl_module, outputs, batch, batch_idx, dataloader_idx)
|
||
|
if self._should_update(self.train_batch_idx, self.total_train_batches + self.total_val_batches):
|
||
|
self.progress.update(self.main_progress_bar_id, advance=1.0)
|
||
|
|
||
|
def on_validation_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx):
|
||
|
super().on_validation_batch_end(trainer, pl_module, outputs, batch, batch_idx, dataloader_idx)
|
||
|
if trainer.sanity_checking:
|
||
|
self.progress.update(self.val_sanity_progress_bar_id, advance=1.0)
|
||
|
elif self.val_progress_bar_id and self._should_update(
|
||
|
self.val_batch_idx, self.total_train_batches + self.total_val_batches
|
||
|
):
|
||
|
self.progress.update(self.main_progress_bar_id, advance=1.0)
|
||
|
self.progress.update(self.val_progress_bar_id, advance=1.0)
|
||
|
|
||
|
def on_test_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx):
|
||
|
super().on_test_batch_end(trainer, pl_module, outputs, batch, batch_idx, dataloader_idx)
|
||
|
if self._should_update(self.test_batch_idx, self.total_test_batches):
|
||
|
self.progress.update(self.test_progress_bar_id, advance=1.0)
|
||
|
|
||
|
def on_predict_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx):
|
||
|
super().on_predict_batch_end(trainer, pl_module, outputs, batch, batch_idx, dataloader_idx)
|
||
|
if self._should_update(self.predict_batch_idx, self.total_predict_batches):
|
||
|
self.progress.update(self.predict_progress_bar_id, advance=1.0)
|
||
|
|
||
|
def _should_update(self, current, total) -> bool:
|
||
|
return self.is_enabled and (current % self.refresh_rate == 0 or current == total)
|
||
|
|
||
|
def _get_train_description(self, current_epoch: int) -> str:
|
||
|
train_description = f"[Epoch {current_epoch}]"
|
||
|
if len(self.validation_description) > len(train_description):
|
||
|
# Padding is required to avoid flickering due of uneven lengths of "Epoch X"
|
||
|
# and "Validation" Bar description
|
||
|
num_digits = len(str(current_epoch))
|
||
|
required_padding = (len(self.validation_description) - len(train_description) + 1) - num_digits
|
||
|
for _ in range(required_padding):
|
||
|
train_description += " "
|
||
|
return train_description
|
||
|
|
||
|
def teardown(self, trainer, pl_module, stage):
|
||
|
self.progress.__exit__(None, None, None)
|