Fix rich main progress bar update (#12618)

This commit is contained in:
Rohit Gupta 2022-04-06 20:12:32 +05:30 committed by lexierule
parent f87cff2123
commit 959174703b
4 changed files with 46 additions and 29 deletions

View File

@ -77,6 +77,11 @@ class ProgressBarBase(Callback):
def predict_description(self) -> str:
return "Predicting"
@property
def _val_processed(self) -> int:
# use total in case validation runs more than once per training epoch
return self.trainer.fit_loop.epoch_loop.val_loop.epoch_loop.batch_progress.total.processed
@property
def train_batch_idx(self) -> int:
"""The number of batches processed during training.

View File

@ -368,15 +368,19 @@ class RichProgressBar(ProgressBarBase):
f"[{self.theme.description}]{description}", total=total_batches, visible=visible
)
def _update(self, progress_bar_id: int, current: int, total: Union[int, float], visible: bool = True) -> None:
if self.progress is not None and self._should_update(current, total):
def _update(self, progress_bar_id: int, current: int, visible: bool = True) -> None:
if self.progress is not None and self.is_enabled:
total = self.progress.tasks[progress_bar_id].total
if not self._should_update(current, total):
return
leftover = current % self.refresh_rate
advance = leftover if (current == total and leftover != 0) else self.refresh_rate
self.progress.update(progress_bar_id, advance=advance, visible=visible)
self.refresh()
def _should_update(self, current: int, total: Union[int, float]) -> bool:
return self.is_enabled and (current % self.refresh_rate == 0 or current == total)
return current % self.refresh_rate == 0 or current == total
def on_validation_epoch_end(self, trainer, pl_module):
if self.val_progress_bar_id is not None and trainer.state.fn == "fit":
@ -419,7 +423,7 @@ class RichProgressBar(ProgressBarBase):
self.refresh()
def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx):
self._update(self.main_progress_bar_id, self.train_batch_idx, self.total_train_batches)
self._update(self.main_progress_bar_id, self.train_batch_idx + self._val_processed)
self._update_metrics(trainer, pl_module)
self.refresh()
@ -428,23 +432,20 @@ class RichProgressBar(ProgressBarBase):
def on_validation_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx):
if trainer.sanity_checking:
self._update(self.val_sanity_progress_bar_id, self.val_batch_idx, self.total_val_batches_current_dataloader)
self._update(self.val_sanity_progress_bar_id, self.val_batch_idx)
elif self.val_progress_bar_id is not None:
# check to see if we should update the main training progress bar
if self.main_progress_bar_id is not None:
# TODO: Use total val_processed here just like TQDM in a follow-up
self._update(self.main_progress_bar_id, self.val_batch_idx, self.total_val_batches_current_dataloader)
self._update(self.val_progress_bar_id, self.val_batch_idx, self.total_val_batches_current_dataloader)
self._update(self.main_progress_bar_id, self.train_batch_idx + self._val_processed)
self._update(self.val_progress_bar_id, self.val_batch_idx)
self.refresh()
def on_test_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx):
self._update(self.test_progress_bar_id, self.test_batch_idx, self.total_test_batches_current_dataloader)
self._update(self.test_progress_bar_id, self.test_batch_idx)
self.refresh()
def on_predict_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx):
self._update(
self.predict_progress_bar_id, self.predict_batch_idx, self.total_predict_batches_current_dataloader
)
self._update(self.predict_progress_bar_id, self.predict_batch_idx)
self.refresh()
def _get_train_description(self, current_epoch: int) -> str:

View File

@ -171,11 +171,6 @@ class TQDMProgressBar(ProgressBarBase):
def is_disabled(self) -> bool:
return not self.is_enabled
@property
def _val_processed(self) -> int:
# use total in case validation runs more than once per training epoch
return self.trainer.fit_loop.epoch_loop.val_loop.epoch_loop.batch_progress.total.processed
def disable(self) -> None:
self._enabled = False

View File

@ -205,14 +205,28 @@ def test_rich_progress_bar_refresh_rate_disabled(progress_update, tmpdir):
@RunIf(rich=True)
@pytest.mark.parametrize(("refresh_rate", "expected_call_count"), ([(3, 7), (4, 7), (7, 4)]))
def test_rich_progress_bar_with_refresh_rate(tmpdir, refresh_rate, expected_call_count):
@pytest.mark.parametrize(
"refresh_rate,train_batches,val_batches,expected_call_count",
[
(3, 6, 6, 4 + 3),
(4, 6, 6, 3 + 3),
(7, 6, 6, 2 + 2),
(1, 2, 3, 5 + 4),
(1, 0, 0, 0 + 0),
(3, 1, 0, 1 + 0),
(3, 1, 1, 1 + 2),
(3, 5, 0, 2 + 0),
(3, 5, 2, 3 + 2),
(6, 5, 2, 2 + 2),
],
)
def test_rich_progress_bar_with_refresh_rate(tmpdir, refresh_rate, train_batches, val_batches, expected_call_count):
model = BoringModel()
trainer = Trainer(
default_root_dir=tmpdir,
num_sanity_val_steps=0,
limit_train_batches=6,
limit_val_batches=6,
limit_train_batches=train_batches,
limit_val_batches=val_batches,
max_epochs=1,
callbacks=RichProgressBar(refresh_rate=refresh_rate),
)
@ -224,14 +238,16 @@ def test_rich_progress_bar_with_refresh_rate(tmpdir, refresh_rate, expected_call
trainer.fit(model)
assert progress_update.call_count == expected_call_count
fit_main_bar = trainer.progress_bar_callback.progress.tasks[0]
fit_val_bar = trainer.progress_bar_callback.progress.tasks[1]
assert fit_main_bar.completed == 12
assert fit_main_bar.total == 12
assert fit_main_bar.visible
assert fit_val_bar.completed == 6
assert fit_val_bar.total == 6
assert not fit_val_bar.visible
if train_batches > 0:
fit_main_bar = trainer.progress_bar_callback.progress.tasks[0]
assert fit_main_bar.completed == train_batches + val_batches
assert fit_main_bar.total == train_batches + val_batches
assert fit_main_bar.visible
if val_batches > 0:
fit_val_bar = trainer.progress_bar_callback.progress.tasks[1]
assert fit_val_bar.completed == val_batches
assert fit_val_bar.total == val_batches
assert not fit_val_bar.visible
@RunIf(rich=True)