Fix `TQDMProgressBar` reset and update to show correct time estimation (#12889)

Co-authored-by: Carlos Mocholí <carlossmocholi@gmail.com>
This commit is contained in:
Rohit Gupta 2022-05-03 21:51:59 +05:30 committed by lexierule
parent ab7ad37b82
commit 55f5e2d170
4 changed files with 30 additions and 25 deletions

View File

@ -21,6 +21,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Fixed `fuse_modules` to be qat-aware for `torch>=1.11` ([#12891](https://github.com/PyTorchLightning/pytorch-lightning/pull/12891)) - Fixed `fuse_modules` to be qat-aware for `torch>=1.11` ([#12891](https://github.com/PyTorchLightning/pytorch-lightning/pull/12891))
- Enforced eval shuffle warning only for default samplers in DataLoader ([#12653](https://github.com/PyTorchLightning/pytorch-lightning/pull/12653)) - Enforced eval shuffle warning only for default samplers in DataLoader ([#12653](https://github.com/PyTorchLightning/pytorch-lightning/pull/12653))
- Enable mixed precision in `DDPFullyShardedStrategy` when `precision=16` ([#12965](https://github.com/PyTorchLightning/pytorch-lightning/pull/12965)) - Enable mixed precision in `DDPFullyShardedStrategy` when `precision=16` ([#12965](https://github.com/PyTorchLightning/pytorch-lightning/pull/12965))
- Fixed `TQDMProgressBar` reset and update to show correct time estimation ([#12889](https://github.com/PyTorchLightning/pytorch-lightning/pull/12889))
## [1.6.2] - 2022-04-27 ## [1.6.2] - 2022-04-27

View File

@ -262,13 +262,13 @@ class TQDMProgressBar(ProgressBarBase):
val_checks_per_epoch = total_train_batches // trainer.val_check_batch val_checks_per_epoch = total_train_batches // trainer.val_check_batch
total_val_batches = total_val_batches * val_checks_per_epoch total_val_batches = total_val_batches * val_checks_per_epoch
total_batches = total_train_batches + total_val_batches total_batches = total_train_batches + total_val_batches
self.main_progress_bar.total = convert_inf(total_batches) self.main_progress_bar.reset(convert_inf(total_batches))
self.main_progress_bar.set_description(f"Epoch {trainer.current_epoch}") self.main_progress_bar.set_description(f"Epoch {trainer.current_epoch}")
def on_train_batch_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", *_: Any) -> None: def on_train_batch_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", *_: Any) -> None:
current = self.train_batch_idx + self._val_processed current = self.train_batch_idx + self._val_processed
if self._should_update(current, self.main_progress_bar.total): if self._should_update(current, self.main_progress_bar.total):
_update_n(self.main_progress_bar, current) _update_n(self.main_progress_bar, current, self.refresh_rate)
self.main_progress_bar.set_postfix(self.get_metrics(trainer, pl_module)) self.main_progress_bar.set_postfix(self.get_metrics(trainer, pl_module))
def on_train_epoch_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: def on_train_epoch_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
@ -288,17 +288,17 @@ class TQDMProgressBar(ProgressBarBase):
if not self.has_dataloader_changed(dataloader_idx): if not self.has_dataloader_changed(dataloader_idx):
return return
self.val_progress_bar.total = convert_inf(self.total_val_batches_current_dataloader) self.val_progress_bar.reset(convert_inf(self.total_val_batches_current_dataloader))
desc = self.sanity_check_description if trainer.sanity_checking else self.validation_description desc = self.sanity_check_description if trainer.sanity_checking else self.validation_description
self.val_progress_bar.set_description(f"{desc} DataLoader {dataloader_idx}") self.val_progress_bar.set_description(f"{desc} DataLoader {dataloader_idx}")
def on_validation_batch_end(self, trainer: "pl.Trainer", *_: Any) -> None: def on_validation_batch_end(self, trainer: "pl.Trainer", *_: Any) -> None:
if self._should_update(self.val_batch_idx, self.val_progress_bar.total): if self._should_update(self.val_batch_idx, self.val_progress_bar.total):
_update_n(self.val_progress_bar, self.val_batch_idx) _update_n(self.val_progress_bar, self.val_batch_idx, self.refresh_rate)
current = self.train_batch_idx + self._val_processed current = self.train_batch_idx + self._val_processed
if trainer.state.fn == "fit" and self._should_update(current, self.main_progress_bar.total): if trainer.state.fn == "fit" and self._should_update(current, self.main_progress_bar.total):
_update_n(self.main_progress_bar, current) _update_n(self.main_progress_bar, current, self.refresh_rate)
def on_validation_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: def on_validation_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
if self._main_progress_bar is not None and trainer.state.fn == "fit": if self._main_progress_bar is not None and trainer.state.fn == "fit":
@ -315,12 +315,12 @@ class TQDMProgressBar(ProgressBarBase):
if not self.has_dataloader_changed(dataloader_idx): if not self.has_dataloader_changed(dataloader_idx):
return return
self.test_progress_bar.total = convert_inf(self.total_test_batches_current_dataloader) self.test_progress_bar.reset(convert_inf(self.total_test_batches_current_dataloader))
self.test_progress_bar.set_description(f"{self.test_description} DataLoader {dataloader_idx}") self.test_progress_bar.set_description(f"{self.test_description} DataLoader {dataloader_idx}")
def on_test_batch_end(self, *_: Any) -> None: def on_test_batch_end(self, *_: Any) -> None:
if self._should_update(self.test_batch_idx, self.test_progress_bar.total): if self._should_update(self.test_batch_idx, self.test_progress_bar.total):
_update_n(self.test_progress_bar, self.test_batch_idx) _update_n(self.test_progress_bar, self.test_batch_idx, self.refresh_rate)
def on_test_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: def on_test_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
self.test_progress_bar.close() self.test_progress_bar.close()
@ -335,12 +335,12 @@ class TQDMProgressBar(ProgressBarBase):
if not self.has_dataloader_changed(dataloader_idx): if not self.has_dataloader_changed(dataloader_idx):
return return
self.predict_progress_bar.total = convert_inf(self.total_predict_batches_current_dataloader) self.predict_progress_bar.reset(convert_inf(self.total_predict_batches_current_dataloader))
self.predict_progress_bar.set_description(f"{self.predict_description} DataLoader {dataloader_idx}") self.predict_progress_bar.set_description(f"{self.predict_description} DataLoader {dataloader_idx}")
def on_predict_batch_end(self, *_: Any) -> None: def on_predict_batch_end(self, *_: Any) -> None:
if self._should_update(self.predict_batch_idx, self.predict_progress_bar.total): if self._should_update(self.predict_batch_idx, self.predict_progress_bar.total):
_update_n(self.predict_progress_bar, self.predict_batch_idx) _update_n(self.predict_progress_bar, self.predict_batch_idx, self.refresh_rate)
def on_predict_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: def on_predict_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
self.predict_progress_bar.close() self.predict_progress_bar.close()
@ -384,7 +384,10 @@ def convert_inf(x: Optional[Union[int, float]]) -> Optional[Union[int, float]]:
return x return x
def _update_n(bar: _tqdm, value: int) -> None: def _update_n(bar: _tqdm, current: int, refresh_rate: int) -> None:
if not bar.disable: if not bar.disable:
bar.n = value total = bar.total
leftover = current % refresh_rate
advance = leftover if (current == total and leftover != 0) else refresh_rate
bar.update(advance)
bar.refresh() bar.refresh()

View File

@ -2,7 +2,7 @@
numpy>=1.17.2 numpy>=1.17.2
torch>=1.8.* torch>=1.8.*
tqdm>=4.41.0 tqdm>=4.57.0
PyYAML>=5.4 PyYAML>=5.4
fsspec[http]>=2021.05.0, !=2021.06.0 fsspec[http]>=2021.05.0, !=2021.06.0
tensorboard>=2.2.0 tensorboard>=2.2.0

View File

@ -53,6 +53,7 @@ class MockTqdm(Tqdm):
@n.setter @n.setter
def n(self, value): def n(self, value):
self.__n = value self.__n = value
# track the changes in the `n` value # track the changes in the `n` value
if not len(self.n_values) or value != self.n_values[-1]: if not len(self.n_values) or value != self.n_values[-1]:
self.n_values.append(value) self.n_values.append(value)
@ -158,7 +159,7 @@ def test_tqdm_progress_bar_totals(tmpdir, num_dl):
assert not pbar.val_progress_bar.leave assert not pbar.val_progress_bar.leave
assert trainer.num_sanity_val_batches == expected_sanity_steps assert trainer.num_sanity_val_batches == expected_sanity_steps
assert pbar.val_progress_bar.total_values == expected_sanity_steps assert pbar.val_progress_bar.total_values == expected_sanity_steps
assert pbar.val_progress_bar.n_values == list(range(1, num_sanity_val_steps + 1)) * num_dl assert pbar.val_progress_bar.n_values == list(range(num_sanity_val_steps + 1)) * num_dl
assert pbar.val_progress_bar.descriptions == [f"Sanity Checking DataLoader {i}: " for i in range(num_dl)] assert pbar.val_progress_bar.descriptions == [f"Sanity Checking DataLoader {i}: " for i in range(num_dl)]
# fit # fit
@ -177,7 +178,7 @@ def test_tqdm_progress_bar_totals(tmpdir, num_dl):
# check val progress bar total # check val progress bar total
assert pbar.val_progress_bar.total_values == m assert pbar.val_progress_bar.total_values == m
assert pbar.val_progress_bar.n_values == list(range(1, m[0] + 1)) * num_dl assert pbar.val_progress_bar.n_values == list(range(m[0] + 1)) * num_dl
assert pbar.val_progress_bar.descriptions == [f"Validation DataLoader {i}: " for i in range(num_dl)] assert pbar.val_progress_bar.descriptions == [f"Validation DataLoader {i}: " for i in range(num_dl)]
assert not pbar.val_progress_bar.leave assert not pbar.val_progress_bar.leave
@ -186,7 +187,7 @@ def test_tqdm_progress_bar_totals(tmpdir, num_dl):
trainer.validate(model) trainer.validate(model)
assert trainer.num_val_batches == m assert trainer.num_val_batches == m
assert pbar.val_progress_bar.total_values == m assert pbar.val_progress_bar.total_values == m
assert pbar.val_progress_bar.n_values == list(range(1, m[0] + 1)) * num_dl assert pbar.val_progress_bar.n_values == list(range(m[0] + 1)) * num_dl
assert pbar.val_progress_bar.descriptions == [f"Validation DataLoader {i}: " for i in range(num_dl)] assert pbar.val_progress_bar.descriptions == [f"Validation DataLoader {i}: " for i in range(num_dl)]
# test # test
@ -195,7 +196,7 @@ def test_tqdm_progress_bar_totals(tmpdir, num_dl):
assert pbar.test_progress_bar.leave assert pbar.test_progress_bar.leave
k = trainer.num_test_batches k = trainer.num_test_batches
assert pbar.test_progress_bar.total_values == k assert pbar.test_progress_bar.total_values == k
assert pbar.test_progress_bar.n_values == list(range(1, k[0] + 1)) * num_dl assert pbar.test_progress_bar.n_values == list(range(k[0] + 1)) * num_dl
assert pbar.test_progress_bar.descriptions == [f"Testing DataLoader {i}: " for i in range(num_dl)] assert pbar.test_progress_bar.descriptions == [f"Testing DataLoader {i}: " for i in range(num_dl)]
assert pbar.test_progress_bar.leave assert pbar.test_progress_bar.leave
@ -205,7 +206,7 @@ def test_tqdm_progress_bar_totals(tmpdir, num_dl):
assert pbar.predict_progress_bar.leave assert pbar.predict_progress_bar.leave
k = trainer.num_predict_batches k = trainer.num_predict_batches
assert pbar.predict_progress_bar.total_values == k assert pbar.predict_progress_bar.total_values == k
assert pbar.predict_progress_bar.n_values == list(range(1, k[0] + 1)) * num_dl assert pbar.predict_progress_bar.n_values == list(range(k[0] + 1)) * num_dl
assert pbar.predict_progress_bar.descriptions == [f"Predicting DataLoader {i}: " for i in range(num_dl)] assert pbar.predict_progress_bar.descriptions == [f"Predicting DataLoader {i}: " for i in range(num_dl)]
assert pbar.predict_progress_bar.leave assert pbar.predict_progress_bar.leave
@ -359,13 +360,13 @@ def test_tqdm_progress_bar_value_on_colab(tmpdir):
@pytest.mark.parametrize( @pytest.mark.parametrize(
"train_batches,val_batches,refresh_rate,train_updates,val_updates", "train_batches,val_batches,refresh_rate,train_updates,val_updates",
[ [
[2, 3, 1, [1, 2, 3, 4, 5], [1, 2, 3]], [2, 3, 1, [0, 1, 2, 3, 4, 5], [0, 1, 2, 3]],
[0, 0, 3, None, None], [0, 0, 3, None, None],
[1, 0, 3, [1], None], [1, 0, 3, [0, 1], None],
[1, 1, 3, [2], [1]], [1, 1, 3, [0, 2], [0, 1]],
[5, 0, 3, [3, 5], None], [5, 0, 3, [0, 3, 5], None],
[5, 2, 3, [3, 6, 7], [2]], [5, 2, 3, [0, 3, 6, 7], [0, 2]],
[5, 2, 6, [6, 7], [2]], [5, 2, 6, [0, 6, 7], [0, 2]],
], ],
) )
def test_main_progress_bar_update_amount( def test_main_progress_bar_update_amount(
@ -395,7 +396,7 @@ def test_main_progress_bar_update_amount(
assert progress_bar.val_progress_bar.n_values == val_updates assert progress_bar.val_progress_bar.n_values == val_updates
@pytest.mark.parametrize("test_batches,refresh_rate,updates", [[1, 3, [1]], [3, 1, [1, 2, 3]], [5, 3, [3, 5]]]) @pytest.mark.parametrize("test_batches,refresh_rate,updates", [(1, 3, [0, 1]), (3, 1, [0, 1, 2, 3]), (5, 3, [0, 3, 5])])
def test_test_progress_bar_update_amount(tmpdir, test_batches: int, refresh_rate: int, updates: list): def test_test_progress_bar_update_amount(tmpdir, test_batches: int, refresh_rate: int, updates: list):
"""Test that test progress updates with the correct amount.""" """Test that test progress updates with the correct amount."""
model = BoringModel() model = BoringModel()
@ -566,7 +567,7 @@ def test_tqdm_progress_bar_can_be_pickled():
@pytest.mark.parametrize( @pytest.mark.parametrize(
["val_check_interval", "main_progress_bar_updates", "val_progress_bar_updates"], ["val_check_interval", "main_progress_bar_updates", "val_progress_bar_updates"],
[(4, [3, 6, 9, 12, 14], [3, 6, 7]), (0.5, [3, 6, 9, 12, 15, 18, 21], [3, 6, 7])], [(4, [0, 3, 6, 9, 12, 14], [0, 3, 6, 7]), (0.5, [0, 3, 6, 9, 12, 15, 18, 21], [0, 3, 6, 7])],
) )
def test_progress_bar_max_val_check_interval( def test_progress_bar_max_val_check_interval(
tmpdir, val_check_interval, main_progress_bar_updates, val_progress_bar_updates tmpdir, val_check_interval, main_progress_bar_updates, val_progress_bar_updates