Prevent flickering progress bar (#6009)
* add padding * fix * fix * Update pytorch_lightning/callbacks/progress.py Co-authored-by: Carlos Mocholí <carlossmocholi@gmail.com> * updated based on suggestion * changelog * add test * fix pep8 * resolve test * fix code format Co-authored-by: Carlos Mocholí <carlossmocholi@gmail.com> Co-authored-by: tchaton <thomas@grid.ai>
This commit is contained in:
parent
ad36c7b9ce
commit
68fd3086f1
|
@ -297,6 +297,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
|
|||
- Fixed missing `process_dataloader` call for `TPUSpawn` when in distributed mode ([#6015](https://github.com/PyTorchLightning/pytorch-lightning/pull/6015))
|
||||
|
||||
|
||||
- Fixed progress bar flickering by appending 0 to floats/strings ([#6009](https://github.com/PyTorchLightning/pytorch-lightning/pull/6009))
|
||||
|
||||
|
||||
- Fixed synchronization issues with TPU training ([#6027](https://github.com/PyTorchLightning/pytorch-lightning/pull/6027))
|
||||
|
||||
|
||||
|
|
|
@ -26,12 +26,37 @@ import sys
|
|||
from typing import Optional, Union
|
||||
|
||||
if importlib.util.find_spec('ipywidgets') is not None:
|
||||
from tqdm.auto import tqdm
|
||||
from tqdm.auto import tqdm as _tqdm
|
||||
else:
|
||||
from tqdm import tqdm
|
||||
from tqdm import tqdm as _tqdm
|
||||
|
||||
from pytorch_lightning.callbacks import Callback
|
||||
|
||||
_PAD_SIZE = 5
|
||||
|
||||
|
||||
class tqdm(_tqdm):
|
||||
"""
|
||||
Custom tqdm progressbar where we append 0 to floating points/strings to
|
||||
prevent the progress bar from flickering
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def format_num(n) -> str:
|
||||
""" Add additional padding to the formatted numbers """
|
||||
should_be_padded = isinstance(n, (float, str))
|
||||
if not isinstance(n, str):
|
||||
n = _tqdm.format_num(n)
|
||||
if should_be_padded and 'e' not in n:
|
||||
if '.' not in n and len(n) < _PAD_SIZE:
|
||||
try:
|
||||
_ = float(n)
|
||||
except ValueError:
|
||||
return n
|
||||
n += '.'
|
||||
n += "0" * (_PAD_SIZE - len(n))
|
||||
return n
|
||||
|
||||
|
||||
class ProgressBarBase(Callback):
|
||||
r"""
|
||||
|
|
|
@ -20,6 +20,7 @@ import torch
|
|||
|
||||
from pytorch_lightning import Trainer
|
||||
from pytorch_lightning.callbacks import ModelCheckpoint, ProgressBar, ProgressBarBase
|
||||
from pytorch_lightning.callbacks.progress import tqdm
|
||||
from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
||||
from tests.helpers import BoringModel
|
||||
|
||||
|
@ -371,3 +372,12 @@ def test_tensor_to_float_conversion(tmpdir):
|
|||
pbar = trainer.progress_bar_callback.main_progress_bar
|
||||
actual = str(pbar.postfix)
|
||||
assert actual.endswith("foo=0.123, bar={'baz': tensor([1])}")
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"input_num, expected", [[1, '1'], [1.0, '1.000'], [0.1, '0.100'], [1e-3, '0.001'], [1e-5, '1e-5'], ['1.0', '1.000'],
|
||||
['10000', '10000'], ['abc', 'abc']]
|
||||
)
|
||||
def test_tqdm_format_num(input_num, expected):
|
||||
""" Check that the specialized tqdm.format_num appends 0 to floats and strings """
|
||||
assert tqdm.format_num(input_num) == expected
|
||||
|
|
Loading…
Reference in New Issue