Fix rich with uneven refresh rate tracking (#11668)
Co-authored-by: Carlos Mocholí <carlossmocholi@gmail.com>
This commit is contained in:
parent
7948ed703d
commit
e9065e9d42
12
CHANGELOG.md
12
CHANGELOG.md
|
@ -509,12 +509,18 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
|
|||
- Fixed an issue in `RichProgressbar` to display the metrics logged only on main progress bar ([#11690](https://github.com/PyTorchLightning/pytorch-lightning/pull/11690))
|
||||
|
||||
|
||||
- Fixed `RichProgressBar` progress when refresh rate does not evenly divide the total counter ([#11668](https://github.com/PyTorchLightning/pytorch-lightning/pull/11668))
|
||||
|
||||
|
||||
- Fixed `RichProgressBar` progress validation bar total when using multiple validation runs within a single training epoch ([#11668](https://github.com/PyTorchLightning/pytorch-lightning/pull/11668))
|
||||
|
||||
|
||||
- The `RichProgressBar` now correctly shows the `on_epoch` logged values on train epoch end ([#11689](https://github.com/PyTorchLightning/pytorch-lightning/pull/11689))
|
||||
|
||||
|
||||
- Fixed check for available modules ([#11526](https://github.com/PyTorchLightning/pytorch-lightning/pull/11526))
|
||||
|
||||
|
||||
- The Rich progress bar now correctly shows the `on_epoch` logged values on train epoch end ([#11689](https://github.com/PyTorchLightning/pytorch-lightning/pull/11689))
|
||||
|
||||
|
||||
- Fixed an issue to avoid validation loop run on restart ([#11552](https://github.com/PyTorchLightning/pytorch-lightning/pull/11552))
|
||||
|
||||
|
||||
|
|
|
@ -105,6 +105,9 @@ class ProgressBarBase(Callback):
|
|||
Use this to set the total number of iterations in the progress bar. Can return ``inf`` if the validation
|
||||
dataloader is of infinite size.
|
||||
"""
|
||||
if self.trainer.sanity_checking:
|
||||
return sum(self.trainer.num_sanity_val_batches)
|
||||
|
||||
total_val_batches = 0
|
||||
if self.trainer.enable_validation:
|
||||
is_val_epoch = (self.trainer.current_epoch + 1) % self.trainer.check_val_every_n_epoch == 0
|
||||
|
|
|
@ -318,8 +318,6 @@ class RichProgressBar(ProgressBarBase):
|
|||
|
||||
def on_sanity_check_start(self, trainer, pl_module):
|
||||
self._init_progress(trainer)
|
||||
self.val_sanity_progress_bar_id = self._add_task(trainer.num_sanity_val_steps, self.sanity_check_description)
|
||||
self.refresh()
|
||||
|
||||
def on_sanity_check_end(self, trainer, pl_module):
|
||||
if self.progress is not None:
|
||||
|
@ -349,14 +347,13 @@ class RichProgressBar(ProgressBarBase):
|
|||
self.refresh()
|
||||
|
||||
def on_validation_epoch_start(self, trainer, pl_module):
|
||||
if self.total_val_batches > 0:
|
||||
total_val_batches = self.total_val_batches
|
||||
if self.total_train_batches != float("inf") and hasattr(trainer, "val_check_batch"):
|
||||
# val can be checked multiple times per epoch
|
||||
val_checks_per_epoch = self.total_train_batches // trainer.val_check_batch
|
||||
total_val_batches = self.total_val_batches * val_checks_per_epoch
|
||||
self.val_progress_bar_id = self._add_task(total_val_batches, self.validation_description, visible=False)
|
||||
self.refresh()
|
||||
if trainer.sanity_checking:
|
||||
self.val_sanity_progress_bar_id = self._add_task(self.total_val_batches, self.sanity_check_description)
|
||||
else:
|
||||
self.val_progress_bar_id = self._add_task(
|
||||
self.total_val_batches, self.validation_description, visible=False
|
||||
)
|
||||
self.refresh()
|
||||
|
||||
def _add_task(self, total_batches: int, description: str, visible: bool = True) -> Optional[int]:
|
||||
if self.progress is not None:
|
||||
|
@ -364,17 +361,20 @@ class RichProgressBar(ProgressBarBase):
|
|||
f"[{self.theme.description}]{description}", total=total_batches, visible=visible
|
||||
)
|
||||
|
||||
def _update(self, progress_bar_id: int, current: int, total: int, visible: bool = True) -> None:
|
||||
def _update(self, progress_bar_id: int, current: int, total: Union[int, float], visible: bool = True) -> None:
|
||||
if self.progress is not None and self._should_update(current, total):
|
||||
self.progress.update(progress_bar_id, advance=self.refresh_rate, visible=visible)
|
||||
leftover = current % self.refresh_rate
|
||||
advance = leftover if (current == total and leftover != 0) else self.refresh_rate
|
||||
self.progress.update(progress_bar_id, advance=advance, visible=visible)
|
||||
self.refresh()
|
||||
|
||||
def _should_update(self, current: int, total: int) -> bool:
|
||||
def _should_update(self, current: int, total: Union[int, float]) -> bool:
|
||||
return self.is_enabled and (current % self.refresh_rate == 0 or current == total)
|
||||
|
||||
def on_validation_epoch_end(self, trainer, pl_module):
|
||||
if self.val_progress_bar_id is not None:
|
||||
self._update(self.val_progress_bar_id, self.val_batch_idx, self.total_val_batches, visible=False)
|
||||
if self.val_progress_bar_id is not None and trainer.state.fn == "fit":
|
||||
self.progress.update(self.val_progress_bar_id, advance=0, visible=False)
|
||||
self.refresh()
|
||||
|
||||
def on_validation_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
|
||||
if trainer.state.fn == "fit":
|
||||
|
|
|
@ -178,9 +178,6 @@ class EvaluationLoop(DataLoaderLoop):
|
|||
max_batches = self.trainer.num_test_batches
|
||||
else:
|
||||
if self.trainer.sanity_checking:
|
||||
self.trainer.num_sanity_val_batches = [
|
||||
min(self.trainer.num_sanity_val_steps, val_batches) for val_batches in self.trainer.num_val_batches
|
||||
]
|
||||
max_batches = self.trainer.num_sanity_val_batches
|
||||
else:
|
||||
max_batches = self.trainer.num_val_batches
|
||||
|
|
|
@ -1344,6 +1344,9 @@ class Trainer(
|
|||
|
||||
# reload dataloaders
|
||||
val_loop._reload_evaluation_dataloaders()
|
||||
self.num_sanity_val_batches = [
|
||||
min(self.num_sanity_val_steps, val_batches) for val_batches in self.num_val_batches
|
||||
]
|
||||
|
||||
# run eval step
|
||||
with torch.no_grad():
|
||||
|
|
|
@ -46,9 +46,8 @@ def test_rich_progress_bar_refresh_rate_enabled():
|
|||
|
||||
|
||||
@RunIf(rich=True)
|
||||
@mock.patch("pytorch_lightning.callbacks.progress.rich_progress.Progress.update")
|
||||
@pytest.mark.parametrize("dataset", [RandomDataset(32, 64), RandomIterableDataset(32, 64)])
|
||||
def test_rich_progress_bar(progress_update, tmpdir, dataset):
|
||||
def test_rich_progress_bar(tmpdir, dataset):
|
||||
class TestModel(BoringModel):
|
||||
def train_dataloader(self):
|
||||
return DataLoader(dataset=dataset)
|
||||
|
@ -62,8 +61,6 @@ def test_rich_progress_bar(progress_update, tmpdir, dataset):
|
|||
def predict_dataloader(self):
|
||||
return DataLoader(dataset=dataset)
|
||||
|
||||
model = TestModel()
|
||||
|
||||
trainer = Trainer(
|
||||
default_root_dir=tmpdir,
|
||||
num_sanity_val_steps=0,
|
||||
|
@ -71,16 +68,27 @@ def test_rich_progress_bar(progress_update, tmpdir, dataset):
|
|||
limit_val_batches=1,
|
||||
limit_test_batches=1,
|
||||
limit_predict_batches=1,
|
||||
max_steps=1,
|
||||
max_epochs=1,
|
||||
callbacks=RichProgressBar(),
|
||||
)
|
||||
model = TestModel()
|
||||
|
||||
trainer.fit(model)
|
||||
trainer.validate(model)
|
||||
trainer.test(model)
|
||||
trainer.predict(model)
|
||||
with mock.patch("pytorch_lightning.callbacks.progress.rich_progress.Progress.update") as mocked:
|
||||
trainer.fit(model)
|
||||
# 3 for main progress bar and 1 for val progress bar
|
||||
assert mocked.call_count == 4
|
||||
|
||||
assert progress_update.call_count == 8
|
||||
with mock.patch("pytorch_lightning.callbacks.progress.rich_progress.Progress.update") as mocked:
|
||||
trainer.validate(model)
|
||||
assert mocked.call_count == 1
|
||||
|
||||
with mock.patch("pytorch_lightning.callbacks.progress.rich_progress.Progress.update") as mocked:
|
||||
trainer.test(model)
|
||||
assert mocked.call_count == 1
|
||||
|
||||
with mock.patch("pytorch_lightning.callbacks.progress.rich_progress.Progress.update") as mocked:
|
||||
trainer.predict(model)
|
||||
assert mocked.call_count == 1
|
||||
|
||||
|
||||
def test_rich_progress_bar_import_error(monkeypatch):
|
||||
|
@ -186,11 +194,20 @@ def test_rich_progress_bar_leave(tmpdir, leave, reset_call_count):
|
|||
|
||||
@RunIf(rich=True)
|
||||
@mock.patch("pytorch_lightning.callbacks.progress.rich_progress.Progress.update")
|
||||
@pytest.mark.parametrize(("refresh_rate", "expected_call_count"), ([(0, 0), (3, 7)]))
|
||||
def test_rich_progress_bar_refresh_rate(progress_update, tmpdir, refresh_rate, expected_call_count):
|
||||
def test_rich_progress_bar_refresh_rate_disabled(progress_update, tmpdir):
|
||||
trainer = Trainer(
|
||||
default_root_dir=tmpdir,
|
||||
fast_dev_run=4,
|
||||
callbacks=RichProgressBar(refresh_rate=0),
|
||||
)
|
||||
trainer.fit(BoringModel())
|
||||
assert progress_update.call_count == 0
|
||||
|
||||
|
||||
@RunIf(rich=True)
|
||||
@pytest.mark.parametrize(("refresh_rate", "expected_call_count"), ([(3, 7), (4, 7), (7, 4)]))
|
||||
def test_rich_progress_bar_with_refresh_rate(tmpdir, refresh_rate, expected_call_count):
|
||||
model = BoringModel()
|
||||
|
||||
trainer = Trainer(
|
||||
default_root_dir=tmpdir,
|
||||
num_sanity_val_steps=0,
|
||||
|
@ -200,14 +217,26 @@ def test_rich_progress_bar_refresh_rate(progress_update, tmpdir, refresh_rate, e
|
|||
callbacks=RichProgressBar(refresh_rate=refresh_rate),
|
||||
)
|
||||
|
||||
trainer.fit(model)
|
||||
trainer.progress_bar_callback.on_train_start(trainer, model)
|
||||
with mock.patch.object(
|
||||
trainer.progress_bar_callback.progress, "update", wraps=trainer.progress_bar_callback.progress.update
|
||||
) as progress_update:
|
||||
trainer.fit(model)
|
||||
assert progress_update.call_count == expected_call_count
|
||||
|
||||
assert progress_update.call_count == expected_call_count
|
||||
fit_main_bar = trainer.progress_bar_callback.progress.tasks[0]
|
||||
fit_val_bar = trainer.progress_bar_callback.progress.tasks[1]
|
||||
assert fit_main_bar.completed == 12
|
||||
assert fit_main_bar.total == 12
|
||||
assert fit_main_bar.visible
|
||||
assert fit_val_bar.completed == 6
|
||||
assert fit_val_bar.total == 6
|
||||
assert not fit_val_bar.visible
|
||||
|
||||
|
||||
@RunIf(rich=True)
|
||||
@pytest.mark.parametrize("limit_val_batches", (1, 5))
|
||||
def test_rich_progress_bar_num_sanity_val_steps(tmpdir, limit_val_batches: int):
|
||||
def test_rich_progress_bar_num_sanity_val_steps(tmpdir, limit_val_batches):
|
||||
model = BoringModel()
|
||||
|
||||
progress_bar = RichProgressBar()
|
||||
|
@ -224,6 +253,36 @@ def test_rich_progress_bar_num_sanity_val_steps(tmpdir, limit_val_batches: int):
|
|||
|
||||
trainer.fit(model)
|
||||
assert progress_bar.progress.tasks[0].completed == min(num_sanity_val_steps, limit_val_batches)
|
||||
assert progress_bar.progress.tasks[0].total == min(num_sanity_val_steps, limit_val_batches)
|
||||
|
||||
|
||||
@RunIf(rich=True)
|
||||
def test_rich_progress_bar_counter_with_val_check_interval(tmpdir):
|
||||
"""Test the completed and total counter for rich progress bar when using val_check_interval."""
|
||||
progress_bar = RichProgressBar()
|
||||
model = BoringModel()
|
||||
trainer = Trainer(
|
||||
default_root_dir=tmpdir,
|
||||
val_check_interval=2,
|
||||
max_epochs=1,
|
||||
limit_train_batches=7,
|
||||
limit_val_batches=4,
|
||||
callbacks=[progress_bar],
|
||||
)
|
||||
trainer.fit(model)
|
||||
|
||||
fit_main_progress_bar = progress_bar.progress.tasks[1]
|
||||
assert fit_main_progress_bar.completed == 7 + 3 * 4
|
||||
assert fit_main_progress_bar.total == 7 + 3 * 4
|
||||
|
||||
fit_val_bar = progress_bar.progress.tasks[2]
|
||||
assert fit_val_bar.completed == 4
|
||||
assert fit_val_bar.total == 4
|
||||
|
||||
trainer.validate(model)
|
||||
val_bar = progress_bar.progress.tasks[0]
|
||||
assert val_bar.completed == 4
|
||||
assert val_bar.total == 4
|
||||
|
||||
|
||||
@RunIf(rich=True)
|
||||
|
|
Loading…
Reference in New Issue