Fix rich main progress bar update (#12618)
This commit is contained in:
parent
f87cff2123
commit
959174703b
|
@ -77,6 +77,11 @@ class ProgressBarBase(Callback):
|
|||
def predict_description(self) -> str:
|
||||
return "Predicting"
|
||||
|
||||
@property
|
||||
def _val_processed(self) -> int:
|
||||
# use total in case validation runs more than once per training epoch
|
||||
return self.trainer.fit_loop.epoch_loop.val_loop.epoch_loop.batch_progress.total.processed
|
||||
|
||||
@property
|
||||
def train_batch_idx(self) -> int:
|
||||
"""The number of batches processed during training.
|
||||
|
|
|
@ -368,15 +368,19 @@ class RichProgressBar(ProgressBarBase):
|
|||
f"[{self.theme.description}]{description}", total=total_batches, visible=visible
|
||||
)
|
||||
|
||||
def _update(self, progress_bar_id: int, current: int, total: Union[int, float], visible: bool = True) -> None:
|
||||
if self.progress is not None and self._should_update(current, total):
|
||||
def _update(self, progress_bar_id: int, current: int, visible: bool = True) -> None:
|
||||
if self.progress is not None and self.is_enabled:
|
||||
total = self.progress.tasks[progress_bar_id].total
|
||||
if not self._should_update(current, total):
|
||||
return
|
||||
|
||||
leftover = current % self.refresh_rate
|
||||
advance = leftover if (current == total and leftover != 0) else self.refresh_rate
|
||||
self.progress.update(progress_bar_id, advance=advance, visible=visible)
|
||||
self.refresh()
|
||||
|
||||
def _should_update(self, current: int, total: Union[int, float]) -> bool:
|
||||
return self.is_enabled and (current % self.refresh_rate == 0 or current == total)
|
||||
return current % self.refresh_rate == 0 or current == total
|
||||
|
||||
def on_validation_epoch_end(self, trainer, pl_module):
|
||||
if self.val_progress_bar_id is not None and trainer.state.fn == "fit":
|
||||
|
@ -419,7 +423,7 @@ class RichProgressBar(ProgressBarBase):
|
|||
self.refresh()
|
||||
|
||||
def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx):
|
||||
self._update(self.main_progress_bar_id, self.train_batch_idx, self.total_train_batches)
|
||||
self._update(self.main_progress_bar_id, self.train_batch_idx + self._val_processed)
|
||||
self._update_metrics(trainer, pl_module)
|
||||
self.refresh()
|
||||
|
||||
|
@ -428,23 +432,20 @@ class RichProgressBar(ProgressBarBase):
|
|||
|
||||
def on_validation_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx):
|
||||
if trainer.sanity_checking:
|
||||
self._update(self.val_sanity_progress_bar_id, self.val_batch_idx, self.total_val_batches_current_dataloader)
|
||||
self._update(self.val_sanity_progress_bar_id, self.val_batch_idx)
|
||||
elif self.val_progress_bar_id is not None:
|
||||
# check to see if we should update the main training progress bar
|
||||
if self.main_progress_bar_id is not None:
|
||||
# TODO: Use total val_processed here just like TQDM in a follow-up
|
||||
self._update(self.main_progress_bar_id, self.val_batch_idx, self.total_val_batches_current_dataloader)
|
||||
self._update(self.val_progress_bar_id, self.val_batch_idx, self.total_val_batches_current_dataloader)
|
||||
self._update(self.main_progress_bar_id, self.train_batch_idx + self._val_processed)
|
||||
self._update(self.val_progress_bar_id, self.val_batch_idx)
|
||||
self.refresh()
|
||||
|
||||
def on_test_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx):
|
||||
self._update(self.test_progress_bar_id, self.test_batch_idx, self.total_test_batches_current_dataloader)
|
||||
self._update(self.test_progress_bar_id, self.test_batch_idx)
|
||||
self.refresh()
|
||||
|
||||
def on_predict_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx):
|
||||
self._update(
|
||||
self.predict_progress_bar_id, self.predict_batch_idx, self.total_predict_batches_current_dataloader
|
||||
)
|
||||
self._update(self.predict_progress_bar_id, self.predict_batch_idx)
|
||||
self.refresh()
|
||||
|
||||
def _get_train_description(self, current_epoch: int) -> str:
|
||||
|
|
|
@ -171,11 +171,6 @@ class TQDMProgressBar(ProgressBarBase):
|
|||
def is_disabled(self) -> bool:
|
||||
return not self.is_enabled
|
||||
|
||||
@property
|
||||
def _val_processed(self) -> int:
|
||||
# use total in case validation runs more than once per training epoch
|
||||
return self.trainer.fit_loop.epoch_loop.val_loop.epoch_loop.batch_progress.total.processed
|
||||
|
||||
def disable(self) -> None:
|
||||
self._enabled = False
|
||||
|
||||
|
|
|
@ -205,14 +205,28 @@ def test_rich_progress_bar_refresh_rate_disabled(progress_update, tmpdir):
|
|||
|
||||
|
||||
@RunIf(rich=True)
|
||||
@pytest.mark.parametrize(("refresh_rate", "expected_call_count"), ([(3, 7), (4, 7), (7, 4)]))
|
||||
def test_rich_progress_bar_with_refresh_rate(tmpdir, refresh_rate, expected_call_count):
|
||||
@pytest.mark.parametrize(
|
||||
"refresh_rate,train_batches,val_batches,expected_call_count",
|
||||
[
|
||||
(3, 6, 6, 4 + 3),
|
||||
(4, 6, 6, 3 + 3),
|
||||
(7, 6, 6, 2 + 2),
|
||||
(1, 2, 3, 5 + 4),
|
||||
(1, 0, 0, 0 + 0),
|
||||
(3, 1, 0, 1 + 0),
|
||||
(3, 1, 1, 1 + 2),
|
||||
(3, 5, 0, 2 + 0),
|
||||
(3, 5, 2, 3 + 2),
|
||||
(6, 5, 2, 2 + 2),
|
||||
],
|
||||
)
|
||||
def test_rich_progress_bar_with_refresh_rate(tmpdir, refresh_rate, train_batches, val_batches, expected_call_count):
|
||||
model = BoringModel()
|
||||
trainer = Trainer(
|
||||
default_root_dir=tmpdir,
|
||||
num_sanity_val_steps=0,
|
||||
limit_train_batches=6,
|
||||
limit_val_batches=6,
|
||||
limit_train_batches=train_batches,
|
||||
limit_val_batches=val_batches,
|
||||
max_epochs=1,
|
||||
callbacks=RichProgressBar(refresh_rate=refresh_rate),
|
||||
)
|
||||
|
@ -224,14 +238,16 @@ def test_rich_progress_bar_with_refresh_rate(tmpdir, refresh_rate, expected_call
|
|||
trainer.fit(model)
|
||||
assert progress_update.call_count == expected_call_count
|
||||
|
||||
fit_main_bar = trainer.progress_bar_callback.progress.tasks[0]
|
||||
fit_val_bar = trainer.progress_bar_callback.progress.tasks[1]
|
||||
assert fit_main_bar.completed == 12
|
||||
assert fit_main_bar.total == 12
|
||||
assert fit_main_bar.visible
|
||||
assert fit_val_bar.completed == 6
|
||||
assert fit_val_bar.total == 6
|
||||
assert not fit_val_bar.visible
|
||||
if train_batches > 0:
|
||||
fit_main_bar = trainer.progress_bar_callback.progress.tasks[0]
|
||||
assert fit_main_bar.completed == train_batches + val_batches
|
||||
assert fit_main_bar.total == train_batches + val_batches
|
||||
assert fit_main_bar.visible
|
||||
if val_batches > 0:
|
||||
fit_val_bar = trainer.progress_bar_callback.progress.tasks[1]
|
||||
assert fit_val_bar.completed == val_batches
|
||||
assert fit_val_bar.total == val_batches
|
||||
assert not fit_val_bar.visible
|
||||
|
||||
|
||||
@RunIf(rich=True)
|
||||
|
|
Loading…
Reference in New Issue