From e9065e9d429f0313df9f68eebfd712af671b7303 Mon Sep 17 00:00:00 2001 From: Rohit Gupta Date: Thu, 3 Feb 2022 15:57:05 +0530 Subject: [PATCH] Fix rich with uneven refresh rate tracking (#11668) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Carlos MocholĂ­ --- CHANGELOG.md | 12 ++- pytorch_lightning/callbacks/progress/base.py | 3 + .../callbacks/progress/rich_progress.py | 30 +++--- .../loops/dataloader/evaluation_loop.py | 3 - pytorch_lightning/trainer/trainer.py | 3 + tests/callbacks/test_rich_progress_bar.py | 91 +++++++++++++++---- 6 files changed, 105 insertions(+), 37 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index d9c2e189c1..3c0501843f 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -509,12 +509,18 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fixed an issue in `RichProgressbar` to display the metrics logged only on main progress bar ([#11690](https://github.com/PyTorchLightning/pytorch-lightning/pull/11690)) +- Fixed `RichProgressBar` progress when refresh rate does not evenly divide the total counter ([#11668](https://github.com/PyTorchLightning/pytorch-lightning/pull/11668)) + + +- Fixed `RichProgressBar` progress validation bar total when using multiple validation runs within a single training epoch ([#11668](https://github.com/PyTorchLightning/pytorch-lightning/pull/11668)) + + +- The `RichProgressBar` now correctly shows the `on_epoch` logged values on train epoch end ([#11689](https://github.com/PyTorchLightning/pytorch-lightning/pull/11689)) + + - Fixed check for available modules ([#11526](https://github.com/PyTorchLightning/pytorch-lightning/pull/11526)) -- The Rich progress bar now correctly shows the `on_epoch` logged values on train epoch end ([#11689](https://github.com/PyTorchLightning/pytorch-lightning/pull/11689)) - - - Fixed an issue to avoid validation loop run on restart ([#11552](https://github.com/PyTorchLightning/pytorch-lightning/pull/11552)) diff --git a/pytorch_lightning/callbacks/progress/base.py b/pytorch_lightning/callbacks/progress/base.py index 5c35bf122b..77340f9f28 100644 --- a/pytorch_lightning/callbacks/progress/base.py +++ b/pytorch_lightning/callbacks/progress/base.py @@ -105,6 +105,9 @@ class ProgressBarBase(Callback): Use this to set the total number of iterations in the progress bar. Can return ``inf`` if the validation dataloader is of infinite size. """ + if self.trainer.sanity_checking: + return sum(self.trainer.num_sanity_val_batches) + total_val_batches = 0 if self.trainer.enable_validation: is_val_epoch = (self.trainer.current_epoch + 1) % self.trainer.check_val_every_n_epoch == 0 diff --git a/pytorch_lightning/callbacks/progress/rich_progress.py b/pytorch_lightning/callbacks/progress/rich_progress.py index cd944c5eb1..8f8ba6c93f 100644 --- a/pytorch_lightning/callbacks/progress/rich_progress.py +++ b/pytorch_lightning/callbacks/progress/rich_progress.py @@ -318,8 +318,6 @@ class RichProgressBar(ProgressBarBase): def on_sanity_check_start(self, trainer, pl_module): self._init_progress(trainer) - self.val_sanity_progress_bar_id = self._add_task(trainer.num_sanity_val_steps, self.sanity_check_description) - self.refresh() def on_sanity_check_end(self, trainer, pl_module): if self.progress is not None: @@ -349,14 +347,13 @@ class RichProgressBar(ProgressBarBase): self.refresh() def on_validation_epoch_start(self, trainer, pl_module): - if self.total_val_batches > 0: - total_val_batches = self.total_val_batches - if self.total_train_batches != float("inf") and hasattr(trainer, "val_check_batch"): - # val can be checked multiple times per epoch - val_checks_per_epoch = self.total_train_batches // trainer.val_check_batch - total_val_batches = self.total_val_batches * val_checks_per_epoch - self.val_progress_bar_id = self._add_task(total_val_batches, self.validation_description, visible=False) - self.refresh() + if trainer.sanity_checking: + self.val_sanity_progress_bar_id = self._add_task(self.total_val_batches, self.sanity_check_description) + else: + self.val_progress_bar_id = self._add_task( + self.total_val_batches, self.validation_description, visible=False + ) + self.refresh() def _add_task(self, total_batches: int, description: str, visible: bool = True) -> Optional[int]: if self.progress is not None: @@ -364,17 +361,20 @@ class RichProgressBar(ProgressBarBase): f"[{self.theme.description}]{description}", total=total_batches, visible=visible ) - def _update(self, progress_bar_id: int, current: int, total: int, visible: bool = True) -> None: + def _update(self, progress_bar_id: int, current: int, total: Union[int, float], visible: bool = True) -> None: if self.progress is not None and self._should_update(current, total): - self.progress.update(progress_bar_id, advance=self.refresh_rate, visible=visible) + leftover = current % self.refresh_rate + advance = leftover if (current == total and leftover != 0) else self.refresh_rate + self.progress.update(progress_bar_id, advance=advance, visible=visible) self.refresh() - def _should_update(self, current: int, total: int) -> bool: + def _should_update(self, current: int, total: Union[int, float]) -> bool: return self.is_enabled and (current % self.refresh_rate == 0 or current == total) def on_validation_epoch_end(self, trainer, pl_module): - if self.val_progress_bar_id is not None: - self._update(self.val_progress_bar_id, self.val_batch_idx, self.total_val_batches, visible=False) + if self.val_progress_bar_id is not None and trainer.state.fn == "fit": + self.progress.update(self.val_progress_bar_id, advance=0, visible=False) + self.refresh() def on_validation_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: if trainer.state.fn == "fit": diff --git a/pytorch_lightning/loops/dataloader/evaluation_loop.py b/pytorch_lightning/loops/dataloader/evaluation_loop.py index 1e0b30cab0..076fe1b52c 100644 --- a/pytorch_lightning/loops/dataloader/evaluation_loop.py +++ b/pytorch_lightning/loops/dataloader/evaluation_loop.py @@ -178,9 +178,6 @@ class EvaluationLoop(DataLoaderLoop): max_batches = self.trainer.num_test_batches else: if self.trainer.sanity_checking: - self.trainer.num_sanity_val_batches = [ - min(self.trainer.num_sanity_val_steps, val_batches) for val_batches in self.trainer.num_val_batches - ] max_batches = self.trainer.num_sanity_val_batches else: max_batches = self.trainer.num_val_batches diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 4a49723bf0..08e813375d 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -1344,6 +1344,9 @@ class Trainer( # reload dataloaders val_loop._reload_evaluation_dataloaders() + self.num_sanity_val_batches = [ + min(self.num_sanity_val_steps, val_batches) for val_batches in self.num_val_batches + ] # run eval step with torch.no_grad(): diff --git a/tests/callbacks/test_rich_progress_bar.py b/tests/callbacks/test_rich_progress_bar.py index 768632da77..9f7f8e4541 100644 --- a/tests/callbacks/test_rich_progress_bar.py +++ b/tests/callbacks/test_rich_progress_bar.py @@ -46,9 +46,8 @@ def test_rich_progress_bar_refresh_rate_enabled(): @RunIf(rich=True) -@mock.patch("pytorch_lightning.callbacks.progress.rich_progress.Progress.update") @pytest.mark.parametrize("dataset", [RandomDataset(32, 64), RandomIterableDataset(32, 64)]) -def test_rich_progress_bar(progress_update, tmpdir, dataset): +def test_rich_progress_bar(tmpdir, dataset): class TestModel(BoringModel): def train_dataloader(self): return DataLoader(dataset=dataset) @@ -62,8 +61,6 @@ def test_rich_progress_bar(progress_update, tmpdir, dataset): def predict_dataloader(self): return DataLoader(dataset=dataset) - model = TestModel() - trainer = Trainer( default_root_dir=tmpdir, num_sanity_val_steps=0, @@ -71,16 +68,27 @@ def test_rich_progress_bar(progress_update, tmpdir, dataset): limit_val_batches=1, limit_test_batches=1, limit_predict_batches=1, - max_steps=1, + max_epochs=1, callbacks=RichProgressBar(), ) + model = TestModel() - trainer.fit(model) - trainer.validate(model) - trainer.test(model) - trainer.predict(model) + with mock.patch("pytorch_lightning.callbacks.progress.rich_progress.Progress.update") as mocked: + trainer.fit(model) + # 3 for main progress bar and 1 for val progress bar + assert mocked.call_count == 4 - assert progress_update.call_count == 8 + with mock.patch("pytorch_lightning.callbacks.progress.rich_progress.Progress.update") as mocked: + trainer.validate(model) + assert mocked.call_count == 1 + + with mock.patch("pytorch_lightning.callbacks.progress.rich_progress.Progress.update") as mocked: + trainer.test(model) + assert mocked.call_count == 1 + + with mock.patch("pytorch_lightning.callbacks.progress.rich_progress.Progress.update") as mocked: + trainer.predict(model) + assert mocked.call_count == 1 def test_rich_progress_bar_import_error(monkeypatch): @@ -186,11 +194,20 @@ def test_rich_progress_bar_leave(tmpdir, leave, reset_call_count): @RunIf(rich=True) @mock.patch("pytorch_lightning.callbacks.progress.rich_progress.Progress.update") -@pytest.mark.parametrize(("refresh_rate", "expected_call_count"), ([(0, 0), (3, 7)])) -def test_rich_progress_bar_refresh_rate(progress_update, tmpdir, refresh_rate, expected_call_count): +def test_rich_progress_bar_refresh_rate_disabled(progress_update, tmpdir): + trainer = Trainer( + default_root_dir=tmpdir, + fast_dev_run=4, + callbacks=RichProgressBar(refresh_rate=0), + ) + trainer.fit(BoringModel()) + assert progress_update.call_count == 0 + +@RunIf(rich=True) +@pytest.mark.parametrize(("refresh_rate", "expected_call_count"), ([(3, 7), (4, 7), (7, 4)])) +def test_rich_progress_bar_with_refresh_rate(tmpdir, refresh_rate, expected_call_count): model = BoringModel() - trainer = Trainer( default_root_dir=tmpdir, num_sanity_val_steps=0, @@ -200,14 +217,26 @@ def test_rich_progress_bar_refresh_rate(progress_update, tmpdir, refresh_rate, e callbacks=RichProgressBar(refresh_rate=refresh_rate), ) - trainer.fit(model) + trainer.progress_bar_callback.on_train_start(trainer, model) + with mock.patch.object( + trainer.progress_bar_callback.progress, "update", wraps=trainer.progress_bar_callback.progress.update + ) as progress_update: + trainer.fit(model) + assert progress_update.call_count == expected_call_count - assert progress_update.call_count == expected_call_count + fit_main_bar = trainer.progress_bar_callback.progress.tasks[0] + fit_val_bar = trainer.progress_bar_callback.progress.tasks[1] + assert fit_main_bar.completed == 12 + assert fit_main_bar.total == 12 + assert fit_main_bar.visible + assert fit_val_bar.completed == 6 + assert fit_val_bar.total == 6 + assert not fit_val_bar.visible @RunIf(rich=True) @pytest.mark.parametrize("limit_val_batches", (1, 5)) -def test_rich_progress_bar_num_sanity_val_steps(tmpdir, limit_val_batches: int): +def test_rich_progress_bar_num_sanity_val_steps(tmpdir, limit_val_batches): model = BoringModel() progress_bar = RichProgressBar() @@ -224,6 +253,36 @@ def test_rich_progress_bar_num_sanity_val_steps(tmpdir, limit_val_batches: int): trainer.fit(model) assert progress_bar.progress.tasks[0].completed == min(num_sanity_val_steps, limit_val_batches) + assert progress_bar.progress.tasks[0].total == min(num_sanity_val_steps, limit_val_batches) + + +@RunIf(rich=True) +def test_rich_progress_bar_counter_with_val_check_interval(tmpdir): + """Test the completed and total counter for rich progress bar when using val_check_interval.""" + progress_bar = RichProgressBar() + model = BoringModel() + trainer = Trainer( + default_root_dir=tmpdir, + val_check_interval=2, + max_epochs=1, + limit_train_batches=7, + limit_val_batches=4, + callbacks=[progress_bar], + ) + trainer.fit(model) + + fit_main_progress_bar = progress_bar.progress.tasks[1] + assert fit_main_progress_bar.completed == 7 + 3 * 4 + assert fit_main_progress_bar.total == 7 + 3 * 4 + + fit_val_bar = progress_bar.progress.tasks[2] + assert fit_val_bar.completed == 4 + assert fit_val_bar.total == 4 + + trainer.validate(model) + val_bar = progress_bar.progress.tasks[0] + assert val_bar.completed == 4 + assert val_bar.total == 4 @RunIf(rich=True)