Add `refresh_rate` to RichProgressBar (#10497)

Co-authored-by: ananthsub <ananth.subramaniam@gmail.com>
Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com>
This commit is contained in:
Kaushik B 2021-11-19 11:29:57 +05:30 committed by GitHub
parent 7d3ad5b76e
commit 137b62d80d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 63 additions and 26 deletions

View File

@ -37,6 +37,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Raise `MisconfigurationException` when `enable_progress_bar=False` and a progress bar instance has been passed in the callback list ([#10520](https://github.com/PyTorchLightning/pytorch-lightning/issues/10520))
- Renamed `refresh_rate_per_second` parameter to `referesh_rate` for `RichProgressBar` signature ([#10497](https://github.com/PyTorchLightning/pytorch-lightning/pull/10497))
- Moved ownership of the `PrecisionPlugin` into `TrainingTypePlugin` and updated all references ([#10570](https://github.com/PyTorchLightning/pytorch-lightning/pull/10570))

View File

@ -206,7 +206,8 @@ class RichProgressBar(ProgressBarBase):
trainer = Trainer(callbacks=RichProgressBar())
Args:
refresh_rate_per_second: the number of updates per second. If refresh_rate is 0, progress bar is disabled.
refresh_rate: Determines at which rate (in number of batches) the progress bars get updated.
Set it to ``0`` to disable the display.
leave: Leaves the finished progress bar in the terminal at the end of the epoch. Default: False
theme: Contains styles used to stylize the progress bar.
@ -222,7 +223,7 @@ class RichProgressBar(ProgressBarBase):
def __init__(
self,
refresh_rate_per_second: int = 10,
refresh_rate: int = 1,
leave: bool = False,
theme: RichProgressBarTheme = RichProgressBarTheme(),
) -> None:
@ -231,7 +232,7 @@ class RichProgressBar(ProgressBarBase):
"`RichProgressBar` requires `rich` to be installed. Install it by running `pip install -U rich`."
)
super().__init__()
self._refresh_rate_per_second: int = refresh_rate_per_second
self._refresh_rate: int = refresh_rate
self._leave: bool = leave
self._enabled: bool = True
self.progress: Optional[Progress] = None
@ -242,17 +243,12 @@ class RichProgressBar(ProgressBarBase):
self.theme = theme
@property
def refresh_rate_per_second(self) -> float:
"""Refresh rate for Rich Progress.
Returns: Refresh rate for Progress Bar.
Return 1 if not enabled, as a positive integer is required (ignored by Rich Progress).
"""
return self._refresh_rate_per_second if self._refresh_rate_per_second > 0 else 1
def refresh_rate(self) -> float:
return self._refresh_rate
@property
def is_enabled(self) -> bool:
return self._enabled and self._refresh_rate_per_second > 0
return self._enabled and self.refresh_rate > 0
@property
def is_disabled(self) -> bool:
@ -289,7 +285,7 @@ class RichProgressBar(ProgressBarBase):
self.progress = CustomProgress(
*self.configure_columns(trainer),
self._metric_component,
refresh_per_second=self.refresh_rate_per_second,
auto_refresh=False,
disable=self.is_disabled,
console=self._console,
)
@ -297,6 +293,10 @@ class RichProgressBar(ProgressBarBase):
# progress has started
self._progress_stopped = False
def refresh(self) -> None:
if self.progress:
self.progress.refresh()
def on_train_start(self, trainer, pl_module):
super().on_train_start(trainer, pl_module)
self._init_progress(trainer)
@ -328,10 +328,12 @@ class RichProgressBar(ProgressBarBase):
super().on_sanity_check_start(trainer, pl_module)
self._init_progress(trainer)
self.val_sanity_progress_bar_id = self._add_task(trainer.num_sanity_val_steps, self.sanity_check_description)
self.refresh()
def on_sanity_check_end(self, trainer, pl_module):
super().on_sanity_check_end(trainer, pl_module)
self._update(self.val_sanity_progress_bar_id, visible=False)
self.refresh()
def on_train_epoch_start(self, trainer, pl_module):
super().on_train_epoch_start(trainer, pl_module)
@ -354,6 +356,7 @@ class RichProgressBar(ProgressBarBase):
self.progress.reset(
self.main_progress_bar_id, total=total_batches, description=train_description, visible=True
)
self.refresh()
def on_validation_epoch_start(self, trainer, pl_module):
super().on_validation_epoch_start(trainer, pl_module)
@ -364,6 +367,7 @@ class RichProgressBar(ProgressBarBase):
val_checks_per_epoch = self.total_train_batches // trainer.val_check_batch
total_val_batches = self.total_val_batches * val_checks_per_epoch
self.val_progress_bar_id = self._add_task(total_val_batches, self.validation_description, visible=False)
self.refresh()
def _add_task(self, total_batches: int, description: str, visible: bool = True) -> Optional[int]:
if self.progress is not None:
@ -371,45 +375,54 @@ class RichProgressBar(ProgressBarBase):
f"[{self.theme.description}]{description}", total=total_batches, visible=visible
)
def _update(self, progress_bar_id: int, visible: bool = True) -> None:
if self.progress is not None:
self.progress.update(progress_bar_id, advance=1.0, visible=visible)
def _update(self, progress_bar_id: int, current: int, total: int, visible: bool = True) -> None:
if self.progress is not None and self._should_update(current, total):
self.progress.update(progress_bar_id, advance=self.refresh_rate, visible=visible)
self.refresh()
def _should_update(self, current: int, total: int) -> bool:
return self.is_enabled and (current % self.refresh_rate == 0 or current == total)
def on_validation_epoch_end(self, trainer, pl_module):
super().on_validation_epoch_end(trainer, pl_module)
if self.val_progress_bar_id is not None:
self._update(self.val_progress_bar_id, visible=False)
self._update(self.val_progress_bar_id, self.val_batch_idx, self.total_val_batches, visible=False)
def on_test_epoch_start(self, trainer, pl_module):
super().on_train_epoch_start(trainer, pl_module)
self.test_progress_bar_id = self._add_task(self.total_test_batches, self.test_description)
self.refresh()
def on_predict_epoch_start(self, trainer, pl_module):
super().on_predict_epoch_start(trainer, pl_module)
self.predict_progress_bar_id = self._add_task(self.total_predict_batches, self.predict_description)
self.refresh()
def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx):
super().on_train_batch_end(trainer, pl_module, outputs, batch, batch_idx)
self._update(self.main_progress_bar_id)
self._update(self.main_progress_bar_id, self.train_batch_idx, self.total_train_batches)
self._update_metrics(trainer, pl_module)
self.refresh()
def on_validation_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx):
super().on_validation_batch_end(trainer, pl_module, outputs, batch, batch_idx, dataloader_idx)
if trainer.sanity_checking:
self._update(self.val_sanity_progress_bar_id)
self._update(self.val_sanity_progress_bar_id, self.val_batch_idx, self.total_val_batches)
elif self.val_progress_bar_id is not None:
# check to see if we should update the main training progress bar
if self.main_progress_bar_id is not None:
self._update(self.main_progress_bar_id)
self._update(self.val_progress_bar_id)
self._update(self.main_progress_bar_id, self.val_batch_idx, self.total_val_batches)
self._update(self.val_progress_bar_id, self.val_batch_idx, self.total_val_batches)
self.refresh()
def on_test_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx):
super().on_test_batch_end(trainer, pl_module, outputs, batch, batch_idx, dataloader_idx)
self._update(self.test_progress_bar_id)
self._update(self.test_progress_bar_id, self.test_batch_idx, self.total_test_batches)
self.refresh()
def on_predict_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx):
super().on_predict_batch_end(trainer, pl_module, outputs, batch, batch_idx, dataloader_idx)
self._update(self.predict_progress_bar_id)
self._update(self.predict_progress_bar_id, self.predict_batch_idx, self.total_predict_batches)
self.refresh()
def _get_train_description(self, current_epoch: int) -> str:
train_description = f"Epoch {current_epoch}"

View File

@ -36,11 +36,11 @@ def test_rich_progress_bar_callback():
@RunIf(rich=True)
def test_rich_progress_bar_refresh_rate():
progress_bar = RichProgressBar(refresh_rate_per_second=1)
def test_rich_progress_bar_refresh_rate_enabled():
progress_bar = RichProgressBar(refresh_rate=1)
assert progress_bar.is_enabled
assert not progress_bar.is_disabled
progress_bar = RichProgressBar(refresh_rate_per_second=0)
progress_bar = RichProgressBar(refresh_rate=0)
assert not progress_bar.is_enabled
assert progress_bar.is_disabled
@ -180,3 +180,24 @@ def test_rich_progress_bar_leave(tmpdir, leave, reset_call_count):
)
trainer.fit(model)
assert mock_progress_reset.call_count == reset_call_count
@RunIf(rich=True)
@mock.patch("pytorch_lightning.callbacks.progress.rich_progress.Progress.update")
@pytest.mark.parametrize(("refresh_rate", "expected_call_count"), ([(0, 0), (3, 7)]))
def test_rich_progress_bar_refresh_rate(progress_update, tmpdir, refresh_rate, expected_call_count):
model = BoringModel()
trainer = Trainer(
default_root_dir=tmpdir,
num_sanity_val_steps=0,
limit_train_batches=6,
limit_val_batches=6,
max_epochs=1,
callbacks=RichProgressBar(refresh_rate=refresh_rate),
)
trainer.fit(model)
assert progress_update.call_count == expected_call_count