diff --git a/CHANGELOG.md b/CHANGELOG.md index 57293d6ea8..8ab18a66d3 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -297,6 +297,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fixed missing `process_dataloader` call for `TPUSpawn` when in distributed mode ([#6015](https://github.com/PyTorchLightning/pytorch-lightning/pull/6015)) +- Fixed progress bar flickering by appending 0 to floats/strings ([#6009](https://github.com/PyTorchLightning/pytorch-lightning/pull/6009)) + + - Fixed synchronization issues with TPU training ([#6027](https://github.com/PyTorchLightning/pytorch-lightning/pull/6027)) diff --git a/pytorch_lightning/callbacks/progress.py b/pytorch_lightning/callbacks/progress.py index 7de7982b4a..3f401669c3 100644 --- a/pytorch_lightning/callbacks/progress.py +++ b/pytorch_lightning/callbacks/progress.py @@ -26,12 +26,37 @@ import sys from typing import Optional, Union if importlib.util.find_spec('ipywidgets') is not None: - from tqdm.auto import tqdm + from tqdm.auto import tqdm as _tqdm else: - from tqdm import tqdm + from tqdm import tqdm as _tqdm from pytorch_lightning.callbacks import Callback +_PAD_SIZE = 5 + + +class tqdm(_tqdm): + """ + Custom tqdm progressbar where we append 0 to floating points/strings to + prevent the progress bar from flickering + """ + + @staticmethod + def format_num(n) -> str: + """ Add additional padding to the formatted numbers """ + should_be_padded = isinstance(n, (float, str)) + if not isinstance(n, str): + n = _tqdm.format_num(n) + if should_be_padded and 'e' not in n: + if '.' not in n and len(n) < _PAD_SIZE: + try: + _ = float(n) + except ValueError: + return n + n += '.' + n += "0" * (_PAD_SIZE - len(n)) + return n + class ProgressBarBase(Callback): r""" diff --git a/tests/callbacks/test_progress_bar.py b/tests/callbacks/test_progress_bar.py index 8398aec88f..9ec4800851 100644 --- a/tests/callbacks/test_progress_bar.py +++ b/tests/callbacks/test_progress_bar.py @@ -20,6 +20,7 @@ import torch from pytorch_lightning import Trainer from pytorch_lightning.callbacks import ModelCheckpoint, ProgressBar, ProgressBarBase +from pytorch_lightning.callbacks.progress import tqdm from pytorch_lightning.utilities.exceptions import MisconfigurationException from tests.helpers import BoringModel @@ -371,3 +372,12 @@ def test_tensor_to_float_conversion(tmpdir): pbar = trainer.progress_bar_callback.main_progress_bar actual = str(pbar.postfix) assert actual.endswith("foo=0.123, bar={'baz': tensor([1])}") + + +@pytest.mark.parametrize( + "input_num, expected", [[1, '1'], [1.0, '1.000'], [0.1, '0.100'], [1e-3, '0.001'], [1e-5, '1e-5'], ['1.0', '1.000'], + ['10000', '10000'], ['abc', 'abc']] +) +def test_tqdm_format_num(input_num, expected): + """ Check that the specialized tqdm.format_num appends 0 to floats and strings """ + assert tqdm.format_num(input_num) == expected