From 45c45dc7b018f9a2db60f5df1a3f7dbbb45ccb36 Mon Sep 17 00:00:00 2001 From: Kaushik B <45285388+kaushikb11@users.noreply.github.com> Date: Mon, 1 Nov 2021 17:12:21 +0530 Subject: [PATCH] Deprecate `ProgressBar` and rename it to `TQDMProgressBar` (#10134) Co-authored-by: ananthsub Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- CHANGELOG.md | 3 + docs/source/common/trainer.rst | 2 +- pyproject.toml | 1 + pytorch_lightning/callbacks/__init__.py | 3 +- .../callbacks/progress/__init__.py | 3 +- .../callbacks/progress/progress.py | 24 ++++++++ .../callbacks/progress/tqdm_progress.py | 4 +- .../trainer/connectors/callback_connector.py | 8 +-- pytorch_lightning/trainer/trainer.py | 4 +- ...gress_bar.py => test_tqdm_progress_bar.py} | 60 +++++++++---------- tests/deprecated_api/test_remove_1-7.py | 6 ++ .../connectors/test_callback_connector.py | 8 +-- .../logging_/test_train_loop_logging.py | 4 +- 13 files changed, 83 insertions(+), 47 deletions(-) create mode 100644 pytorch_lightning/callbacks/progress/progress.py rename tests/callbacks/{test_progress_bar.py => test_tqdm_progress_bar.py} (93%) diff --git a/CHANGELOG.md b/CHANGELOG.md index 16d7336de8..38f465bb5d 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -209,6 +209,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Deprecated `lr_sch_names` from `LearningRateMonitor` ([#10066](https://github.com/PyTorchLightning/pytorch-lightning/pull/10066)) +- Deprecated `ProgressBar` callback in favor of `TQDMProgressBar` ([#10134](https://github.com/PyTorchLightning/pytorch-lightning/pull/10134)) + + ### Removed - Removed deprecated `metrics` ([#8586](https://github.com/PyTorchLightning/pytorch-lightning/pull/8586/)) diff --git a/docs/source/common/trainer.rst b/docs/source/common/trainer.rst index 3d1f4aae5d..006e14b64d 100644 --- a/docs/source/common/trainer.rst +++ b/docs/source/common/trainer.rst @@ -1240,7 +1240,7 @@ See the :doc:`profiler documentation <../advanced/profiler>`. for more details. progress_bar_refresh_rate ^^^^^^^^^^^^^^^^^^^^^^^^^ ``progress_bar_refresh_rate`` has been deprecated in v1.5 and will be removed in v1.7. -Please pass :class:`~pytorch_lightning.callbacks.progress.ProgressBar` with ``refresh_rate`` +Please pass :class:`~pytorch_lightning.callbacks.progress.TQDMProgressBar` with ``refresh_rate`` directly to the Trainer's ``callbacks`` argument instead. To disable the progress bar, pass ``enable_progress_bar = False`` to the Trainer. diff --git a/pyproject.toml b/pyproject.toml index 98c0962da8..a2b83ae93e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -67,6 +67,7 @@ module = [ "pytorch_lightning.callbacks.gradient_accumulation_scheduler", "pytorch_lightning.callbacks.lr_monitor", "pytorch_lightning.callbacks.model_summary", + "pytorch_lightning.callbacks.progress", "pytorch_lightning.callbacks.pruning", "pytorch_lightning.callbacks.rich_model_summary", "pytorch_lightning.core.optimizer", diff --git a/pytorch_lightning/callbacks/__init__.py b/pytorch_lightning/callbacks/__init__.py index b94fa969f6..f47bc115ec 100644 --- a/pytorch_lightning/callbacks/__init__.py +++ b/pytorch_lightning/callbacks/__init__.py @@ -22,7 +22,7 @@ from pytorch_lightning.callbacks.lr_monitor import LearningRateMonitor from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint from pytorch_lightning.callbacks.model_summary import ModelSummary from pytorch_lightning.callbacks.prediction_writer import BasePredictionWriter -from pytorch_lightning.callbacks.progress import ProgressBar, ProgressBarBase, RichProgressBar +from pytorch_lightning.callbacks.progress import ProgressBar, ProgressBarBase, RichProgressBar, TQDMProgressBar from pytorch_lightning.callbacks.pruning import ModelPruning from pytorch_lightning.callbacks.quantization import QuantizationAwareTraining from pytorch_lightning.callbacks.rich_model_summary import RichModelSummary @@ -52,4 +52,5 @@ __all__ = [ "RichProgressBar", "StochasticWeightAveraging", "Timer", + "TQDMProgressBar", ] diff --git a/pytorch_lightning/callbacks/progress/__init__.py b/pytorch_lightning/callbacks/progress/__init__.py index 3fa7b1afe6..6ccc181b95 100644 --- a/pytorch_lightning/callbacks/progress/__init__.py +++ b/pytorch_lightning/callbacks/progress/__init__.py @@ -19,5 +19,6 @@ Use or override one of the progress bar callbacks. """ from pytorch_lightning.callbacks.progress.base import ProgressBarBase # noqa: F401 +from pytorch_lightning.callbacks.progress.progress import ProgressBar # noqa: F401 from pytorch_lightning.callbacks.progress.rich_progress import RichProgressBar # noqa: F401 -from pytorch_lightning.callbacks.progress.tqdm_progress import ProgressBar # noqa: F401 +from pytorch_lightning.callbacks.progress.tqdm_progress import TQDMProgressBar # noqa: F401 diff --git a/pytorch_lightning/callbacks/progress/progress.py b/pytorch_lightning/callbacks/progress/progress.py new file mode 100644 index 0000000000..e13e612805 --- /dev/null +++ b/pytorch_lightning/callbacks/progress/progress.py @@ -0,0 +1,24 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from pytorch_lightning.callbacks.progress.tqdm_progress import TQDMProgressBar +from pytorch_lightning.utilities import rank_zero_deprecation + + +class ProgressBar(TQDMProgressBar): + def __init__(self, *args, **kwargs) -> None: + super().__init__(*args, **kwargs) + rank_zero_deprecation( + "`ProgressBar` has been deprecated in v1.5 and will be removed in v1.7." + " It has been renamed to `TQDMProgressBar` instead." + ) diff --git a/pytorch_lightning/callbacks/progress/tqdm_progress.py b/pytorch_lightning/callbacks/progress/tqdm_progress.py index 7f3b902925..672d9d893a 100644 --- a/pytorch_lightning/callbacks/progress/tqdm_progress.py +++ b/pytorch_lightning/callbacks/progress/tqdm_progress.py @@ -54,7 +54,7 @@ class Tqdm(_tqdm): return n -class ProgressBar(ProgressBarBase): +class TQDMProgressBar(ProgressBarBase): r""" This is the default progress bar used by Lightning. It prints to ``stdout`` using the :mod:`tqdm` package and shows up to four different bars: @@ -75,7 +75,7 @@ class ProgressBar(ProgressBarBase): Example: - >>> class LitProgressBar(ProgressBar): + >>> class LitProgressBar(TQDMProgressBar): ... def init_validation_tqdm(self): ... bar = super().init_validation_tqdm() ... bar.set_description('running validation ...') diff --git a/pytorch_lightning/trainer/connectors/callback_connector.py b/pytorch_lightning/trainer/connectors/callback_connector.py index 2f63e65340..4d41734ed9 100644 --- a/pytorch_lightning/trainer/connectors/callback_connector.py +++ b/pytorch_lightning/trainer/connectors/callback_connector.py @@ -20,9 +20,9 @@ from pytorch_lightning.callbacks import ( GradientAccumulationScheduler, ModelCheckpoint, ModelSummary, - ProgressBar, ProgressBarBase, RichProgressBar, + TQDMProgressBar, ) from pytorch_lightning.callbacks.rich_model_summary import RichModelSummary from pytorch_lightning.callbacks.timer import Timer @@ -82,14 +82,14 @@ class CallbackConnector: if process_position != 0: rank_zero_deprecation( f"Setting `Trainer(process_position={process_position})` is deprecated in v1.5 and will be removed" - " in v1.7. Please pass `pytorch_lightning.callbacks.progress.ProgressBar` with" + " in v1.7. Please pass `pytorch_lightning.callbacks.progress.TQDMProgressBar` with" " `process_position` directly to the Trainer's `callbacks` argument instead." ) if progress_bar_refresh_rate is not None: rank_zero_deprecation( f"Setting `Trainer(progress_bar_refresh_rate={progress_bar_refresh_rate})` is deprecated in v1.5 and" - " will be removed in v1.7. Please pass `pytorch_lightning.callbacks.progress.ProgressBar` with" + " will be removed in v1.7. Please pass `pytorch_lightning.callbacks.progress.TQDMProgressBar` with" " `refresh_rate` directly to the Trainer's `callbacks` argument instead. Or, to disable the progress" " bar pass `enable_progress_bar = False` to the Trainer." ) @@ -230,7 +230,7 @@ class CallbackConnector: if len(progress_bars) == 1: progress_bar_callback = progress_bars[0] elif refresh_rate > 0: - progress_bar_callback = ProgressBar(refresh_rate=refresh_rate, process_position=process_position) + progress_bar_callback = TQDMProgressBar(refresh_rate=refresh_rate, process_position=process_position) self.trainer.callbacks.append(progress_bar_callback) else: progress_bar_callback = None diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 944739e1c3..a487d9b508 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -293,7 +293,7 @@ class Trainer( .. deprecated:: v1.5 ``process_position`` has been deprecated in v1.5 and will be removed in v1.7. - Please pass :class:`~pytorch_lightning.callbacks.progress.ProgressBar` with ``process_position`` + Please pass :class:`~pytorch_lightning.callbacks.progress.TQDMProgressBar` with ``process_position`` directly to the Trainer's ``callbacks`` argument instead. progress_bar_refresh_rate: How often to refresh progress bar (in steps). Value ``0`` disables progress bar. @@ -302,7 +302,7 @@ class Trainer( .. deprecated:: v1.5 ``progress_bar_refresh_rate`` has been deprecated in v1.5 and will be removed in v1.7. - Please pass :class:`~pytorch_lightning.callbacks.progress.ProgressBar` with ``refresh_rate`` + Please pass :class:`~pytorch_lightning.callbacks.progress.TQDMProgressBar` with ``refresh_rate`` directly to the Trainer's ``callbacks`` argument instead. To disable the progress bar, pass ``enable_progress_bar = False`` to the Trainer. diff --git a/tests/callbacks/test_progress_bar.py b/tests/callbacks/test_tqdm_progress_bar.py similarity index 93% rename from tests/callbacks/test_progress_bar.py rename to tests/callbacks/test_tqdm_progress_bar.py index 9cbf89b64f..b92fb18d54 100644 --- a/tests/callbacks/test_progress_bar.py +++ b/tests/callbacks/test_tqdm_progress_bar.py @@ -23,7 +23,7 @@ import torch from torch.utils.data.dataloader import DataLoader from pytorch_lightning import Trainer -from pytorch_lightning.callbacks import ModelCheckpoint, ProgressBar, ProgressBarBase +from pytorch_lightning.callbacks import ModelCheckpoint, ProgressBarBase, TQDMProgressBar from pytorch_lightning.callbacks.progress.tqdm_progress import Tqdm from pytorch_lightning.core.lightning import LightningModule from pytorch_lightning.utilities.exceptions import MisconfigurationException @@ -37,12 +37,12 @@ from tests.helpers.runif import RunIf ([], None), ([], 1), ([], 2), - ([ProgressBar(refresh_rate=1)], 0), - ([ProgressBar(refresh_rate=2)], 0), - ([ProgressBar(refresh_rate=2)], 1), + ([TQDMProgressBar(refresh_rate=1)], 0), + ([TQDMProgressBar(refresh_rate=2)], 0), + ([TQDMProgressBar(refresh_rate=2)], 1), ], ) -def test_progress_bar_on(tmpdir, callbacks: list, refresh_rate: Optional[int]): +def test_tqdm_progress_bar_on(tmpdir, callbacks: list, refresh_rate: Optional[int]): """Test different ways the progress bar can be turned on.""" trainer = Trainer( @@ -63,7 +63,7 @@ def test_progress_bar_on(tmpdir, callbacks: list, refresh_rate: Optional[int]): "callbacks,refresh_rate,enable_progress_bar", [([], 0, True), ([], False, True), ([ModelCheckpoint(dirpath="../trainer")], 0, True), ([], 1, False)], ) -def test_progress_bar_off(tmpdir, callbacks: list, refresh_rate: Union[bool, int], enable_progress_bar: bool): +def test_tqdm_progress_bar_off(tmpdir, callbacks: list, refresh_rate: Union[bool, int], enable_progress_bar: bool): """Test different ways the progress bar can be turned off.""" trainer = Trainer( @@ -73,19 +73,19 @@ def test_progress_bar_off(tmpdir, callbacks: list, refresh_rate: Union[bool, int enable_progress_bar=enable_progress_bar, ) - progress_bars = [c for c in trainer.callbacks if isinstance(c, ProgressBar)] + progress_bars = [c for c in trainer.callbacks if isinstance(c, TQDMProgressBar)] assert 0 == len(progress_bars) assert not trainer.progress_bar_callback -def test_progress_bar_misconfiguration(): +def test_tqdm_progress_bar_misconfiguration(): """Test that Trainer doesn't accept multiple progress bars.""" - callbacks = [ProgressBar(), ProgressBar(), ModelCheckpoint(dirpath="../trainer")] + callbacks = [TQDMProgressBar(), TQDMProgressBar(), ModelCheckpoint(dirpath="../trainer")] with pytest.raises(MisconfigurationException, match=r"^You added multiple progress bar callbacks"): Trainer(callbacks=callbacks) -def test_progress_bar_totals(tmpdir): +def test_tqdm_progress_bar_totals(tmpdir): """Test that the progress finishes with the correct total steps processed.""" model = BoringModel() @@ -138,7 +138,7 @@ def test_progress_bar_totals(tmpdir): assert bar.test_batch_idx == k -def test_progress_bar_fast_dev_run(tmpdir): +def test_tqdm_progress_bar_fast_dev_run(tmpdir): model = BoringModel() trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True) @@ -174,12 +174,12 @@ def test_progress_bar_fast_dev_run(tmpdir): @pytest.mark.parametrize("refresh_rate", [0, 1, 50]) -def test_progress_bar_progress_refresh(tmpdir, refresh_rate: int): +def test_tqdm_progress_bar_progress_refresh(tmpdir, refresh_rate: int): """Test that the three progress bars get correctly updated when using different refresh rates.""" model = BoringModel() - class CurrentProgressBar(ProgressBar): + class CurrentProgressBar(TQDMProgressBar): train_batches_seen = 0 val_batches_seen = 0 @@ -239,7 +239,7 @@ def test_progress_bar_progress_refresh(tmpdir, refresh_rate: int): def test_num_sanity_val_steps_progress_bar(tmpdir, limit_val_batches: int): """Test val_progress_bar total with 'num_sanity_val_steps' Trainer argument.""" - class CurrentProgressBar(ProgressBar): + class CurrentProgressBar(TQDMProgressBar): val_pbar_total = 0 sanity_pbar_total = 0 @@ -271,7 +271,7 @@ def test_num_sanity_val_steps_progress_bar(tmpdir, limit_val_batches: int): assert progress_bar.val_pbar_total == limit_val_batches -def test_progress_bar_default_value(tmpdir): +def test_tqdm_progress_bar_default_value(tmpdir): """Test that a value of None defaults to refresh rate 1.""" trainer = Trainer(default_root_dir=tmpdir) assert trainer.progress_bar_callback.refresh_rate == 1 @@ -281,7 +281,7 @@ def test_progress_bar_default_value(tmpdir): @mock.patch.dict(os.environ, {"COLAB_GPU": "1"}) -def test_progress_bar_value_on_colab(tmpdir): +def test_tqdm_progress_bar_value_on_colab(tmpdir): """Test that Trainer will override the default in Google COLAB.""" trainer = Trainer(default_root_dir=tmpdir) assert trainer.progress_bar_callback.refresh_rate == 20 @@ -293,7 +293,7 @@ def test_progress_bar_value_on_colab(tmpdir): assert trainer.progress_bar_callback.refresh_rate == 19 -class MockedUpdateProgressBars(ProgressBar): +class MockedUpdateProgressBars(TQDMProgressBar): """Mocks the update method once bars get initializied.""" def _mock_bar_update(self, bar): @@ -428,10 +428,10 @@ class PrintModel(BoringModel): @mock.patch("pytorch_lightning.callbacks.progress.tqdm_progress.Tqdm.write") -def test_progress_bar_print(tqdm_write, tmpdir): +def test_tqdm_progress_bar_print(tqdm_write, tmpdir): """Test that printing in the LightningModule redirects arguments to the progress bar.""" model = PrintModel() - bar = ProgressBar() + bar = TQDMProgressBar() trainer = Trainer( default_root_dir=tmpdir, num_sanity_val_steps=0, @@ -455,10 +455,10 @@ def test_progress_bar_print(tqdm_write, tmpdir): @mock.patch("pytorch_lightning.callbacks.progress.tqdm_progress.Tqdm.write") -def test_progress_bar_print_no_train(tqdm_write, tmpdir): +def test_tqdm_progress_bar_print_no_train(tqdm_write, tmpdir): """Test that printing in the LightningModule redirects arguments to the progress bar without training.""" model = PrintModel() - bar = ProgressBar() + bar = TQDMProgressBar() trainer = Trainer( default_root_dir=tmpdir, num_sanity_val_steps=0, @@ -482,10 +482,10 @@ def test_progress_bar_print_no_train(tqdm_write, tmpdir): @mock.patch("builtins.print") @mock.patch("pytorch_lightning.callbacks.progress.tqdm_progress.Tqdm.write") -def test_progress_bar_print_disabled(tqdm_write, mock_print, tmpdir): +def test_tqdm_progress_bar_print_disabled(tqdm_write, mock_print, tmpdir): """Test that printing in LightningModule goes through built-in print function when progress bar is disabled.""" model = PrintModel() - bar = ProgressBar() + bar = TQDMProgressBar() trainer = Trainer( default_root_dir=tmpdir, num_sanity_val_steps=0, @@ -507,8 +507,8 @@ def test_progress_bar_print_disabled(tqdm_write, mock_print, tmpdir): tqdm_write.assert_not_called() -def test_progress_bar_can_be_pickled(): - bar = ProgressBar() +def test_tqdm_progress_bar_can_be_pickled(): + bar = TQDMProgressBar() trainer = Trainer(fast_dev_run=True, callbacks=[bar], max_steps=1) model = BoringModel() @@ -522,14 +522,14 @@ def test_progress_bar_can_be_pickled(): @RunIf(min_gpus=2, special=True) -def test_progress_bar_max_val_check_interval_0(tmpdir): +def test_tqdm_progress_bar_max_val_check_interval_0(tmpdir): _test_progress_bar_max_val_check_interval( tmpdir, total_train_samples=8, train_batch_size=4, total_val_samples=2, val_batch_size=1, val_check_interval=0.2 ) @RunIf(min_gpus=2, special=True) -def test_progress_bar_max_val_check_interval_1(tmpdir): +def test_tqdm_progress_bar_max_val_check_interval_1(tmpdir): _test_progress_bar_max_val_check_interval( tmpdir, total_train_samples=8, train_batch_size=4, total_val_samples=2, val_batch_size=1, val_check_interval=0.5 ) @@ -567,7 +567,7 @@ def _test_progress_bar_max_val_check_interval( def test_get_progress_bar_metrics(tmpdir: str): - class TestProgressBar(ProgressBar): + class TestProgressBar(TQDMProgressBar): def get_metrics(self, trainer: Trainer, model: LightningModule): items = super().get_metrics(trainer, model) items.pop("v_num", None) @@ -588,9 +588,9 @@ def test_get_progress_bar_metrics(tmpdir: str): assert "v_num" not in standard_metrics.keys() -def test_progress_bar_main_bar_resume(): +def test_tqdm_progress_bar_main_bar_resume(): """Test that the progress bar can resume its counters based on the Trainer state.""" - bar = ProgressBar() + bar = TQDMProgressBar() trainer = Mock() model = Mock() diff --git a/tests/deprecated_api/test_remove_1-7.py b/tests/deprecated_api/test_remove_1-7.py index d3328d7a3a..16c511b6ef 100644 --- a/tests/deprecated_api/test_remove_1-7.py +++ b/tests/deprecated_api/test_remove_1-7.py @@ -20,6 +20,7 @@ import torch from pytorch_lightning import Callback, LightningDataModule, Trainer from pytorch_lightning.callbacks.gpu_stats_monitor import GPUStatsMonitor from pytorch_lightning.callbacks.lr_monitor import LearningRateMonitor +from pytorch_lightning.callbacks.progress import ProgressBar from pytorch_lightning.callbacks.xla_stats_monitor import XLAStatsMonitor from pytorch_lightning.loggers import LoggerCollection, TestTubeLogger from tests.callbacks.test_callbacks import OldStatefulCallback @@ -391,6 +392,11 @@ def test_v1_7_0_deprecate_xla_stats_monitor(tmpdir): _ = XLAStatsMonitor() +def test_v1_7_0_progress_bar(): + with pytest.deprecated_call(match="has been deprecated in v1.5 and will be removed in v1.7."): + _ = ProgressBar() + + def test_v1_7_0_deprecated_max_steps_none(tmpdir): with pytest.deprecated_call(match="`max_steps = None` is deprecated in v1.5"): _ = Trainer(max_steps=None) diff --git a/tests/trainer/connectors/test_callback_connector.py b/tests/trainer/connectors/test_callback_connector.py index 949723b022..7ec238acf5 100644 --- a/tests/trainer/connectors/test_callback_connector.py +++ b/tests/trainer/connectors/test_callback_connector.py @@ -22,7 +22,7 @@ from pytorch_lightning.callbacks import ( LearningRateMonitor, ModelCheckpoint, ModelSummary, - ProgressBar, + TQDMProgressBar, ) from pytorch_lightning.trainer.connectors.callback_connector import CallbackConnector from tests.helpers import BoringModel @@ -35,7 +35,7 @@ def test_checkpoint_callbacks_are_last(tmpdir): model_summary = ModelSummary() early_stopping = EarlyStopping() lr_monitor = LearningRateMonitor() - progress_bar = ProgressBar() + progress_bar = TQDMProgressBar() # no model reference trainer = Trainer(callbacks=[checkpoint1, progress_bar, lr_monitor, model_summary, checkpoint2]) @@ -155,7 +155,7 @@ def test_attach_model_callbacks(): return trainer early_stopping = EarlyStopping() - progress_bar = ProgressBar() + progress_bar = TQDMProgressBar() lr_monitor = LearningRateMonitor() grad_accumulation = GradientAccumulationScheduler({1: 1}) @@ -199,7 +199,7 @@ def test_attach_model_callbacks_override_info(caplog): """Test that the logs contain the info about overriding callbacks returned by configure_callbacks.""" model = LightningModule() model.configure_callbacks = lambda: [LearningRateMonitor(), EarlyStopping()] - trainer = Trainer(enable_checkpointing=False, callbacks=[EarlyStopping(), LearningRateMonitor(), ProgressBar()]) + trainer = Trainer(enable_checkpointing=False, callbacks=[EarlyStopping(), LearningRateMonitor(), TQDMProgressBar()]) trainer.model = model cb_connector = CallbackConnector(trainer) with caplog.at_level(logging.INFO): diff --git a/tests/trainer/logging_/test_train_loop_logging.py b/tests/trainer/logging_/test_train_loop_logging.py index 899dfb38eb..6cad940171 100644 --- a/tests/trainer/logging_/test_train_loop_logging.py +++ b/tests/trainer/logging_/test_train_loop_logging.py @@ -24,7 +24,7 @@ from torch.utils.data import DataLoader from torchmetrics import Accuracy from pytorch_lightning import callbacks, Trainer -from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint, ProgressBar +from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint, TQDMProgressBar from pytorch_lightning.core.lightning import LightningModule from pytorch_lightning.utilities.exceptions import MisconfigurationException from tests.helpers.boring_model import BoringModel, RandomDataset, RandomDictDataset @@ -481,7 +481,7 @@ def test_progress_bar_metrics_contains_values_on_train_epoch_end(tmpdir: str): ) self.on_train_epoch_end_called = True - class TestProgressBar(ProgressBar): + class TestProgressBar(TQDMProgressBar): def get_metrics(self, trainer: Trainer, model: LightningModule): items = super().get_metrics(trainer, model) items.pop("v_num", None)