Improvements for rich progress (#9579)
This commit is contained in:
parent
3f7872d93a
commit
fd2e778cbc
|
@ -125,7 +125,7 @@ class RichProgressBar(ProgressBarBase):
|
|||
trainer = Trainer(callbacks=RichProgressBar())
|
||||
|
||||
Args:
|
||||
refresh_rate: the number of updates per second, must be strictly positive
|
||||
refresh_rate_per_second: the number of updates per second. If refresh_rate is 0, progress bar is disabled.
|
||||
theme: Contains styles used to stylize the progress bar.
|
||||
|
||||
Raises:
|
||||
|
@ -135,7 +135,7 @@ class RichProgressBar(ProgressBarBase):
|
|||
|
||||
def __init__(
|
||||
self,
|
||||
refresh_rate: float = 1.0,
|
||||
refresh_rate_per_second: int = 10,
|
||||
theme: RichProgressBarTheme = RichProgressBarTheme(),
|
||||
) -> None:
|
||||
if not _RICH_AVAILABLE:
|
||||
|
@ -143,7 +143,7 @@ class RichProgressBar(ProgressBarBase):
|
|||
"`RichProgressBar` requires `rich` to be installed. Install it by running `pip install rich`."
|
||||
)
|
||||
super().__init__()
|
||||
self._refresh_rate: float = refresh_rate
|
||||
self._refresh_rate_per_second: int = refresh_rate_per_second
|
||||
self._enabled: bool = True
|
||||
self._total_val_batches: int = 0
|
||||
self.progress: Progress = None
|
||||
|
@ -156,12 +156,17 @@ class RichProgressBar(ProgressBarBase):
|
|||
self.theme = theme
|
||||
|
||||
@property
|
||||
def refresh_rate(self) -> int:
|
||||
return self._refresh_rate
|
||||
def refresh_rate_per_second(self) -> float:
|
||||
"""Refresh rate for Rich Progress.
|
||||
|
||||
Returns: Refresh rate for Progress Bar.
|
||||
Return 1 if not enabled, as a positive integer is required (ignored by Rich Progress).
|
||||
"""
|
||||
return self._refresh_rate_per_second if self._refresh_rate_per_second > 0 else 1
|
||||
|
||||
@property
|
||||
def is_enabled(self) -> bool:
|
||||
return self._enabled and self.refresh_rate > 0
|
||||
return self._enabled and self._refresh_rate_per_second > 0
|
||||
|
||||
@property
|
||||
def is_disabled(self) -> bool:
|
||||
|
@ -189,7 +194,7 @@ class RichProgressBar(ProgressBarBase):
|
|||
def predict_description(self) -> str:
|
||||
return "Predicting"
|
||||
|
||||
def setup(self, trainer, pl_module, stage):
|
||||
def setup(self, trainer, pl_module, stage: Optional[str] = None):
|
||||
self.progress = Progress(
|
||||
TextColumn("[progress.description]{task.description}"),
|
||||
BarColumn(complete_style=self.theme.progress_bar_complete, finished_style=self.theme.progress_bar_finished),
|
||||
|
@ -198,8 +203,10 @@ class RichProgressBar(ProgressBarBase):
|
|||
ProcessingSpeedColumn(style=self.theme.processing_speed),
|
||||
MetricsTextColumn(trainer, pl_module, stage),
|
||||
console=self.console,
|
||||
refresh_per_second=self.refresh_rate,
|
||||
).__enter__()
|
||||
refresh_per_second=self.refresh_rate_per_second,
|
||||
disable=self.is_disabled,
|
||||
)
|
||||
self.progress.start()
|
||||
|
||||
def on_sanity_check_start(self, trainer, pl_module):
|
||||
super().on_sanity_check_start(trainer, pl_module)
|
||||
|
@ -259,31 +266,23 @@ class RichProgressBar(ProgressBarBase):
|
|||
|
||||
def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx):
|
||||
super().on_train_batch_end(trainer, pl_module, outputs, batch, batch_idx, dataloader_idx)
|
||||
if self._should_update(self.train_batch_idx, self.total_train_batches + self.total_val_batches):
|
||||
self.progress.update(self.main_progress_bar_id, advance=1.0)
|
||||
self.progress.update(self.main_progress_bar_id, advance=1.0)
|
||||
|
||||
def on_validation_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx):
|
||||
super().on_validation_batch_end(trainer, pl_module, outputs, batch, batch_idx, dataloader_idx)
|
||||
if trainer.sanity_checking:
|
||||
self.progress.update(self.val_sanity_progress_bar_id, advance=1.0)
|
||||
elif self.val_progress_bar_id and self._should_update(
|
||||
self.val_batch_idx, self.total_train_batches + self.total_val_batches
|
||||
):
|
||||
elif self.val_progress_bar_id:
|
||||
self.progress.update(self.main_progress_bar_id, advance=1.0)
|
||||
self.progress.update(self.val_progress_bar_id, advance=1.0)
|
||||
|
||||
def on_test_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx):
|
||||
super().on_test_batch_end(trainer, pl_module, outputs, batch, batch_idx, dataloader_idx)
|
||||
if self._should_update(self.test_batch_idx, self.total_test_batches):
|
||||
self.progress.update(self.test_progress_bar_id, advance=1.0)
|
||||
self.progress.update(self.test_progress_bar_id, advance=1.0)
|
||||
|
||||
def on_predict_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx):
|
||||
super().on_predict_batch_end(trainer, pl_module, outputs, batch, batch_idx, dataloader_idx)
|
||||
if self._should_update(self.predict_batch_idx, self.total_predict_batches):
|
||||
self.progress.update(self.predict_progress_bar_id, advance=1.0)
|
||||
|
||||
def _should_update(self, current, total) -> bool:
|
||||
return self.is_enabled and (current % self.refresh_rate == 0 or current == total)
|
||||
self.progress.update(self.predict_progress_bar_id, advance=1.0)
|
||||
|
||||
def _get_train_description(self, current_epoch: int) -> str:
|
||||
train_description = f"Epoch {current_epoch}"
|
||||
|
@ -296,8 +295,8 @@ class RichProgressBar(ProgressBarBase):
|
|||
train_description += " "
|
||||
return train_description
|
||||
|
||||
def teardown(self, trainer, pl_module, stage):
|
||||
self.progress.__exit__(None, None, None)
|
||||
def teardown(self, trainer, pl_module, stage: Optional[str] = None) -> None:
|
||||
self.progress.stop()
|
||||
|
||||
def on_exception(self, trainer, pl_module, exception: BaseException) -> None:
|
||||
if isinstance(exception, KeyboardInterrupt):
|
||||
|
|
|
@ -34,6 +34,16 @@ def test_rich_progress_bar_callback():
|
|||
assert isinstance(trainer.progress_bar_callback, RichProgressBar)
|
||||
|
||||
|
||||
@RunIf(rich=True)
|
||||
def test_rich_progress_bar_refresh_rate():
|
||||
progress_bar = RichProgressBar(refresh_rate_per_second=1)
|
||||
assert progress_bar.is_enabled
|
||||
assert not progress_bar.is_disabled
|
||||
progress_bar = RichProgressBar(refresh_rate_per_second=0)
|
||||
assert not progress_bar.is_enabled
|
||||
assert progress_bar.is_disabled
|
||||
|
||||
|
||||
@RunIf(rich=True)
|
||||
@mock.patch("pytorch_lightning.callbacks.progress.rich_progress.Progress.update")
|
||||
def test_rich_progress_bar(progress_update, tmpdir):
|
||||
|
|
Loading…
Reference in New Issue