Deprecate `ProgressBar` and rename it to `TQDMProgressBar` (#10134)
Co-authored-by: ananthsub <ananth.subramaniam@gmail.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
parent
edea0d4bc3
commit
45c45dc7b0
|
@ -209,6 +209,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
|
|||
- Deprecated `lr_sch_names` from `LearningRateMonitor` ([#10066](https://github.com/PyTorchLightning/pytorch-lightning/pull/10066))
|
||||
|
||||
|
||||
- Deprecated `ProgressBar` callback in favor of `TQDMProgressBar` ([#10134](https://github.com/PyTorchLightning/pytorch-lightning/pull/10134))
|
||||
|
||||
|
||||
### Removed
|
||||
|
||||
- Removed deprecated `metrics` ([#8586](https://github.com/PyTorchLightning/pytorch-lightning/pull/8586/))
|
||||
|
|
|
@ -1240,7 +1240,7 @@ See the :doc:`profiler documentation <../advanced/profiler>`. for more details.
|
|||
progress_bar_refresh_rate
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^
|
||||
``progress_bar_refresh_rate`` has been deprecated in v1.5 and will be removed in v1.7.
|
||||
Please pass :class:`~pytorch_lightning.callbacks.progress.ProgressBar` with ``refresh_rate``
|
||||
Please pass :class:`~pytorch_lightning.callbacks.progress.TQDMProgressBar` with ``refresh_rate``
|
||||
directly to the Trainer's ``callbacks`` argument instead. To disable the progress bar,
|
||||
pass ``enable_progress_bar = False`` to the Trainer.
|
||||
|
||||
|
|
|
@ -67,6 +67,7 @@ module = [
|
|||
"pytorch_lightning.callbacks.gradient_accumulation_scheduler",
|
||||
"pytorch_lightning.callbacks.lr_monitor",
|
||||
"pytorch_lightning.callbacks.model_summary",
|
||||
"pytorch_lightning.callbacks.progress",
|
||||
"pytorch_lightning.callbacks.pruning",
|
||||
"pytorch_lightning.callbacks.rich_model_summary",
|
||||
"pytorch_lightning.core.optimizer",
|
||||
|
|
|
@ -22,7 +22,7 @@ from pytorch_lightning.callbacks.lr_monitor import LearningRateMonitor
|
|||
from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint
|
||||
from pytorch_lightning.callbacks.model_summary import ModelSummary
|
||||
from pytorch_lightning.callbacks.prediction_writer import BasePredictionWriter
|
||||
from pytorch_lightning.callbacks.progress import ProgressBar, ProgressBarBase, RichProgressBar
|
||||
from pytorch_lightning.callbacks.progress import ProgressBar, ProgressBarBase, RichProgressBar, TQDMProgressBar
|
||||
from pytorch_lightning.callbacks.pruning import ModelPruning
|
||||
from pytorch_lightning.callbacks.quantization import QuantizationAwareTraining
|
||||
from pytorch_lightning.callbacks.rich_model_summary import RichModelSummary
|
||||
|
@ -52,4 +52,5 @@ __all__ = [
|
|||
"RichProgressBar",
|
||||
"StochasticWeightAveraging",
|
||||
"Timer",
|
||||
"TQDMProgressBar",
|
||||
]
|
||||
|
|
|
@ -19,5 +19,6 @@ Use or override one of the progress bar callbacks.
|
|||
|
||||
"""
|
||||
from pytorch_lightning.callbacks.progress.base import ProgressBarBase # noqa: F401
|
||||
from pytorch_lightning.callbacks.progress.progress import ProgressBar # noqa: F401
|
||||
from pytorch_lightning.callbacks.progress.rich_progress import RichProgressBar # noqa: F401
|
||||
from pytorch_lightning.callbacks.progress.tqdm_progress import ProgressBar # noqa: F401
|
||||
from pytorch_lightning.callbacks.progress.tqdm_progress import TQDMProgressBar # noqa: F401
|
||||
|
|
|
@ -0,0 +1,24 @@
|
|||
# Copyright The PyTorch Lightning team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from pytorch_lightning.callbacks.progress.tqdm_progress import TQDMProgressBar
|
||||
from pytorch_lightning.utilities import rank_zero_deprecation
|
||||
|
||||
|
||||
class ProgressBar(TQDMProgressBar):
|
||||
def __init__(self, *args, **kwargs) -> None:
|
||||
super().__init__(*args, **kwargs)
|
||||
rank_zero_deprecation(
|
||||
"`ProgressBar` has been deprecated in v1.5 and will be removed in v1.7."
|
||||
" It has been renamed to `TQDMProgressBar` instead."
|
||||
)
|
|
@ -54,7 +54,7 @@ class Tqdm(_tqdm):
|
|||
return n
|
||||
|
||||
|
||||
class ProgressBar(ProgressBarBase):
|
||||
class TQDMProgressBar(ProgressBarBase):
|
||||
r"""
|
||||
This is the default progress bar used by Lightning. It prints to ``stdout`` using the
|
||||
:mod:`tqdm` package and shows up to four different bars:
|
||||
|
@ -75,7 +75,7 @@ class ProgressBar(ProgressBarBase):
|
|||
|
||||
Example:
|
||||
|
||||
>>> class LitProgressBar(ProgressBar):
|
||||
>>> class LitProgressBar(TQDMProgressBar):
|
||||
... def init_validation_tqdm(self):
|
||||
... bar = super().init_validation_tqdm()
|
||||
... bar.set_description('running validation ...')
|
||||
|
|
|
@ -20,9 +20,9 @@ from pytorch_lightning.callbacks import (
|
|||
GradientAccumulationScheduler,
|
||||
ModelCheckpoint,
|
||||
ModelSummary,
|
||||
ProgressBar,
|
||||
ProgressBarBase,
|
||||
RichProgressBar,
|
||||
TQDMProgressBar,
|
||||
)
|
||||
from pytorch_lightning.callbacks.rich_model_summary import RichModelSummary
|
||||
from pytorch_lightning.callbacks.timer import Timer
|
||||
|
@ -82,14 +82,14 @@ class CallbackConnector:
|
|||
if process_position != 0:
|
||||
rank_zero_deprecation(
|
||||
f"Setting `Trainer(process_position={process_position})` is deprecated in v1.5 and will be removed"
|
||||
" in v1.7. Please pass `pytorch_lightning.callbacks.progress.ProgressBar` with"
|
||||
" in v1.7. Please pass `pytorch_lightning.callbacks.progress.TQDMProgressBar` with"
|
||||
" `process_position` directly to the Trainer's `callbacks` argument instead."
|
||||
)
|
||||
|
||||
if progress_bar_refresh_rate is not None:
|
||||
rank_zero_deprecation(
|
||||
f"Setting `Trainer(progress_bar_refresh_rate={progress_bar_refresh_rate})` is deprecated in v1.5 and"
|
||||
" will be removed in v1.7. Please pass `pytorch_lightning.callbacks.progress.ProgressBar` with"
|
||||
" will be removed in v1.7. Please pass `pytorch_lightning.callbacks.progress.TQDMProgressBar` with"
|
||||
" `refresh_rate` directly to the Trainer's `callbacks` argument instead. Or, to disable the progress"
|
||||
" bar pass `enable_progress_bar = False` to the Trainer."
|
||||
)
|
||||
|
@ -230,7 +230,7 @@ class CallbackConnector:
|
|||
if len(progress_bars) == 1:
|
||||
progress_bar_callback = progress_bars[0]
|
||||
elif refresh_rate > 0:
|
||||
progress_bar_callback = ProgressBar(refresh_rate=refresh_rate, process_position=process_position)
|
||||
progress_bar_callback = TQDMProgressBar(refresh_rate=refresh_rate, process_position=process_position)
|
||||
self.trainer.callbacks.append(progress_bar_callback)
|
||||
else:
|
||||
progress_bar_callback = None
|
||||
|
|
|
@ -293,7 +293,7 @@ class Trainer(
|
|||
|
||||
.. deprecated:: v1.5
|
||||
``process_position`` has been deprecated in v1.5 and will be removed in v1.7.
|
||||
Please pass :class:`~pytorch_lightning.callbacks.progress.ProgressBar` with ``process_position``
|
||||
Please pass :class:`~pytorch_lightning.callbacks.progress.TQDMProgressBar` with ``process_position``
|
||||
directly to the Trainer's ``callbacks`` argument instead.
|
||||
|
||||
progress_bar_refresh_rate: How often to refresh progress bar (in steps). Value ``0`` disables progress bar.
|
||||
|
@ -302,7 +302,7 @@ class Trainer(
|
|||
|
||||
.. deprecated:: v1.5
|
||||
``progress_bar_refresh_rate`` has been deprecated in v1.5 and will be removed in v1.7.
|
||||
Please pass :class:`~pytorch_lightning.callbacks.progress.ProgressBar` with ``refresh_rate``
|
||||
Please pass :class:`~pytorch_lightning.callbacks.progress.TQDMProgressBar` with ``refresh_rate``
|
||||
directly to the Trainer's ``callbacks`` argument instead. To disable the progress bar,
|
||||
pass ``enable_progress_bar = False`` to the Trainer.
|
||||
|
||||
|
|
|
@ -23,7 +23,7 @@ import torch
|
|||
from torch.utils.data.dataloader import DataLoader
|
||||
|
||||
from pytorch_lightning import Trainer
|
||||
from pytorch_lightning.callbacks import ModelCheckpoint, ProgressBar, ProgressBarBase
|
||||
from pytorch_lightning.callbacks import ModelCheckpoint, ProgressBarBase, TQDMProgressBar
|
||||
from pytorch_lightning.callbacks.progress.tqdm_progress import Tqdm
|
||||
from pytorch_lightning.core.lightning import LightningModule
|
||||
from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
||||
|
@ -37,12 +37,12 @@ from tests.helpers.runif import RunIf
|
|||
([], None),
|
||||
([], 1),
|
||||
([], 2),
|
||||
([ProgressBar(refresh_rate=1)], 0),
|
||||
([ProgressBar(refresh_rate=2)], 0),
|
||||
([ProgressBar(refresh_rate=2)], 1),
|
||||
([TQDMProgressBar(refresh_rate=1)], 0),
|
||||
([TQDMProgressBar(refresh_rate=2)], 0),
|
||||
([TQDMProgressBar(refresh_rate=2)], 1),
|
||||
],
|
||||
)
|
||||
def test_progress_bar_on(tmpdir, callbacks: list, refresh_rate: Optional[int]):
|
||||
def test_tqdm_progress_bar_on(tmpdir, callbacks: list, refresh_rate: Optional[int]):
|
||||
"""Test different ways the progress bar can be turned on."""
|
||||
|
||||
trainer = Trainer(
|
||||
|
@ -63,7 +63,7 @@ def test_progress_bar_on(tmpdir, callbacks: list, refresh_rate: Optional[int]):
|
|||
"callbacks,refresh_rate,enable_progress_bar",
|
||||
[([], 0, True), ([], False, True), ([ModelCheckpoint(dirpath="../trainer")], 0, True), ([], 1, False)],
|
||||
)
|
||||
def test_progress_bar_off(tmpdir, callbacks: list, refresh_rate: Union[bool, int], enable_progress_bar: bool):
|
||||
def test_tqdm_progress_bar_off(tmpdir, callbacks: list, refresh_rate: Union[bool, int], enable_progress_bar: bool):
|
||||
"""Test different ways the progress bar can be turned off."""
|
||||
|
||||
trainer = Trainer(
|
||||
|
@ -73,19 +73,19 @@ def test_progress_bar_off(tmpdir, callbacks: list, refresh_rate: Union[bool, int
|
|||
enable_progress_bar=enable_progress_bar,
|
||||
)
|
||||
|
||||
progress_bars = [c for c in trainer.callbacks if isinstance(c, ProgressBar)]
|
||||
progress_bars = [c for c in trainer.callbacks if isinstance(c, TQDMProgressBar)]
|
||||
assert 0 == len(progress_bars)
|
||||
assert not trainer.progress_bar_callback
|
||||
|
||||
|
||||
def test_progress_bar_misconfiguration():
|
||||
def test_tqdm_progress_bar_misconfiguration():
|
||||
"""Test that Trainer doesn't accept multiple progress bars."""
|
||||
callbacks = [ProgressBar(), ProgressBar(), ModelCheckpoint(dirpath="../trainer")]
|
||||
callbacks = [TQDMProgressBar(), TQDMProgressBar(), ModelCheckpoint(dirpath="../trainer")]
|
||||
with pytest.raises(MisconfigurationException, match=r"^You added multiple progress bar callbacks"):
|
||||
Trainer(callbacks=callbacks)
|
||||
|
||||
|
||||
def test_progress_bar_totals(tmpdir):
|
||||
def test_tqdm_progress_bar_totals(tmpdir):
|
||||
"""Test that the progress finishes with the correct total steps processed."""
|
||||
|
||||
model = BoringModel()
|
||||
|
@ -138,7 +138,7 @@ def test_progress_bar_totals(tmpdir):
|
|||
assert bar.test_batch_idx == k
|
||||
|
||||
|
||||
def test_progress_bar_fast_dev_run(tmpdir):
|
||||
def test_tqdm_progress_bar_fast_dev_run(tmpdir):
|
||||
model = BoringModel()
|
||||
|
||||
trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True)
|
||||
|
@ -174,12 +174,12 @@ def test_progress_bar_fast_dev_run(tmpdir):
|
|||
|
||||
|
||||
@pytest.mark.parametrize("refresh_rate", [0, 1, 50])
|
||||
def test_progress_bar_progress_refresh(tmpdir, refresh_rate: int):
|
||||
def test_tqdm_progress_bar_progress_refresh(tmpdir, refresh_rate: int):
|
||||
"""Test that the three progress bars get correctly updated when using different refresh rates."""
|
||||
|
||||
model = BoringModel()
|
||||
|
||||
class CurrentProgressBar(ProgressBar):
|
||||
class CurrentProgressBar(TQDMProgressBar):
|
||||
|
||||
train_batches_seen = 0
|
||||
val_batches_seen = 0
|
||||
|
@ -239,7 +239,7 @@ def test_progress_bar_progress_refresh(tmpdir, refresh_rate: int):
|
|||
def test_num_sanity_val_steps_progress_bar(tmpdir, limit_val_batches: int):
|
||||
"""Test val_progress_bar total with 'num_sanity_val_steps' Trainer argument."""
|
||||
|
||||
class CurrentProgressBar(ProgressBar):
|
||||
class CurrentProgressBar(TQDMProgressBar):
|
||||
val_pbar_total = 0
|
||||
sanity_pbar_total = 0
|
||||
|
||||
|
@ -271,7 +271,7 @@ def test_num_sanity_val_steps_progress_bar(tmpdir, limit_val_batches: int):
|
|||
assert progress_bar.val_pbar_total == limit_val_batches
|
||||
|
||||
|
||||
def test_progress_bar_default_value(tmpdir):
|
||||
def test_tqdm_progress_bar_default_value(tmpdir):
|
||||
"""Test that a value of None defaults to refresh rate 1."""
|
||||
trainer = Trainer(default_root_dir=tmpdir)
|
||||
assert trainer.progress_bar_callback.refresh_rate == 1
|
||||
|
@ -281,7 +281,7 @@ def test_progress_bar_default_value(tmpdir):
|
|||
|
||||
|
||||
@mock.patch.dict(os.environ, {"COLAB_GPU": "1"})
|
||||
def test_progress_bar_value_on_colab(tmpdir):
|
||||
def test_tqdm_progress_bar_value_on_colab(tmpdir):
|
||||
"""Test that Trainer will override the default in Google COLAB."""
|
||||
trainer = Trainer(default_root_dir=tmpdir)
|
||||
assert trainer.progress_bar_callback.refresh_rate == 20
|
||||
|
@ -293,7 +293,7 @@ def test_progress_bar_value_on_colab(tmpdir):
|
|||
assert trainer.progress_bar_callback.refresh_rate == 19
|
||||
|
||||
|
||||
class MockedUpdateProgressBars(ProgressBar):
|
||||
class MockedUpdateProgressBars(TQDMProgressBar):
|
||||
"""Mocks the update method once bars get initializied."""
|
||||
|
||||
def _mock_bar_update(self, bar):
|
||||
|
@ -428,10 +428,10 @@ class PrintModel(BoringModel):
|
|||
|
||||
|
||||
@mock.patch("pytorch_lightning.callbacks.progress.tqdm_progress.Tqdm.write")
|
||||
def test_progress_bar_print(tqdm_write, tmpdir):
|
||||
def test_tqdm_progress_bar_print(tqdm_write, tmpdir):
|
||||
"""Test that printing in the LightningModule redirects arguments to the progress bar."""
|
||||
model = PrintModel()
|
||||
bar = ProgressBar()
|
||||
bar = TQDMProgressBar()
|
||||
trainer = Trainer(
|
||||
default_root_dir=tmpdir,
|
||||
num_sanity_val_steps=0,
|
||||
|
@ -455,10 +455,10 @@ def test_progress_bar_print(tqdm_write, tmpdir):
|
|||
|
||||
|
||||
@mock.patch("pytorch_lightning.callbacks.progress.tqdm_progress.Tqdm.write")
|
||||
def test_progress_bar_print_no_train(tqdm_write, tmpdir):
|
||||
def test_tqdm_progress_bar_print_no_train(tqdm_write, tmpdir):
|
||||
"""Test that printing in the LightningModule redirects arguments to the progress bar without training."""
|
||||
model = PrintModel()
|
||||
bar = ProgressBar()
|
||||
bar = TQDMProgressBar()
|
||||
trainer = Trainer(
|
||||
default_root_dir=tmpdir,
|
||||
num_sanity_val_steps=0,
|
||||
|
@ -482,10 +482,10 @@ def test_progress_bar_print_no_train(tqdm_write, tmpdir):
|
|||
|
||||
@mock.patch("builtins.print")
|
||||
@mock.patch("pytorch_lightning.callbacks.progress.tqdm_progress.Tqdm.write")
|
||||
def test_progress_bar_print_disabled(tqdm_write, mock_print, tmpdir):
|
||||
def test_tqdm_progress_bar_print_disabled(tqdm_write, mock_print, tmpdir):
|
||||
"""Test that printing in LightningModule goes through built-in print function when progress bar is disabled."""
|
||||
model = PrintModel()
|
||||
bar = ProgressBar()
|
||||
bar = TQDMProgressBar()
|
||||
trainer = Trainer(
|
||||
default_root_dir=tmpdir,
|
||||
num_sanity_val_steps=0,
|
||||
|
@ -507,8 +507,8 @@ def test_progress_bar_print_disabled(tqdm_write, mock_print, tmpdir):
|
|||
tqdm_write.assert_not_called()
|
||||
|
||||
|
||||
def test_progress_bar_can_be_pickled():
|
||||
bar = ProgressBar()
|
||||
def test_tqdm_progress_bar_can_be_pickled():
|
||||
bar = TQDMProgressBar()
|
||||
trainer = Trainer(fast_dev_run=True, callbacks=[bar], max_steps=1)
|
||||
model = BoringModel()
|
||||
|
||||
|
@ -522,14 +522,14 @@ def test_progress_bar_can_be_pickled():
|
|||
|
||||
|
||||
@RunIf(min_gpus=2, special=True)
|
||||
def test_progress_bar_max_val_check_interval_0(tmpdir):
|
||||
def test_tqdm_progress_bar_max_val_check_interval_0(tmpdir):
|
||||
_test_progress_bar_max_val_check_interval(
|
||||
tmpdir, total_train_samples=8, train_batch_size=4, total_val_samples=2, val_batch_size=1, val_check_interval=0.2
|
||||
)
|
||||
|
||||
|
||||
@RunIf(min_gpus=2, special=True)
|
||||
def test_progress_bar_max_val_check_interval_1(tmpdir):
|
||||
def test_tqdm_progress_bar_max_val_check_interval_1(tmpdir):
|
||||
_test_progress_bar_max_val_check_interval(
|
||||
tmpdir, total_train_samples=8, train_batch_size=4, total_val_samples=2, val_batch_size=1, val_check_interval=0.5
|
||||
)
|
||||
|
@ -567,7 +567,7 @@ def _test_progress_bar_max_val_check_interval(
|
|||
|
||||
|
||||
def test_get_progress_bar_metrics(tmpdir: str):
|
||||
class TestProgressBar(ProgressBar):
|
||||
class TestProgressBar(TQDMProgressBar):
|
||||
def get_metrics(self, trainer: Trainer, model: LightningModule):
|
||||
items = super().get_metrics(trainer, model)
|
||||
items.pop("v_num", None)
|
||||
|
@ -588,9 +588,9 @@ def test_get_progress_bar_metrics(tmpdir: str):
|
|||
assert "v_num" not in standard_metrics.keys()
|
||||
|
||||
|
||||
def test_progress_bar_main_bar_resume():
|
||||
def test_tqdm_progress_bar_main_bar_resume():
|
||||
"""Test that the progress bar can resume its counters based on the Trainer state."""
|
||||
bar = ProgressBar()
|
||||
bar = TQDMProgressBar()
|
||||
trainer = Mock()
|
||||
model = Mock()
|
||||
|
|
@ -20,6 +20,7 @@ import torch
|
|||
from pytorch_lightning import Callback, LightningDataModule, Trainer
|
||||
from pytorch_lightning.callbacks.gpu_stats_monitor import GPUStatsMonitor
|
||||
from pytorch_lightning.callbacks.lr_monitor import LearningRateMonitor
|
||||
from pytorch_lightning.callbacks.progress import ProgressBar
|
||||
from pytorch_lightning.callbacks.xla_stats_monitor import XLAStatsMonitor
|
||||
from pytorch_lightning.loggers import LoggerCollection, TestTubeLogger
|
||||
from tests.callbacks.test_callbacks import OldStatefulCallback
|
||||
|
@ -391,6 +392,11 @@ def test_v1_7_0_deprecate_xla_stats_monitor(tmpdir):
|
|||
_ = XLAStatsMonitor()
|
||||
|
||||
|
||||
def test_v1_7_0_progress_bar():
|
||||
with pytest.deprecated_call(match="has been deprecated in v1.5 and will be removed in v1.7."):
|
||||
_ = ProgressBar()
|
||||
|
||||
|
||||
def test_v1_7_0_deprecated_max_steps_none(tmpdir):
|
||||
with pytest.deprecated_call(match="`max_steps = None` is deprecated in v1.5"):
|
||||
_ = Trainer(max_steps=None)
|
||||
|
|
|
@ -22,7 +22,7 @@ from pytorch_lightning.callbacks import (
|
|||
LearningRateMonitor,
|
||||
ModelCheckpoint,
|
||||
ModelSummary,
|
||||
ProgressBar,
|
||||
TQDMProgressBar,
|
||||
)
|
||||
from pytorch_lightning.trainer.connectors.callback_connector import CallbackConnector
|
||||
from tests.helpers import BoringModel
|
||||
|
@ -35,7 +35,7 @@ def test_checkpoint_callbacks_are_last(tmpdir):
|
|||
model_summary = ModelSummary()
|
||||
early_stopping = EarlyStopping()
|
||||
lr_monitor = LearningRateMonitor()
|
||||
progress_bar = ProgressBar()
|
||||
progress_bar = TQDMProgressBar()
|
||||
|
||||
# no model reference
|
||||
trainer = Trainer(callbacks=[checkpoint1, progress_bar, lr_monitor, model_summary, checkpoint2])
|
||||
|
@ -155,7 +155,7 @@ def test_attach_model_callbacks():
|
|||
return trainer
|
||||
|
||||
early_stopping = EarlyStopping()
|
||||
progress_bar = ProgressBar()
|
||||
progress_bar = TQDMProgressBar()
|
||||
lr_monitor = LearningRateMonitor()
|
||||
grad_accumulation = GradientAccumulationScheduler({1: 1})
|
||||
|
||||
|
@ -199,7 +199,7 @@ def test_attach_model_callbacks_override_info(caplog):
|
|||
"""Test that the logs contain the info about overriding callbacks returned by configure_callbacks."""
|
||||
model = LightningModule()
|
||||
model.configure_callbacks = lambda: [LearningRateMonitor(), EarlyStopping()]
|
||||
trainer = Trainer(enable_checkpointing=False, callbacks=[EarlyStopping(), LearningRateMonitor(), ProgressBar()])
|
||||
trainer = Trainer(enable_checkpointing=False, callbacks=[EarlyStopping(), LearningRateMonitor(), TQDMProgressBar()])
|
||||
trainer.model = model
|
||||
cb_connector = CallbackConnector(trainer)
|
||||
with caplog.at_level(logging.INFO):
|
||||
|
|
|
@ -24,7 +24,7 @@ from torch.utils.data import DataLoader
|
|||
from torchmetrics import Accuracy
|
||||
|
||||
from pytorch_lightning import callbacks, Trainer
|
||||
from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint, ProgressBar
|
||||
from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint, TQDMProgressBar
|
||||
from pytorch_lightning.core.lightning import LightningModule
|
||||
from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
||||
from tests.helpers.boring_model import BoringModel, RandomDataset, RandomDictDataset
|
||||
|
@ -481,7 +481,7 @@ def test_progress_bar_metrics_contains_values_on_train_epoch_end(tmpdir: str):
|
|||
)
|
||||
self.on_train_epoch_end_called = True
|
||||
|
||||
class TestProgressBar(ProgressBar):
|
||||
class TestProgressBar(TQDMProgressBar):
|
||||
def get_metrics(self, trainer: Trainer, model: LightningModule):
|
||||
items = super().get_metrics(trainer, model)
|
||||
items.pop("v_num", None)
|
||||
|
|
Loading…
Reference in New Issue