diff --git a/CHANGELOG.md b/CHANGELOG.md index 1d50ae3540..6026a149ac 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -21,6 +21,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fixed `fuse_modules` to be qat-aware for `torch>=1.11` ([#12891](https://github.com/PyTorchLightning/pytorch-lightning/pull/12891)) - Enforced eval shuffle warning only for default samplers in DataLoader ([#12653](https://github.com/PyTorchLightning/pytorch-lightning/pull/12653)) - Enable mixed precision in `DDPFullyShardedStrategy` when `precision=16` ([#12965](https://github.com/PyTorchLightning/pytorch-lightning/pull/12965)) +- Fixed `TQDMProgressBar` reset and update to show correct time estimation ([#12889](https://github.com/PyTorchLightning/pytorch-lightning/pull/12889)) ## [1.6.2] - 2022-04-27 diff --git a/pytorch_lightning/callbacks/progress/tqdm_progress.py b/pytorch_lightning/callbacks/progress/tqdm_progress.py index 2d4da1c15e..f6467e4606 100644 --- a/pytorch_lightning/callbacks/progress/tqdm_progress.py +++ b/pytorch_lightning/callbacks/progress/tqdm_progress.py @@ -262,13 +262,13 @@ class TQDMProgressBar(ProgressBarBase): val_checks_per_epoch = total_train_batches // trainer.val_check_batch total_val_batches = total_val_batches * val_checks_per_epoch total_batches = total_train_batches + total_val_batches - self.main_progress_bar.total = convert_inf(total_batches) + self.main_progress_bar.reset(convert_inf(total_batches)) self.main_progress_bar.set_description(f"Epoch {trainer.current_epoch}") def on_train_batch_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", *_: Any) -> None: current = self.train_batch_idx + self._val_processed if self._should_update(current, self.main_progress_bar.total): - _update_n(self.main_progress_bar, current) + _update_n(self.main_progress_bar, current, self.refresh_rate) self.main_progress_bar.set_postfix(self.get_metrics(trainer, pl_module)) def on_train_epoch_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: @@ -288,17 +288,17 @@ class TQDMProgressBar(ProgressBarBase): if not self.has_dataloader_changed(dataloader_idx): return - self.val_progress_bar.total = convert_inf(self.total_val_batches_current_dataloader) + self.val_progress_bar.reset(convert_inf(self.total_val_batches_current_dataloader)) desc = self.sanity_check_description if trainer.sanity_checking else self.validation_description self.val_progress_bar.set_description(f"{desc} DataLoader {dataloader_idx}") def on_validation_batch_end(self, trainer: "pl.Trainer", *_: Any) -> None: if self._should_update(self.val_batch_idx, self.val_progress_bar.total): - _update_n(self.val_progress_bar, self.val_batch_idx) + _update_n(self.val_progress_bar, self.val_batch_idx, self.refresh_rate) current = self.train_batch_idx + self._val_processed if trainer.state.fn == "fit" and self._should_update(current, self.main_progress_bar.total): - _update_n(self.main_progress_bar, current) + _update_n(self.main_progress_bar, current, self.refresh_rate) def on_validation_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: if self._main_progress_bar is not None and trainer.state.fn == "fit": @@ -315,12 +315,12 @@ class TQDMProgressBar(ProgressBarBase): if not self.has_dataloader_changed(dataloader_idx): return - self.test_progress_bar.total = convert_inf(self.total_test_batches_current_dataloader) + self.test_progress_bar.reset(convert_inf(self.total_test_batches_current_dataloader)) self.test_progress_bar.set_description(f"{self.test_description} DataLoader {dataloader_idx}") def on_test_batch_end(self, *_: Any) -> None: if self._should_update(self.test_batch_idx, self.test_progress_bar.total): - _update_n(self.test_progress_bar, self.test_batch_idx) + _update_n(self.test_progress_bar, self.test_batch_idx, self.refresh_rate) def on_test_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: self.test_progress_bar.close() @@ -335,12 +335,12 @@ class TQDMProgressBar(ProgressBarBase): if not self.has_dataloader_changed(dataloader_idx): return - self.predict_progress_bar.total = convert_inf(self.total_predict_batches_current_dataloader) + self.predict_progress_bar.reset(convert_inf(self.total_predict_batches_current_dataloader)) self.predict_progress_bar.set_description(f"{self.predict_description} DataLoader {dataloader_idx}") def on_predict_batch_end(self, *_: Any) -> None: if self._should_update(self.predict_batch_idx, self.predict_progress_bar.total): - _update_n(self.predict_progress_bar, self.predict_batch_idx) + _update_n(self.predict_progress_bar, self.predict_batch_idx, self.refresh_rate) def on_predict_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: self.predict_progress_bar.close() @@ -384,7 +384,10 @@ def convert_inf(x: Optional[Union[int, float]]) -> Optional[Union[int, float]]: return x -def _update_n(bar: _tqdm, value: int) -> None: +def _update_n(bar: _tqdm, current: int, refresh_rate: int) -> None: if not bar.disable: - bar.n = value + total = bar.total + leftover = current % refresh_rate + advance = leftover if (current == total and leftover != 0) else refresh_rate + bar.update(advance) bar.refresh() diff --git a/requirements.txt b/requirements.txt index 6aa080fc7e..39f0d586ba 100644 --- a/requirements.txt +++ b/requirements.txt @@ -2,7 +2,7 @@ numpy>=1.17.2 torch>=1.8.* -tqdm>=4.41.0 +tqdm>=4.57.0 PyYAML>=5.4 fsspec[http]>=2021.05.0, !=2021.06.0 tensorboard>=2.2.0 diff --git a/tests/callbacks/test_tqdm_progress_bar.py b/tests/callbacks/test_tqdm_progress_bar.py index f36f9d3353..f46ea267f9 100644 --- a/tests/callbacks/test_tqdm_progress_bar.py +++ b/tests/callbacks/test_tqdm_progress_bar.py @@ -53,6 +53,7 @@ class MockTqdm(Tqdm): @n.setter def n(self, value): self.__n = value + # track the changes in the `n` value if not len(self.n_values) or value != self.n_values[-1]: self.n_values.append(value) @@ -158,7 +159,7 @@ def test_tqdm_progress_bar_totals(tmpdir, num_dl): assert not pbar.val_progress_bar.leave assert trainer.num_sanity_val_batches == expected_sanity_steps assert pbar.val_progress_bar.total_values == expected_sanity_steps - assert pbar.val_progress_bar.n_values == list(range(1, num_sanity_val_steps + 1)) * num_dl + assert pbar.val_progress_bar.n_values == list(range(num_sanity_val_steps + 1)) * num_dl assert pbar.val_progress_bar.descriptions == [f"Sanity Checking DataLoader {i}: " for i in range(num_dl)] # fit @@ -177,7 +178,7 @@ def test_tqdm_progress_bar_totals(tmpdir, num_dl): # check val progress bar total assert pbar.val_progress_bar.total_values == m - assert pbar.val_progress_bar.n_values == list(range(1, m[0] + 1)) * num_dl + assert pbar.val_progress_bar.n_values == list(range(m[0] + 1)) * num_dl assert pbar.val_progress_bar.descriptions == [f"Validation DataLoader {i}: " for i in range(num_dl)] assert not pbar.val_progress_bar.leave @@ -186,7 +187,7 @@ def test_tqdm_progress_bar_totals(tmpdir, num_dl): trainer.validate(model) assert trainer.num_val_batches == m assert pbar.val_progress_bar.total_values == m - assert pbar.val_progress_bar.n_values == list(range(1, m[0] + 1)) * num_dl + assert pbar.val_progress_bar.n_values == list(range(m[0] + 1)) * num_dl assert pbar.val_progress_bar.descriptions == [f"Validation DataLoader {i}: " for i in range(num_dl)] # test @@ -195,7 +196,7 @@ def test_tqdm_progress_bar_totals(tmpdir, num_dl): assert pbar.test_progress_bar.leave k = trainer.num_test_batches assert pbar.test_progress_bar.total_values == k - assert pbar.test_progress_bar.n_values == list(range(1, k[0] + 1)) * num_dl + assert pbar.test_progress_bar.n_values == list(range(k[0] + 1)) * num_dl assert pbar.test_progress_bar.descriptions == [f"Testing DataLoader {i}: " for i in range(num_dl)] assert pbar.test_progress_bar.leave @@ -205,7 +206,7 @@ def test_tqdm_progress_bar_totals(tmpdir, num_dl): assert pbar.predict_progress_bar.leave k = trainer.num_predict_batches assert pbar.predict_progress_bar.total_values == k - assert pbar.predict_progress_bar.n_values == list(range(1, k[0] + 1)) * num_dl + assert pbar.predict_progress_bar.n_values == list(range(k[0] + 1)) * num_dl assert pbar.predict_progress_bar.descriptions == [f"Predicting DataLoader {i}: " for i in range(num_dl)] assert pbar.predict_progress_bar.leave @@ -359,13 +360,13 @@ def test_tqdm_progress_bar_value_on_colab(tmpdir): @pytest.mark.parametrize( "train_batches,val_batches,refresh_rate,train_updates,val_updates", [ - [2, 3, 1, [1, 2, 3, 4, 5], [1, 2, 3]], + [2, 3, 1, [0, 1, 2, 3, 4, 5], [0, 1, 2, 3]], [0, 0, 3, None, None], - [1, 0, 3, [1], None], - [1, 1, 3, [2], [1]], - [5, 0, 3, [3, 5], None], - [5, 2, 3, [3, 6, 7], [2]], - [5, 2, 6, [6, 7], [2]], + [1, 0, 3, [0, 1], None], + [1, 1, 3, [0, 2], [0, 1]], + [5, 0, 3, [0, 3, 5], None], + [5, 2, 3, [0, 3, 6, 7], [0, 2]], + [5, 2, 6, [0, 6, 7], [0, 2]], ], ) def test_main_progress_bar_update_amount( @@ -395,7 +396,7 @@ def test_main_progress_bar_update_amount( assert progress_bar.val_progress_bar.n_values == val_updates -@pytest.mark.parametrize("test_batches,refresh_rate,updates", [[1, 3, [1]], [3, 1, [1, 2, 3]], [5, 3, [3, 5]]]) +@pytest.mark.parametrize("test_batches,refresh_rate,updates", [(1, 3, [0, 1]), (3, 1, [0, 1, 2, 3]), (5, 3, [0, 3, 5])]) def test_test_progress_bar_update_amount(tmpdir, test_batches: int, refresh_rate: int, updates: list): """Test that test progress updates with the correct amount.""" model = BoringModel() @@ -566,7 +567,7 @@ def test_tqdm_progress_bar_can_be_pickled(): @pytest.mark.parametrize( ["val_check_interval", "main_progress_bar_updates", "val_progress_bar_updates"], - [(4, [3, 6, 9, 12, 14], [3, 6, 7]), (0.5, [3, 6, 9, 12, 15, 18, 21], [3, 6, 7])], + [(4, [0, 3, 6, 9, 12, 14], [0, 3, 6, 7]), (0.5, [0, 3, 6, 9, 12, 15, 18, 21], [0, 3, 6, 7])], ) def test_progress_bar_max_val_check_interval( tmpdir, val_check_interval, main_progress_bar_updates, val_progress_bar_updates