Add `refresh_rate` to RichProgressBar (#10497)
Co-authored-by: ananthsub <ananth.subramaniam@gmail.com> Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com>
This commit is contained in:
parent
7d3ad5b76e
commit
137b62d80d
|
@ -37,6 +37,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
|
|||
- Raise `MisconfigurationException` when `enable_progress_bar=False` and a progress bar instance has been passed in the callback list ([#10520](https://github.com/PyTorchLightning/pytorch-lightning/issues/10520))
|
||||
|
||||
|
||||
- Renamed `refresh_rate_per_second` parameter to `referesh_rate` for `RichProgressBar` signature ([#10497](https://github.com/PyTorchLightning/pytorch-lightning/pull/10497))
|
||||
|
||||
|
||||
- Moved ownership of the `PrecisionPlugin` into `TrainingTypePlugin` and updated all references ([#10570](https://github.com/PyTorchLightning/pytorch-lightning/pull/10570))
|
||||
|
||||
|
||||
|
|
|
@ -206,7 +206,8 @@ class RichProgressBar(ProgressBarBase):
|
|||
trainer = Trainer(callbacks=RichProgressBar())
|
||||
|
||||
Args:
|
||||
refresh_rate_per_second: the number of updates per second. If refresh_rate is 0, progress bar is disabled.
|
||||
refresh_rate: Determines at which rate (in number of batches) the progress bars get updated.
|
||||
Set it to ``0`` to disable the display.
|
||||
leave: Leaves the finished progress bar in the terminal at the end of the epoch. Default: False
|
||||
theme: Contains styles used to stylize the progress bar.
|
||||
|
||||
|
@ -222,7 +223,7 @@ class RichProgressBar(ProgressBarBase):
|
|||
|
||||
def __init__(
|
||||
self,
|
||||
refresh_rate_per_second: int = 10,
|
||||
refresh_rate: int = 1,
|
||||
leave: bool = False,
|
||||
theme: RichProgressBarTheme = RichProgressBarTheme(),
|
||||
) -> None:
|
||||
|
@ -231,7 +232,7 @@ class RichProgressBar(ProgressBarBase):
|
|||
"`RichProgressBar` requires `rich` to be installed. Install it by running `pip install -U rich`."
|
||||
)
|
||||
super().__init__()
|
||||
self._refresh_rate_per_second: int = refresh_rate_per_second
|
||||
self._refresh_rate: int = refresh_rate
|
||||
self._leave: bool = leave
|
||||
self._enabled: bool = True
|
||||
self.progress: Optional[Progress] = None
|
||||
|
@ -242,17 +243,12 @@ class RichProgressBar(ProgressBarBase):
|
|||
self.theme = theme
|
||||
|
||||
@property
|
||||
def refresh_rate_per_second(self) -> float:
|
||||
"""Refresh rate for Rich Progress.
|
||||
|
||||
Returns: Refresh rate for Progress Bar.
|
||||
Return 1 if not enabled, as a positive integer is required (ignored by Rich Progress).
|
||||
"""
|
||||
return self._refresh_rate_per_second if self._refresh_rate_per_second > 0 else 1
|
||||
def refresh_rate(self) -> float:
|
||||
return self._refresh_rate
|
||||
|
||||
@property
|
||||
def is_enabled(self) -> bool:
|
||||
return self._enabled and self._refresh_rate_per_second > 0
|
||||
return self._enabled and self.refresh_rate > 0
|
||||
|
||||
@property
|
||||
def is_disabled(self) -> bool:
|
||||
|
@ -289,7 +285,7 @@ class RichProgressBar(ProgressBarBase):
|
|||
self.progress = CustomProgress(
|
||||
*self.configure_columns(trainer),
|
||||
self._metric_component,
|
||||
refresh_per_second=self.refresh_rate_per_second,
|
||||
auto_refresh=False,
|
||||
disable=self.is_disabled,
|
||||
console=self._console,
|
||||
)
|
||||
|
@ -297,6 +293,10 @@ class RichProgressBar(ProgressBarBase):
|
|||
# progress has started
|
||||
self._progress_stopped = False
|
||||
|
||||
def refresh(self) -> None:
|
||||
if self.progress:
|
||||
self.progress.refresh()
|
||||
|
||||
def on_train_start(self, trainer, pl_module):
|
||||
super().on_train_start(trainer, pl_module)
|
||||
self._init_progress(trainer)
|
||||
|
@ -328,10 +328,12 @@ class RichProgressBar(ProgressBarBase):
|
|||
super().on_sanity_check_start(trainer, pl_module)
|
||||
self._init_progress(trainer)
|
||||
self.val_sanity_progress_bar_id = self._add_task(trainer.num_sanity_val_steps, self.sanity_check_description)
|
||||
self.refresh()
|
||||
|
||||
def on_sanity_check_end(self, trainer, pl_module):
|
||||
super().on_sanity_check_end(trainer, pl_module)
|
||||
self._update(self.val_sanity_progress_bar_id, visible=False)
|
||||
self.refresh()
|
||||
|
||||
def on_train_epoch_start(self, trainer, pl_module):
|
||||
super().on_train_epoch_start(trainer, pl_module)
|
||||
|
@ -354,6 +356,7 @@ class RichProgressBar(ProgressBarBase):
|
|||
self.progress.reset(
|
||||
self.main_progress_bar_id, total=total_batches, description=train_description, visible=True
|
||||
)
|
||||
self.refresh()
|
||||
|
||||
def on_validation_epoch_start(self, trainer, pl_module):
|
||||
super().on_validation_epoch_start(trainer, pl_module)
|
||||
|
@ -364,6 +367,7 @@ class RichProgressBar(ProgressBarBase):
|
|||
val_checks_per_epoch = self.total_train_batches // trainer.val_check_batch
|
||||
total_val_batches = self.total_val_batches * val_checks_per_epoch
|
||||
self.val_progress_bar_id = self._add_task(total_val_batches, self.validation_description, visible=False)
|
||||
self.refresh()
|
||||
|
||||
def _add_task(self, total_batches: int, description: str, visible: bool = True) -> Optional[int]:
|
||||
if self.progress is not None:
|
||||
|
@ -371,45 +375,54 @@ class RichProgressBar(ProgressBarBase):
|
|||
f"[{self.theme.description}]{description}", total=total_batches, visible=visible
|
||||
)
|
||||
|
||||
def _update(self, progress_bar_id: int, visible: bool = True) -> None:
|
||||
if self.progress is not None:
|
||||
self.progress.update(progress_bar_id, advance=1.0, visible=visible)
|
||||
def _update(self, progress_bar_id: int, current: int, total: int, visible: bool = True) -> None:
|
||||
if self.progress is not None and self._should_update(current, total):
|
||||
self.progress.update(progress_bar_id, advance=self.refresh_rate, visible=visible)
|
||||
self.refresh()
|
||||
|
||||
def _should_update(self, current: int, total: int) -> bool:
|
||||
return self.is_enabled and (current % self.refresh_rate == 0 or current == total)
|
||||
|
||||
def on_validation_epoch_end(self, trainer, pl_module):
|
||||
super().on_validation_epoch_end(trainer, pl_module)
|
||||
if self.val_progress_bar_id is not None:
|
||||
self._update(self.val_progress_bar_id, visible=False)
|
||||
self._update(self.val_progress_bar_id, self.val_batch_idx, self.total_val_batches, visible=False)
|
||||
|
||||
def on_test_epoch_start(self, trainer, pl_module):
|
||||
super().on_train_epoch_start(trainer, pl_module)
|
||||
self.test_progress_bar_id = self._add_task(self.total_test_batches, self.test_description)
|
||||
self.refresh()
|
||||
|
||||
def on_predict_epoch_start(self, trainer, pl_module):
|
||||
super().on_predict_epoch_start(trainer, pl_module)
|
||||
self.predict_progress_bar_id = self._add_task(self.total_predict_batches, self.predict_description)
|
||||
self.refresh()
|
||||
|
||||
def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx):
|
||||
super().on_train_batch_end(trainer, pl_module, outputs, batch, batch_idx)
|
||||
self._update(self.main_progress_bar_id)
|
||||
self._update(self.main_progress_bar_id, self.train_batch_idx, self.total_train_batches)
|
||||
self._update_metrics(trainer, pl_module)
|
||||
self.refresh()
|
||||
|
||||
def on_validation_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx):
|
||||
super().on_validation_batch_end(trainer, pl_module, outputs, batch, batch_idx, dataloader_idx)
|
||||
if trainer.sanity_checking:
|
||||
self._update(self.val_sanity_progress_bar_id)
|
||||
self._update(self.val_sanity_progress_bar_id, self.val_batch_idx, self.total_val_batches)
|
||||
elif self.val_progress_bar_id is not None:
|
||||
# check to see if we should update the main training progress bar
|
||||
if self.main_progress_bar_id is not None:
|
||||
self._update(self.main_progress_bar_id)
|
||||
self._update(self.val_progress_bar_id)
|
||||
self._update(self.main_progress_bar_id, self.val_batch_idx, self.total_val_batches)
|
||||
self._update(self.val_progress_bar_id, self.val_batch_idx, self.total_val_batches)
|
||||
self.refresh()
|
||||
|
||||
def on_test_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx):
|
||||
super().on_test_batch_end(trainer, pl_module, outputs, batch, batch_idx, dataloader_idx)
|
||||
self._update(self.test_progress_bar_id)
|
||||
self._update(self.test_progress_bar_id, self.test_batch_idx, self.total_test_batches)
|
||||
self.refresh()
|
||||
|
||||
def on_predict_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx):
|
||||
super().on_predict_batch_end(trainer, pl_module, outputs, batch, batch_idx, dataloader_idx)
|
||||
self._update(self.predict_progress_bar_id)
|
||||
self._update(self.predict_progress_bar_id, self.predict_batch_idx, self.total_predict_batches)
|
||||
self.refresh()
|
||||
|
||||
def _get_train_description(self, current_epoch: int) -> str:
|
||||
train_description = f"Epoch {current_epoch}"
|
||||
|
|
|
@ -36,11 +36,11 @@ def test_rich_progress_bar_callback():
|
|||
|
||||
|
||||
@RunIf(rich=True)
|
||||
def test_rich_progress_bar_refresh_rate():
|
||||
progress_bar = RichProgressBar(refresh_rate_per_second=1)
|
||||
def test_rich_progress_bar_refresh_rate_enabled():
|
||||
progress_bar = RichProgressBar(refresh_rate=1)
|
||||
assert progress_bar.is_enabled
|
||||
assert not progress_bar.is_disabled
|
||||
progress_bar = RichProgressBar(refresh_rate_per_second=0)
|
||||
progress_bar = RichProgressBar(refresh_rate=0)
|
||||
assert not progress_bar.is_enabled
|
||||
assert progress_bar.is_disabled
|
||||
|
||||
|
@ -180,3 +180,24 @@ def test_rich_progress_bar_leave(tmpdir, leave, reset_call_count):
|
|||
)
|
||||
trainer.fit(model)
|
||||
assert mock_progress_reset.call_count == reset_call_count
|
||||
|
||||
|
||||
@RunIf(rich=True)
|
||||
@mock.patch("pytorch_lightning.callbacks.progress.rich_progress.Progress.update")
|
||||
@pytest.mark.parametrize(("refresh_rate", "expected_call_count"), ([(0, 0), (3, 7)]))
|
||||
def test_rich_progress_bar_refresh_rate(progress_update, tmpdir, refresh_rate, expected_call_count):
|
||||
|
||||
model = BoringModel()
|
||||
|
||||
trainer = Trainer(
|
||||
default_root_dir=tmpdir,
|
||||
num_sanity_val_steps=0,
|
||||
limit_train_batches=6,
|
||||
limit_val_batches=6,
|
||||
max_epochs=1,
|
||||
callbacks=RichProgressBar(refresh_rate=refresh_rate),
|
||||
)
|
||||
|
||||
trainer.fit(model)
|
||||
|
||||
assert progress_update.call_count == expected_call_count
|
||||
|
|
Loading…
Reference in New Issue