diff --git a/docs/source-pytorch/common/progress_bar.rst b/docs/source-pytorch/common/progress_bar.rst index f16846e8c9..e0c29fccdc 100644 --- a/docs/source-pytorch/common/progress_bar.rst +++ b/docs/source-pytorch/common/progress_bar.rst @@ -36,6 +36,14 @@ You can update ``refresh_rate`` (rate (number of batches) at which the progress trainer = Trainer(callbacks=[TQDMProgressBar(refresh_rate=10)]) +By default the training progress bar is reset (overwritten) at each new epoch. +If you wish for a new progress bar to be displayed at the end of every epoch, set +:paramref:`TQDMProgressBar.leave ` to ``True``. + +.. code-block:: python + + trainer = Trainer(callbacks=[TQDMProgressBar(leave=True)]) + If you want to customize the default :class:`~lightning.pytorch.callbacks.TQDMProgressBar` used by Lightning, you can override specific methods of the callback class and pass your custom implementation to the :class:`~lightning.pytorch.trainer.trainer.Trainer`. diff --git a/src/lightning/pytorch/CHANGELOG.md b/src/lightning/pytorch/CHANGELOG.md index a085f0adb6..3091e1ea9b 100644 --- a/src/lightning/pytorch/CHANGELOG.md +++ b/src/lightning/pytorch/CHANGELOG.md @@ -15,7 +15,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Added a flag `verbose` to the `seed_everything()` function ([#20108](https://github.com/Lightning-AI/pytorch-lightning/pull/20108)) -- +- The `TQDMProgressBar` now provides an option to retain prior training epoch bars ([#19578](https://github.com/Lightning-AI/pytorch-lightning/pull/19578)) ### Changed diff --git a/src/lightning/pytorch/callbacks/progress/tqdm_progress.py b/src/lightning/pytorch/callbacks/progress/tqdm_progress.py index bf9e238a01..b28de65c9b 100644 --- a/src/lightning/pytorch/callbacks/progress/tqdm_progress.py +++ b/src/lightning/pytorch/callbacks/progress/tqdm_progress.py @@ -99,12 +99,14 @@ class TQDMProgressBar(ProgressBar): together. This corresponds to :paramref:`~lightning.pytorch.trainer.trainer.Trainer.process_position` in the :class:`~lightning.pytorch.trainer.trainer.Trainer`. + leave: If set to ``True``, leaves the finished progress bar in the terminal at the end of the epoch. + Default: ``False`` """ BAR_FORMAT = "{l_bar}{bar}| {n_fmt}/{total_fmt} [{elapsed}<{remaining}, {rate_noinv_fmt}{postfix}]" - def __init__(self, refresh_rate: int = 1, process_position: int = 0): + def __init__(self, refresh_rate: int = 1, process_position: int = 0, leave: bool = False): super().__init__() self._refresh_rate = self._resolve_refresh_rate(refresh_rate) self._process_position = process_position @@ -113,6 +115,7 @@ class TQDMProgressBar(ProgressBar): self._val_progress_bar: Optional[_tqdm] = None self._test_progress_bar: Optional[_tqdm] = None self._predict_progress_bar: Optional[_tqdm] = None + self._leave = leave def __getstate__(self) -> Dict: # can't pickle the tqdm objects @@ -262,6 +265,8 @@ class TQDMProgressBar(ProgressBar): @override def on_train_epoch_start(self, trainer: "pl.Trainer", *_: Any) -> None: + if self._leave: + self.train_progress_bar = self.init_train_tqdm() self.train_progress_bar.reset(convert_inf(self.total_train_batches)) self.train_progress_bar.initial = 0 self.train_progress_bar.set_description(f"Epoch {trainer.current_epoch}") @@ -279,6 +284,8 @@ class TQDMProgressBar(ProgressBar): def on_train_epoch_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: if not self.train_progress_bar.disable: self.train_progress_bar.set_postfix(self.get_metrics(trainer, pl_module)) + if self._leave: + self.train_progress_bar.close() @override def on_train_end(self, *_: Any) -> None: diff --git a/tests/tests_pytorch/callbacks/progress/test_tqdm_progress_bar.py b/tests/tests_pytorch/callbacks/progress/test_tqdm_progress_bar.py index e98f9077f8..d5187d5a1e 100644 --- a/tests/tests_pytorch/callbacks/progress/test_tqdm_progress_bar.py +++ b/tests/tests_pytorch/callbacks/progress/test_tqdm_progress_bar.py @@ -18,7 +18,7 @@ import sys from collections import defaultdict from typing import Union from unittest import mock -from unittest.mock import ANY, PropertyMock, call +from unittest.mock import ANY, Mock, PropertyMock, call import pytest import torch @@ -783,3 +783,20 @@ def test_tqdm_progress_bar_disabled_when_not_rank_zero(is_global_zero): pbar.enable() trainer.test(model) assert pbar.is_disabled + + +@pytest.mark.parametrize("leave", [True, False]) +def test_tqdm_leave(leave, tmp_path): + pbar = TQDMProgressBar(leave=leave) + pbar.init_train_tqdm = Mock(wraps=pbar.init_train_tqdm) + model = BoringModel() + trainer = Trainer( + default_root_dir=tmp_path, + callbacks=[pbar], + max_epochs=3, + limit_train_batches=1, + limit_val_batches=1, + benchmark=True, + ) + trainer.fit(model) + assert pbar.init_train_tqdm.call_count == (4 if leave else 1)