Deprecate `ProgressBar` and rename it to `TQDMProgressBar` (#10134)

Co-authored-by: ananthsub <ananth.subramaniam@gmail.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
Kaushik B 2021-11-01 17:12:21 +05:30 committed by GitHub
parent edea0d4bc3
commit 45c45dc7b0
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
13 changed files with 83 additions and 47 deletions

View File

@ -209,6 +209,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Deprecated `lr_sch_names` from `LearningRateMonitor` ([#10066](https://github.com/PyTorchLightning/pytorch-lightning/pull/10066))
- Deprecated `ProgressBar` callback in favor of `TQDMProgressBar` ([#10134](https://github.com/PyTorchLightning/pytorch-lightning/pull/10134))
### Removed
- Removed deprecated `metrics` ([#8586](https://github.com/PyTorchLightning/pytorch-lightning/pull/8586/))

View File

@ -1240,7 +1240,7 @@ See the :doc:`profiler documentation <../advanced/profiler>`. for more details.
progress_bar_refresh_rate
^^^^^^^^^^^^^^^^^^^^^^^^^
``progress_bar_refresh_rate`` has been deprecated in v1.5 and will be removed in v1.7.
Please pass :class:`~pytorch_lightning.callbacks.progress.ProgressBar` with ``refresh_rate``
Please pass :class:`~pytorch_lightning.callbacks.progress.TQDMProgressBar` with ``refresh_rate``
directly to the Trainer's ``callbacks`` argument instead. To disable the progress bar,
pass ``enable_progress_bar = False`` to the Trainer.

View File

@ -67,6 +67,7 @@ module = [
"pytorch_lightning.callbacks.gradient_accumulation_scheduler",
"pytorch_lightning.callbacks.lr_monitor",
"pytorch_lightning.callbacks.model_summary",
"pytorch_lightning.callbacks.progress",
"pytorch_lightning.callbacks.pruning",
"pytorch_lightning.callbacks.rich_model_summary",
"pytorch_lightning.core.optimizer",

View File

@ -22,7 +22,7 @@ from pytorch_lightning.callbacks.lr_monitor import LearningRateMonitor
from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint
from pytorch_lightning.callbacks.model_summary import ModelSummary
from pytorch_lightning.callbacks.prediction_writer import BasePredictionWriter
from pytorch_lightning.callbacks.progress import ProgressBar, ProgressBarBase, RichProgressBar
from pytorch_lightning.callbacks.progress import ProgressBar, ProgressBarBase, RichProgressBar, TQDMProgressBar
from pytorch_lightning.callbacks.pruning import ModelPruning
from pytorch_lightning.callbacks.quantization import QuantizationAwareTraining
from pytorch_lightning.callbacks.rich_model_summary import RichModelSummary
@ -52,4 +52,5 @@ __all__ = [
"RichProgressBar",
"StochasticWeightAveraging",
"Timer",
"TQDMProgressBar",
]

View File

@ -19,5 +19,6 @@ Use or override one of the progress bar callbacks.
"""
from pytorch_lightning.callbacks.progress.base import ProgressBarBase # noqa: F401
from pytorch_lightning.callbacks.progress.progress import ProgressBar # noqa: F401
from pytorch_lightning.callbacks.progress.rich_progress import RichProgressBar # noqa: F401
from pytorch_lightning.callbacks.progress.tqdm_progress import ProgressBar # noqa: F401
from pytorch_lightning.callbacks.progress.tqdm_progress import TQDMProgressBar # noqa: F401

View File

@ -0,0 +1,24 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from pytorch_lightning.callbacks.progress.tqdm_progress import TQDMProgressBar
from pytorch_lightning.utilities import rank_zero_deprecation
class ProgressBar(TQDMProgressBar):
def __init__(self, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)
rank_zero_deprecation(
"`ProgressBar` has been deprecated in v1.5 and will be removed in v1.7."
" It has been renamed to `TQDMProgressBar` instead."
)

View File

@ -54,7 +54,7 @@ class Tqdm(_tqdm):
return n
class ProgressBar(ProgressBarBase):
class TQDMProgressBar(ProgressBarBase):
r"""
This is the default progress bar used by Lightning. It prints to ``stdout`` using the
:mod:`tqdm` package and shows up to four different bars:
@ -75,7 +75,7 @@ class ProgressBar(ProgressBarBase):
Example:
>>> class LitProgressBar(ProgressBar):
>>> class LitProgressBar(TQDMProgressBar):
... def init_validation_tqdm(self):
... bar = super().init_validation_tqdm()
... bar.set_description('running validation ...')

View File

@ -20,9 +20,9 @@ from pytorch_lightning.callbacks import (
GradientAccumulationScheduler,
ModelCheckpoint,
ModelSummary,
ProgressBar,
ProgressBarBase,
RichProgressBar,
TQDMProgressBar,
)
from pytorch_lightning.callbacks.rich_model_summary import RichModelSummary
from pytorch_lightning.callbacks.timer import Timer
@ -82,14 +82,14 @@ class CallbackConnector:
if process_position != 0:
rank_zero_deprecation(
f"Setting `Trainer(process_position={process_position})` is deprecated in v1.5 and will be removed"
" in v1.7. Please pass `pytorch_lightning.callbacks.progress.ProgressBar` with"
" in v1.7. Please pass `pytorch_lightning.callbacks.progress.TQDMProgressBar` with"
" `process_position` directly to the Trainer's `callbacks` argument instead."
)
if progress_bar_refresh_rate is not None:
rank_zero_deprecation(
f"Setting `Trainer(progress_bar_refresh_rate={progress_bar_refresh_rate})` is deprecated in v1.5 and"
" will be removed in v1.7. Please pass `pytorch_lightning.callbacks.progress.ProgressBar` with"
" will be removed in v1.7. Please pass `pytorch_lightning.callbacks.progress.TQDMProgressBar` with"
" `refresh_rate` directly to the Trainer's `callbacks` argument instead. Or, to disable the progress"
" bar pass `enable_progress_bar = False` to the Trainer."
)
@ -230,7 +230,7 @@ class CallbackConnector:
if len(progress_bars) == 1:
progress_bar_callback = progress_bars[0]
elif refresh_rate > 0:
progress_bar_callback = ProgressBar(refresh_rate=refresh_rate, process_position=process_position)
progress_bar_callback = TQDMProgressBar(refresh_rate=refresh_rate, process_position=process_position)
self.trainer.callbacks.append(progress_bar_callback)
else:
progress_bar_callback = None

View File

@ -293,7 +293,7 @@ class Trainer(
.. deprecated:: v1.5
``process_position`` has been deprecated in v1.5 and will be removed in v1.7.
Please pass :class:`~pytorch_lightning.callbacks.progress.ProgressBar` with ``process_position``
Please pass :class:`~pytorch_lightning.callbacks.progress.TQDMProgressBar` with ``process_position``
directly to the Trainer's ``callbacks`` argument instead.
progress_bar_refresh_rate: How often to refresh progress bar (in steps). Value ``0`` disables progress bar.
@ -302,7 +302,7 @@ class Trainer(
.. deprecated:: v1.5
``progress_bar_refresh_rate`` has been deprecated in v1.5 and will be removed in v1.7.
Please pass :class:`~pytorch_lightning.callbacks.progress.ProgressBar` with ``refresh_rate``
Please pass :class:`~pytorch_lightning.callbacks.progress.TQDMProgressBar` with ``refresh_rate``
directly to the Trainer's ``callbacks`` argument instead. To disable the progress bar,
pass ``enable_progress_bar = False`` to the Trainer.

View File

@ -23,7 +23,7 @@ import torch
from torch.utils.data.dataloader import DataLoader
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import ModelCheckpoint, ProgressBar, ProgressBarBase
from pytorch_lightning.callbacks import ModelCheckpoint, ProgressBarBase, TQDMProgressBar
from pytorch_lightning.callbacks.progress.tqdm_progress import Tqdm
from pytorch_lightning.core.lightning import LightningModule
from pytorch_lightning.utilities.exceptions import MisconfigurationException
@ -37,12 +37,12 @@ from tests.helpers.runif import RunIf
([], None),
([], 1),
([], 2),
([ProgressBar(refresh_rate=1)], 0),
([ProgressBar(refresh_rate=2)], 0),
([ProgressBar(refresh_rate=2)], 1),
([TQDMProgressBar(refresh_rate=1)], 0),
([TQDMProgressBar(refresh_rate=2)], 0),
([TQDMProgressBar(refresh_rate=2)], 1),
],
)
def test_progress_bar_on(tmpdir, callbacks: list, refresh_rate: Optional[int]):
def test_tqdm_progress_bar_on(tmpdir, callbacks: list, refresh_rate: Optional[int]):
"""Test different ways the progress bar can be turned on."""
trainer = Trainer(
@ -63,7 +63,7 @@ def test_progress_bar_on(tmpdir, callbacks: list, refresh_rate: Optional[int]):
"callbacks,refresh_rate,enable_progress_bar",
[([], 0, True), ([], False, True), ([ModelCheckpoint(dirpath="../trainer")], 0, True), ([], 1, False)],
)
def test_progress_bar_off(tmpdir, callbacks: list, refresh_rate: Union[bool, int], enable_progress_bar: bool):
def test_tqdm_progress_bar_off(tmpdir, callbacks: list, refresh_rate: Union[bool, int], enable_progress_bar: bool):
"""Test different ways the progress bar can be turned off."""
trainer = Trainer(
@ -73,19 +73,19 @@ def test_progress_bar_off(tmpdir, callbacks: list, refresh_rate: Union[bool, int
enable_progress_bar=enable_progress_bar,
)
progress_bars = [c for c in trainer.callbacks if isinstance(c, ProgressBar)]
progress_bars = [c for c in trainer.callbacks if isinstance(c, TQDMProgressBar)]
assert 0 == len(progress_bars)
assert not trainer.progress_bar_callback
def test_progress_bar_misconfiguration():
def test_tqdm_progress_bar_misconfiguration():
"""Test that Trainer doesn't accept multiple progress bars."""
callbacks = [ProgressBar(), ProgressBar(), ModelCheckpoint(dirpath="../trainer")]
callbacks = [TQDMProgressBar(), TQDMProgressBar(), ModelCheckpoint(dirpath="../trainer")]
with pytest.raises(MisconfigurationException, match=r"^You added multiple progress bar callbacks"):
Trainer(callbacks=callbacks)
def test_progress_bar_totals(tmpdir):
def test_tqdm_progress_bar_totals(tmpdir):
"""Test that the progress finishes with the correct total steps processed."""
model = BoringModel()
@ -138,7 +138,7 @@ def test_progress_bar_totals(tmpdir):
assert bar.test_batch_idx == k
def test_progress_bar_fast_dev_run(tmpdir):
def test_tqdm_progress_bar_fast_dev_run(tmpdir):
model = BoringModel()
trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True)
@ -174,12 +174,12 @@ def test_progress_bar_fast_dev_run(tmpdir):
@pytest.mark.parametrize("refresh_rate", [0, 1, 50])
def test_progress_bar_progress_refresh(tmpdir, refresh_rate: int):
def test_tqdm_progress_bar_progress_refresh(tmpdir, refresh_rate: int):
"""Test that the three progress bars get correctly updated when using different refresh rates."""
model = BoringModel()
class CurrentProgressBar(ProgressBar):
class CurrentProgressBar(TQDMProgressBar):
train_batches_seen = 0
val_batches_seen = 0
@ -239,7 +239,7 @@ def test_progress_bar_progress_refresh(tmpdir, refresh_rate: int):
def test_num_sanity_val_steps_progress_bar(tmpdir, limit_val_batches: int):
"""Test val_progress_bar total with 'num_sanity_val_steps' Trainer argument."""
class CurrentProgressBar(ProgressBar):
class CurrentProgressBar(TQDMProgressBar):
val_pbar_total = 0
sanity_pbar_total = 0
@ -271,7 +271,7 @@ def test_num_sanity_val_steps_progress_bar(tmpdir, limit_val_batches: int):
assert progress_bar.val_pbar_total == limit_val_batches
def test_progress_bar_default_value(tmpdir):
def test_tqdm_progress_bar_default_value(tmpdir):
"""Test that a value of None defaults to refresh rate 1."""
trainer = Trainer(default_root_dir=tmpdir)
assert trainer.progress_bar_callback.refresh_rate == 1
@ -281,7 +281,7 @@ def test_progress_bar_default_value(tmpdir):
@mock.patch.dict(os.environ, {"COLAB_GPU": "1"})
def test_progress_bar_value_on_colab(tmpdir):
def test_tqdm_progress_bar_value_on_colab(tmpdir):
"""Test that Trainer will override the default in Google COLAB."""
trainer = Trainer(default_root_dir=tmpdir)
assert trainer.progress_bar_callback.refresh_rate == 20
@ -293,7 +293,7 @@ def test_progress_bar_value_on_colab(tmpdir):
assert trainer.progress_bar_callback.refresh_rate == 19
class MockedUpdateProgressBars(ProgressBar):
class MockedUpdateProgressBars(TQDMProgressBar):
"""Mocks the update method once bars get initializied."""
def _mock_bar_update(self, bar):
@ -428,10 +428,10 @@ class PrintModel(BoringModel):
@mock.patch("pytorch_lightning.callbacks.progress.tqdm_progress.Tqdm.write")
def test_progress_bar_print(tqdm_write, tmpdir):
def test_tqdm_progress_bar_print(tqdm_write, tmpdir):
"""Test that printing in the LightningModule redirects arguments to the progress bar."""
model = PrintModel()
bar = ProgressBar()
bar = TQDMProgressBar()
trainer = Trainer(
default_root_dir=tmpdir,
num_sanity_val_steps=0,
@ -455,10 +455,10 @@ def test_progress_bar_print(tqdm_write, tmpdir):
@mock.patch("pytorch_lightning.callbacks.progress.tqdm_progress.Tqdm.write")
def test_progress_bar_print_no_train(tqdm_write, tmpdir):
def test_tqdm_progress_bar_print_no_train(tqdm_write, tmpdir):
"""Test that printing in the LightningModule redirects arguments to the progress bar without training."""
model = PrintModel()
bar = ProgressBar()
bar = TQDMProgressBar()
trainer = Trainer(
default_root_dir=tmpdir,
num_sanity_val_steps=0,
@ -482,10 +482,10 @@ def test_progress_bar_print_no_train(tqdm_write, tmpdir):
@mock.patch("builtins.print")
@mock.patch("pytorch_lightning.callbacks.progress.tqdm_progress.Tqdm.write")
def test_progress_bar_print_disabled(tqdm_write, mock_print, tmpdir):
def test_tqdm_progress_bar_print_disabled(tqdm_write, mock_print, tmpdir):
"""Test that printing in LightningModule goes through built-in print function when progress bar is disabled."""
model = PrintModel()
bar = ProgressBar()
bar = TQDMProgressBar()
trainer = Trainer(
default_root_dir=tmpdir,
num_sanity_val_steps=0,
@ -507,8 +507,8 @@ def test_progress_bar_print_disabled(tqdm_write, mock_print, tmpdir):
tqdm_write.assert_not_called()
def test_progress_bar_can_be_pickled():
bar = ProgressBar()
def test_tqdm_progress_bar_can_be_pickled():
bar = TQDMProgressBar()
trainer = Trainer(fast_dev_run=True, callbacks=[bar], max_steps=1)
model = BoringModel()
@ -522,14 +522,14 @@ def test_progress_bar_can_be_pickled():
@RunIf(min_gpus=2, special=True)
def test_progress_bar_max_val_check_interval_0(tmpdir):
def test_tqdm_progress_bar_max_val_check_interval_0(tmpdir):
_test_progress_bar_max_val_check_interval(
tmpdir, total_train_samples=8, train_batch_size=4, total_val_samples=2, val_batch_size=1, val_check_interval=0.2
)
@RunIf(min_gpus=2, special=True)
def test_progress_bar_max_val_check_interval_1(tmpdir):
def test_tqdm_progress_bar_max_val_check_interval_1(tmpdir):
_test_progress_bar_max_val_check_interval(
tmpdir, total_train_samples=8, train_batch_size=4, total_val_samples=2, val_batch_size=1, val_check_interval=0.5
)
@ -567,7 +567,7 @@ def _test_progress_bar_max_val_check_interval(
def test_get_progress_bar_metrics(tmpdir: str):
class TestProgressBar(ProgressBar):
class TestProgressBar(TQDMProgressBar):
def get_metrics(self, trainer: Trainer, model: LightningModule):
items = super().get_metrics(trainer, model)
items.pop("v_num", None)
@ -588,9 +588,9 @@ def test_get_progress_bar_metrics(tmpdir: str):
assert "v_num" not in standard_metrics.keys()
def test_progress_bar_main_bar_resume():
def test_tqdm_progress_bar_main_bar_resume():
"""Test that the progress bar can resume its counters based on the Trainer state."""
bar = ProgressBar()
bar = TQDMProgressBar()
trainer = Mock()
model = Mock()

View File

@ -20,6 +20,7 @@ import torch
from pytorch_lightning import Callback, LightningDataModule, Trainer
from pytorch_lightning.callbacks.gpu_stats_monitor import GPUStatsMonitor
from pytorch_lightning.callbacks.lr_monitor import LearningRateMonitor
from pytorch_lightning.callbacks.progress import ProgressBar
from pytorch_lightning.callbacks.xla_stats_monitor import XLAStatsMonitor
from pytorch_lightning.loggers import LoggerCollection, TestTubeLogger
from tests.callbacks.test_callbacks import OldStatefulCallback
@ -391,6 +392,11 @@ def test_v1_7_0_deprecate_xla_stats_monitor(tmpdir):
_ = XLAStatsMonitor()
def test_v1_7_0_progress_bar():
with pytest.deprecated_call(match="has been deprecated in v1.5 and will be removed in v1.7."):
_ = ProgressBar()
def test_v1_7_0_deprecated_max_steps_none(tmpdir):
with pytest.deprecated_call(match="`max_steps = None` is deprecated in v1.5"):
_ = Trainer(max_steps=None)

View File

@ -22,7 +22,7 @@ from pytorch_lightning.callbacks import (
LearningRateMonitor,
ModelCheckpoint,
ModelSummary,
ProgressBar,
TQDMProgressBar,
)
from pytorch_lightning.trainer.connectors.callback_connector import CallbackConnector
from tests.helpers import BoringModel
@ -35,7 +35,7 @@ def test_checkpoint_callbacks_are_last(tmpdir):
model_summary = ModelSummary()
early_stopping = EarlyStopping()
lr_monitor = LearningRateMonitor()
progress_bar = ProgressBar()
progress_bar = TQDMProgressBar()
# no model reference
trainer = Trainer(callbacks=[checkpoint1, progress_bar, lr_monitor, model_summary, checkpoint2])
@ -155,7 +155,7 @@ def test_attach_model_callbacks():
return trainer
early_stopping = EarlyStopping()
progress_bar = ProgressBar()
progress_bar = TQDMProgressBar()
lr_monitor = LearningRateMonitor()
grad_accumulation = GradientAccumulationScheduler({1: 1})
@ -199,7 +199,7 @@ def test_attach_model_callbacks_override_info(caplog):
"""Test that the logs contain the info about overriding callbacks returned by configure_callbacks."""
model = LightningModule()
model.configure_callbacks = lambda: [LearningRateMonitor(), EarlyStopping()]
trainer = Trainer(enable_checkpointing=False, callbacks=[EarlyStopping(), LearningRateMonitor(), ProgressBar()])
trainer = Trainer(enable_checkpointing=False, callbacks=[EarlyStopping(), LearningRateMonitor(), TQDMProgressBar()])
trainer.model = model
cb_connector = CallbackConnector(trainer)
with caplog.at_level(logging.INFO):

View File

@ -24,7 +24,7 @@ from torch.utils.data import DataLoader
from torchmetrics import Accuracy
from pytorch_lightning import callbacks, Trainer
from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint, ProgressBar
from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint, TQDMProgressBar
from pytorch_lightning.core.lightning import LightningModule
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from tests.helpers.boring_model import BoringModel, RandomDataset, RandomDictDataset
@ -481,7 +481,7 @@ def test_progress_bar_metrics_contains_values_on_train_epoch_end(tmpdir: str):
)
self.on_train_epoch_end_called = True
class TestProgressBar(ProgressBar):
class TestProgressBar(TQDMProgressBar):
def get_metrics(self, trainer: Trainer, model: LightningModule):
items = super().get_metrics(trainer, model)
items.pop("v_num", None)