Fix visual progress bar bug / properly reset progress bar (#4579)
* reset * fix reset * changelog * update chlog * typing Co-authored-by: Carlos Mocholí <carlossmocholi@gmail.com> Co-authored-by: Carlos Mocholí <carlossmocholi@gmail.com> Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>
This commit is contained in:
parent
a9a376019f
commit
7b42494d0e
22
CHANGELOG.md
22
CHANGELOG.md
|
@ -169,7 +169,27 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
|
|||
- Fixed loading yaml ([#5619](https://github.com/PyTorchLightning/pytorch-lightning/pull/5619))
|
||||
|
||||
|
||||
## [1.1.4] - YYYY-MM-DD
|
||||
## [unreleased.Bugfixes] - YYYY-MM-DD
|
||||
|
||||
### Added
|
||||
|
||||
|
||||
### Changed
|
||||
|
||||
|
||||
### Deprecated
|
||||
|
||||
|
||||
### Removed
|
||||
|
||||
|
||||
### Fixed
|
||||
|
||||
|
||||
- Fixed a visual bug in the progress bar display initialization ([#4579](https://github.com/PyTorchLightning/pytorch-lightning/pull/4579))
|
||||
|
||||
|
||||
## [1.1.4] - 2021-01-12
|
||||
|
||||
### Added
|
||||
|
||||
|
|
|
@ -24,6 +24,8 @@ import sys
|
|||
|
||||
# check if ipywidgets is installed before importing tqdm.auto
|
||||
# to ensure it won't fail and a progress bar is displayed
|
||||
from typing import Optional, Union
|
||||
|
||||
if importlib.util.find_spec('ipywidgets') is not None:
|
||||
from tqdm.auto import tqdm
|
||||
else:
|
||||
|
@ -308,7 +310,7 @@ class ProgressBar(ProgressBarBase):
|
|||
def on_sanity_check_start(self, trainer, pl_module):
|
||||
super().on_sanity_check_start(trainer, pl_module)
|
||||
self.val_progress_bar = self.init_sanity_tqdm()
|
||||
self.val_progress_bar.total = convert_inf(sum(trainer.num_sanity_val_batches))
|
||||
reset(self.val_progress_bar, sum(trainer.num_sanity_val_batches))
|
||||
self.main_progress_bar = tqdm(disable=True) # dummy progress bar
|
||||
|
||||
def on_sanity_check_end(self, trainer, pl_module):
|
||||
|
@ -329,8 +331,7 @@ class ProgressBar(ProgressBarBase):
|
|||
val_checks_per_epoch = total_train_batches // trainer.val_check_batch
|
||||
total_val_batches = total_val_batches * val_checks_per_epoch
|
||||
total_batches = total_train_batches + total_val_batches
|
||||
if not self.main_progress_bar.disable:
|
||||
self.main_progress_bar.reset(convert_inf(total_batches))
|
||||
reset(self.main_progress_bar, total_batches)
|
||||
self.main_progress_bar.set_description(f'Epoch {trainer.current_epoch}')
|
||||
|
||||
def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx):
|
||||
|
@ -344,7 +345,7 @@ class ProgressBar(ProgressBarBase):
|
|||
if not trainer.running_sanity_check:
|
||||
self._update_bar(self.main_progress_bar) # fill up remaining
|
||||
self.val_progress_bar = self.init_validation_tqdm()
|
||||
self.val_progress_bar.total = convert_inf(self.total_val_batches)
|
||||
reset(self.val_progress_bar, self.total_val_batches)
|
||||
|
||||
def on_validation_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx):
|
||||
super().on_validation_batch_end(trainer, pl_module, outputs, batch, batch_idx, dataloader_idx)
|
||||
|
@ -389,8 +390,14 @@ class ProgressBar(ProgressBarBase):
|
|||
bar.update(delta)
|
||||
|
||||
|
||||
def convert_inf(x):
|
||||
def convert_inf(x: Optional[Union[int, float]]) -> Optional[Union[int, float]]:
|
||||
""" The tqdm doesn't support inf values. We have to convert it to None. """
|
||||
if x == float('inf'):
|
||||
return None
|
||||
return x
|
||||
|
||||
|
||||
def reset(bar: tqdm, total: Optional[int] = None) -> None:
|
||||
""" Resets the tqdm bar to 0 progress with a new total, unless it is disabled. """
|
||||
if not bar.disable:
|
||||
bar.reset(total=convert_inf(total))
|
||||
|
|
Loading…
Reference in New Issue