Add ability for TQDMProgressBar to retain prior epoch training bars (#19578)

Co-authored-by: awaelchli <aedu.waelchli@gmail.com>
This commit is contained in:
Jonas Tingeborn 2024-08-04 09:28:26 +02:00 committed by GitHub
parent 83ba2dfc17
commit e61eafa671
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 35 additions and 3 deletions

View File

@ -36,6 +36,14 @@ You can update ``refresh_rate`` (rate (number of batches) at which the progress
trainer = Trainer(callbacks=[TQDMProgressBar(refresh_rate=10)])
By default the training progress bar is reset (overwritten) at each new epoch.
If you wish for a new progress bar to be displayed at the end of every epoch, set
:paramref:`TQDMProgressBar.leave <lightning.pytorch.callbacks.TQDMProgressBar.leave>` to ``True``.
.. code-block:: python
trainer = Trainer(callbacks=[TQDMProgressBar(leave=True)])
If you want to customize the default :class:`~lightning.pytorch.callbacks.TQDMProgressBar` used by Lightning, you can override
specific methods of the callback class and pass your custom implementation to the :class:`~lightning.pytorch.trainer.trainer.Trainer`.

View File

@ -15,7 +15,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Added a flag `verbose` to the `seed_everything()` function ([#20108](https://github.com/Lightning-AI/pytorch-lightning/pull/20108))
-
- The `TQDMProgressBar` now provides an option to retain prior training epoch bars ([#19578](https://github.com/Lightning-AI/pytorch-lightning/pull/19578))
### Changed

View File

@ -99,12 +99,14 @@ class TQDMProgressBar(ProgressBar):
together. This corresponds to
:paramref:`~lightning.pytorch.trainer.trainer.Trainer.process_position` in the
:class:`~lightning.pytorch.trainer.trainer.Trainer`.
leave: If set to ``True``, leaves the finished progress bar in the terminal at the end of the epoch.
Default: ``False``
"""
BAR_FORMAT = "{l_bar}{bar}| {n_fmt}/{total_fmt} [{elapsed}<{remaining}, {rate_noinv_fmt}{postfix}]"
def __init__(self, refresh_rate: int = 1, process_position: int = 0):
def __init__(self, refresh_rate: int = 1, process_position: int = 0, leave: bool = False):
super().__init__()
self._refresh_rate = self._resolve_refresh_rate(refresh_rate)
self._process_position = process_position
@ -113,6 +115,7 @@ class TQDMProgressBar(ProgressBar):
self._val_progress_bar: Optional[_tqdm] = None
self._test_progress_bar: Optional[_tqdm] = None
self._predict_progress_bar: Optional[_tqdm] = None
self._leave = leave
def __getstate__(self) -> Dict:
# can't pickle the tqdm objects
@ -262,6 +265,8 @@ class TQDMProgressBar(ProgressBar):
@override
def on_train_epoch_start(self, trainer: "pl.Trainer", *_: Any) -> None:
if self._leave:
self.train_progress_bar = self.init_train_tqdm()
self.train_progress_bar.reset(convert_inf(self.total_train_batches))
self.train_progress_bar.initial = 0
self.train_progress_bar.set_description(f"Epoch {trainer.current_epoch}")
@ -279,6 +284,8 @@ class TQDMProgressBar(ProgressBar):
def on_train_epoch_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
if not self.train_progress_bar.disable:
self.train_progress_bar.set_postfix(self.get_metrics(trainer, pl_module))
if self._leave:
self.train_progress_bar.close()
@override
def on_train_end(self, *_: Any) -> None:

View File

@ -18,7 +18,7 @@ import sys
from collections import defaultdict
from typing import Union
from unittest import mock
from unittest.mock import ANY, PropertyMock, call
from unittest.mock import ANY, Mock, PropertyMock, call
import pytest
import torch
@ -783,3 +783,20 @@ def test_tqdm_progress_bar_disabled_when_not_rank_zero(is_global_zero):
pbar.enable()
trainer.test(model)
assert pbar.is_disabled
@pytest.mark.parametrize("leave", [True, False])
def test_tqdm_leave(leave, tmp_path):
pbar = TQDMProgressBar(leave=leave)
pbar.init_train_tqdm = Mock(wraps=pbar.init_train_tqdm)
model = BoringModel()
trainer = Trainer(
default_root_dir=tmp_path,
callbacks=[pbar],
max_epochs=3,
limit_train_batches=1,
limit_val_batches=1,
benchmark=True,
)
trainer.fit(model)
assert pbar.init_train_tqdm.call_count == (4 if leave else 1)