Prevent flickering progress bar (#6009)

* add padding

* fix

* fix

* Update pytorch_lightning/callbacks/progress.py

Co-authored-by: Carlos Mocholí <carlossmocholi@gmail.com>

* updated based on suggestion

* changelog

* add test

* fix pep8

* resolve test

* fix code format

Co-authored-by: Carlos Mocholí <carlossmocholi@gmail.com>
Co-authored-by: tchaton <thomas@grid.ai>
This commit is contained in:
Nicki Skafte 2021-02-17 20:01:51 +01:00 committed by GitHub
parent ad36c7b9ce
commit 68fd3086f1
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 40 additions and 2 deletions

View File

@ -297,6 +297,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Fixed missing `process_dataloader` call for `TPUSpawn` when in distributed mode ([#6015](https://github.com/PyTorchLightning/pytorch-lightning/pull/6015))
- Fixed progress bar flickering by appending 0 to floats/strings ([#6009](https://github.com/PyTorchLightning/pytorch-lightning/pull/6009))
- Fixed synchronization issues with TPU training ([#6027](https://github.com/PyTorchLightning/pytorch-lightning/pull/6027))

View File

@ -26,12 +26,37 @@ import sys
from typing import Optional, Union
if importlib.util.find_spec('ipywidgets') is not None:
from tqdm.auto import tqdm
from tqdm.auto import tqdm as _tqdm
else:
from tqdm import tqdm
from tqdm import tqdm as _tqdm
from pytorch_lightning.callbacks import Callback
_PAD_SIZE = 5
class tqdm(_tqdm):
"""
Custom tqdm progressbar where we append 0 to floating points/strings to
prevent the progress bar from flickering
"""
@staticmethod
def format_num(n) -> str:
""" Add additional padding to the formatted numbers """
should_be_padded = isinstance(n, (float, str))
if not isinstance(n, str):
n = _tqdm.format_num(n)
if should_be_padded and 'e' not in n:
if '.' not in n and len(n) < _PAD_SIZE:
try:
_ = float(n)
except ValueError:
return n
n += '.'
n += "0" * (_PAD_SIZE - len(n))
return n
class ProgressBarBase(Callback):
r"""

View File

@ -20,6 +20,7 @@ import torch
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import ModelCheckpoint, ProgressBar, ProgressBarBase
from pytorch_lightning.callbacks.progress import tqdm
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from tests.helpers import BoringModel
@ -371,3 +372,12 @@ def test_tensor_to_float_conversion(tmpdir):
pbar = trainer.progress_bar_callback.main_progress_bar
actual = str(pbar.postfix)
assert actual.endswith("foo=0.123, bar={'baz': tensor([1])}")
@pytest.mark.parametrize(
"input_num, expected", [[1, '1'], [1.0, '1.000'], [0.1, '0.100'], [1e-3, '0.001'], [1e-5, '1e-5'], ['1.0', '1.000'],
['10000', '10000'], ['abc', 'abc']]
)
def test_tqdm_format_num(input_num, expected):
""" Check that the specialized tqdm.format_num appends 0 to floats and strings """
assert tqdm.format_num(input_num) == expected