Fix visual progress bar bug / properly reset progress bar (#4579)

* reset

* fix reset

* changelog

* update chlog

* typing

Co-authored-by: Carlos Mocholí <carlossmocholi@gmail.com>

Co-authored-by: Carlos Mocholí <carlossmocholi@gmail.com>
Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>
This commit is contained in:
Adrian Wälchli 2021-01-14 03:25:12 +01:00 committed by Jirka Borovec
parent a9a376019f
commit 7b42494d0e
2 changed files with 33 additions and 6 deletions

View File

@ -169,7 +169,27 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Fixed loading yaml ([#5619](https://github.com/PyTorchLightning/pytorch-lightning/pull/5619))
## [1.1.4] - YYYY-MM-DD
## [unreleased.Bugfixes] - YYYY-MM-DD
### Added
### Changed
### Deprecated
### Removed
### Fixed
- Fixed a visual bug in the progress bar display initialization ([#4579](https://github.com/PyTorchLightning/pytorch-lightning/pull/4579))
## [1.1.4] - 2021-01-12
### Added

View File

@ -24,6 +24,8 @@ import sys
# check if ipywidgets is installed before importing tqdm.auto
# to ensure it won't fail and a progress bar is displayed
from typing import Optional, Union
if importlib.util.find_spec('ipywidgets') is not None:
from tqdm.auto import tqdm
else:
@ -308,7 +310,7 @@ class ProgressBar(ProgressBarBase):
def on_sanity_check_start(self, trainer, pl_module):
super().on_sanity_check_start(trainer, pl_module)
self.val_progress_bar = self.init_sanity_tqdm()
self.val_progress_bar.total = convert_inf(sum(trainer.num_sanity_val_batches))
reset(self.val_progress_bar, sum(trainer.num_sanity_val_batches))
self.main_progress_bar = tqdm(disable=True) # dummy progress bar
def on_sanity_check_end(self, trainer, pl_module):
@ -329,8 +331,7 @@ class ProgressBar(ProgressBarBase):
val_checks_per_epoch = total_train_batches // trainer.val_check_batch
total_val_batches = total_val_batches * val_checks_per_epoch
total_batches = total_train_batches + total_val_batches
if not self.main_progress_bar.disable:
self.main_progress_bar.reset(convert_inf(total_batches))
reset(self.main_progress_bar, total_batches)
self.main_progress_bar.set_description(f'Epoch {trainer.current_epoch}')
def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx):
@ -344,7 +345,7 @@ class ProgressBar(ProgressBarBase):
if not trainer.running_sanity_check:
self._update_bar(self.main_progress_bar) # fill up remaining
self.val_progress_bar = self.init_validation_tqdm()
self.val_progress_bar.total = convert_inf(self.total_val_batches)
reset(self.val_progress_bar, self.total_val_batches)
def on_validation_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx):
super().on_validation_batch_end(trainer, pl_module, outputs, batch, batch_idx, dataloader_idx)
@ -389,8 +390,14 @@ class ProgressBar(ProgressBarBase):
bar.update(delta)
def convert_inf(x):
def convert_inf(x: Optional[Union[int, float]]) -> Optional[Union[int, float]]:
""" The tqdm doesn't support inf values. We have to convert it to None. """
if x == float('inf'):
return None
return x
def reset(bar: tqdm, total: Optional[int] = None) -> None:
""" Resets the tqdm bar to 0 progress with a new total, unless it is disabled. """
if not bar.disable:
bar.reset(total=convert_inf(total))