Fix `TQDMProgressBar` reset and update to show correct time estimation (#12889)
Co-authored-by: Carlos Mocholí <carlossmocholi@gmail.com>
This commit is contained in:
parent
ab7ad37b82
commit
55f5e2d170
|
@ -21,6 +21,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
|
|||
- Fixed `fuse_modules` to be qat-aware for `torch>=1.11` ([#12891](https://github.com/PyTorchLightning/pytorch-lightning/pull/12891))
|
||||
- Enforced eval shuffle warning only for default samplers in DataLoader ([#12653](https://github.com/PyTorchLightning/pytorch-lightning/pull/12653))
|
||||
- Enable mixed precision in `DDPFullyShardedStrategy` when `precision=16` ([#12965](https://github.com/PyTorchLightning/pytorch-lightning/pull/12965))
|
||||
- Fixed `TQDMProgressBar` reset and update to show correct time estimation ([#12889](https://github.com/PyTorchLightning/pytorch-lightning/pull/12889))
|
||||
|
||||
|
||||
## [1.6.2] - 2022-04-27
|
||||
|
|
|
@ -262,13 +262,13 @@ class TQDMProgressBar(ProgressBarBase):
|
|||
val_checks_per_epoch = total_train_batches // trainer.val_check_batch
|
||||
total_val_batches = total_val_batches * val_checks_per_epoch
|
||||
total_batches = total_train_batches + total_val_batches
|
||||
self.main_progress_bar.total = convert_inf(total_batches)
|
||||
self.main_progress_bar.reset(convert_inf(total_batches))
|
||||
self.main_progress_bar.set_description(f"Epoch {trainer.current_epoch}")
|
||||
|
||||
def on_train_batch_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", *_: Any) -> None:
|
||||
current = self.train_batch_idx + self._val_processed
|
||||
if self._should_update(current, self.main_progress_bar.total):
|
||||
_update_n(self.main_progress_bar, current)
|
||||
_update_n(self.main_progress_bar, current, self.refresh_rate)
|
||||
self.main_progress_bar.set_postfix(self.get_metrics(trainer, pl_module))
|
||||
|
||||
def on_train_epoch_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
|
||||
|
@ -288,17 +288,17 @@ class TQDMProgressBar(ProgressBarBase):
|
|||
if not self.has_dataloader_changed(dataloader_idx):
|
||||
return
|
||||
|
||||
self.val_progress_bar.total = convert_inf(self.total_val_batches_current_dataloader)
|
||||
self.val_progress_bar.reset(convert_inf(self.total_val_batches_current_dataloader))
|
||||
desc = self.sanity_check_description if trainer.sanity_checking else self.validation_description
|
||||
self.val_progress_bar.set_description(f"{desc} DataLoader {dataloader_idx}")
|
||||
|
||||
def on_validation_batch_end(self, trainer: "pl.Trainer", *_: Any) -> None:
|
||||
if self._should_update(self.val_batch_idx, self.val_progress_bar.total):
|
||||
_update_n(self.val_progress_bar, self.val_batch_idx)
|
||||
_update_n(self.val_progress_bar, self.val_batch_idx, self.refresh_rate)
|
||||
|
||||
current = self.train_batch_idx + self._val_processed
|
||||
if trainer.state.fn == "fit" and self._should_update(current, self.main_progress_bar.total):
|
||||
_update_n(self.main_progress_bar, current)
|
||||
_update_n(self.main_progress_bar, current, self.refresh_rate)
|
||||
|
||||
def on_validation_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
|
||||
if self._main_progress_bar is not None and trainer.state.fn == "fit":
|
||||
|
@ -315,12 +315,12 @@ class TQDMProgressBar(ProgressBarBase):
|
|||
if not self.has_dataloader_changed(dataloader_idx):
|
||||
return
|
||||
|
||||
self.test_progress_bar.total = convert_inf(self.total_test_batches_current_dataloader)
|
||||
self.test_progress_bar.reset(convert_inf(self.total_test_batches_current_dataloader))
|
||||
self.test_progress_bar.set_description(f"{self.test_description} DataLoader {dataloader_idx}")
|
||||
|
||||
def on_test_batch_end(self, *_: Any) -> None:
|
||||
if self._should_update(self.test_batch_idx, self.test_progress_bar.total):
|
||||
_update_n(self.test_progress_bar, self.test_batch_idx)
|
||||
_update_n(self.test_progress_bar, self.test_batch_idx, self.refresh_rate)
|
||||
|
||||
def on_test_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
|
||||
self.test_progress_bar.close()
|
||||
|
@ -335,12 +335,12 @@ class TQDMProgressBar(ProgressBarBase):
|
|||
if not self.has_dataloader_changed(dataloader_idx):
|
||||
return
|
||||
|
||||
self.predict_progress_bar.total = convert_inf(self.total_predict_batches_current_dataloader)
|
||||
self.predict_progress_bar.reset(convert_inf(self.total_predict_batches_current_dataloader))
|
||||
self.predict_progress_bar.set_description(f"{self.predict_description} DataLoader {dataloader_idx}")
|
||||
|
||||
def on_predict_batch_end(self, *_: Any) -> None:
|
||||
if self._should_update(self.predict_batch_idx, self.predict_progress_bar.total):
|
||||
_update_n(self.predict_progress_bar, self.predict_batch_idx)
|
||||
_update_n(self.predict_progress_bar, self.predict_batch_idx, self.refresh_rate)
|
||||
|
||||
def on_predict_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
|
||||
self.predict_progress_bar.close()
|
||||
|
@ -384,7 +384,10 @@ def convert_inf(x: Optional[Union[int, float]]) -> Optional[Union[int, float]]:
|
|||
return x
|
||||
|
||||
|
||||
def _update_n(bar: _tqdm, value: int) -> None:
|
||||
def _update_n(bar: _tqdm, current: int, refresh_rate: int) -> None:
|
||||
if not bar.disable:
|
||||
bar.n = value
|
||||
total = bar.total
|
||||
leftover = current % refresh_rate
|
||||
advance = leftover if (current == total and leftover != 0) else refresh_rate
|
||||
bar.update(advance)
|
||||
bar.refresh()
|
||||
|
|
|
@ -2,7 +2,7 @@
|
|||
|
||||
numpy>=1.17.2
|
||||
torch>=1.8.*
|
||||
tqdm>=4.41.0
|
||||
tqdm>=4.57.0
|
||||
PyYAML>=5.4
|
||||
fsspec[http]>=2021.05.0, !=2021.06.0
|
||||
tensorboard>=2.2.0
|
||||
|
|
|
@ -53,6 +53,7 @@ class MockTqdm(Tqdm):
|
|||
@n.setter
|
||||
def n(self, value):
|
||||
self.__n = value
|
||||
|
||||
# track the changes in the `n` value
|
||||
if not len(self.n_values) or value != self.n_values[-1]:
|
||||
self.n_values.append(value)
|
||||
|
@ -158,7 +159,7 @@ def test_tqdm_progress_bar_totals(tmpdir, num_dl):
|
|||
assert not pbar.val_progress_bar.leave
|
||||
assert trainer.num_sanity_val_batches == expected_sanity_steps
|
||||
assert pbar.val_progress_bar.total_values == expected_sanity_steps
|
||||
assert pbar.val_progress_bar.n_values == list(range(1, num_sanity_val_steps + 1)) * num_dl
|
||||
assert pbar.val_progress_bar.n_values == list(range(num_sanity_val_steps + 1)) * num_dl
|
||||
assert pbar.val_progress_bar.descriptions == [f"Sanity Checking DataLoader {i}: " for i in range(num_dl)]
|
||||
|
||||
# fit
|
||||
|
@ -177,7 +178,7 @@ def test_tqdm_progress_bar_totals(tmpdir, num_dl):
|
|||
|
||||
# check val progress bar total
|
||||
assert pbar.val_progress_bar.total_values == m
|
||||
assert pbar.val_progress_bar.n_values == list(range(1, m[0] + 1)) * num_dl
|
||||
assert pbar.val_progress_bar.n_values == list(range(m[0] + 1)) * num_dl
|
||||
assert pbar.val_progress_bar.descriptions == [f"Validation DataLoader {i}: " for i in range(num_dl)]
|
||||
assert not pbar.val_progress_bar.leave
|
||||
|
||||
|
@ -186,7 +187,7 @@ def test_tqdm_progress_bar_totals(tmpdir, num_dl):
|
|||
trainer.validate(model)
|
||||
assert trainer.num_val_batches == m
|
||||
assert pbar.val_progress_bar.total_values == m
|
||||
assert pbar.val_progress_bar.n_values == list(range(1, m[0] + 1)) * num_dl
|
||||
assert pbar.val_progress_bar.n_values == list(range(m[0] + 1)) * num_dl
|
||||
assert pbar.val_progress_bar.descriptions == [f"Validation DataLoader {i}: " for i in range(num_dl)]
|
||||
|
||||
# test
|
||||
|
@ -195,7 +196,7 @@ def test_tqdm_progress_bar_totals(tmpdir, num_dl):
|
|||
assert pbar.test_progress_bar.leave
|
||||
k = trainer.num_test_batches
|
||||
assert pbar.test_progress_bar.total_values == k
|
||||
assert pbar.test_progress_bar.n_values == list(range(1, k[0] + 1)) * num_dl
|
||||
assert pbar.test_progress_bar.n_values == list(range(k[0] + 1)) * num_dl
|
||||
assert pbar.test_progress_bar.descriptions == [f"Testing DataLoader {i}: " for i in range(num_dl)]
|
||||
assert pbar.test_progress_bar.leave
|
||||
|
||||
|
@ -205,7 +206,7 @@ def test_tqdm_progress_bar_totals(tmpdir, num_dl):
|
|||
assert pbar.predict_progress_bar.leave
|
||||
k = trainer.num_predict_batches
|
||||
assert pbar.predict_progress_bar.total_values == k
|
||||
assert pbar.predict_progress_bar.n_values == list(range(1, k[0] + 1)) * num_dl
|
||||
assert pbar.predict_progress_bar.n_values == list(range(k[0] + 1)) * num_dl
|
||||
assert pbar.predict_progress_bar.descriptions == [f"Predicting DataLoader {i}: " for i in range(num_dl)]
|
||||
assert pbar.predict_progress_bar.leave
|
||||
|
||||
|
@ -359,13 +360,13 @@ def test_tqdm_progress_bar_value_on_colab(tmpdir):
|
|||
@pytest.mark.parametrize(
|
||||
"train_batches,val_batches,refresh_rate,train_updates,val_updates",
|
||||
[
|
||||
[2, 3, 1, [1, 2, 3, 4, 5], [1, 2, 3]],
|
||||
[2, 3, 1, [0, 1, 2, 3, 4, 5], [0, 1, 2, 3]],
|
||||
[0, 0, 3, None, None],
|
||||
[1, 0, 3, [1], None],
|
||||
[1, 1, 3, [2], [1]],
|
||||
[5, 0, 3, [3, 5], None],
|
||||
[5, 2, 3, [3, 6, 7], [2]],
|
||||
[5, 2, 6, [6, 7], [2]],
|
||||
[1, 0, 3, [0, 1], None],
|
||||
[1, 1, 3, [0, 2], [0, 1]],
|
||||
[5, 0, 3, [0, 3, 5], None],
|
||||
[5, 2, 3, [0, 3, 6, 7], [0, 2]],
|
||||
[5, 2, 6, [0, 6, 7], [0, 2]],
|
||||
],
|
||||
)
|
||||
def test_main_progress_bar_update_amount(
|
||||
|
@ -395,7 +396,7 @@ def test_main_progress_bar_update_amount(
|
|||
assert progress_bar.val_progress_bar.n_values == val_updates
|
||||
|
||||
|
||||
@pytest.mark.parametrize("test_batches,refresh_rate,updates", [[1, 3, [1]], [3, 1, [1, 2, 3]], [5, 3, [3, 5]]])
|
||||
@pytest.mark.parametrize("test_batches,refresh_rate,updates", [(1, 3, [0, 1]), (3, 1, [0, 1, 2, 3]), (5, 3, [0, 3, 5])])
|
||||
def test_test_progress_bar_update_amount(tmpdir, test_batches: int, refresh_rate: int, updates: list):
|
||||
"""Test that test progress updates with the correct amount."""
|
||||
model = BoringModel()
|
||||
|
@ -566,7 +567,7 @@ def test_tqdm_progress_bar_can_be_pickled():
|
|||
|
||||
@pytest.mark.parametrize(
|
||||
["val_check_interval", "main_progress_bar_updates", "val_progress_bar_updates"],
|
||||
[(4, [3, 6, 9, 12, 14], [3, 6, 7]), (0.5, [3, 6, 9, 12, 15, 18, 21], [3, 6, 7])],
|
||||
[(4, [0, 3, 6, 9, 12, 14], [0, 3, 6, 7]), (0.5, [0, 3, 6, 9, 12, 15, 18, 21], [0, 3, 6, 7])],
|
||||
)
|
||||
def test_progress_bar_max_val_check_interval(
|
||||
tmpdir, val_check_interval, main_progress_bar_updates, val_progress_bar_updates
|
||||
|
|
Loading…
Reference in New Issue