Add ability for TQDMProgressBar to retain prior epoch training bars (#19578)
Co-authored-by: awaelchli <aedu.waelchli@gmail.com>
This commit is contained in:
parent
83ba2dfc17
commit
e61eafa671
|
@ -36,6 +36,14 @@ You can update ``refresh_rate`` (rate (number of batches) at which the progress
|
|||
|
||||
trainer = Trainer(callbacks=[TQDMProgressBar(refresh_rate=10)])
|
||||
|
||||
By default the training progress bar is reset (overwritten) at each new epoch.
|
||||
If you wish for a new progress bar to be displayed at the end of every epoch, set
|
||||
:paramref:`TQDMProgressBar.leave <lightning.pytorch.callbacks.TQDMProgressBar.leave>` to ``True``.
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
trainer = Trainer(callbacks=[TQDMProgressBar(leave=True)])
|
||||
|
||||
If you want to customize the default :class:`~lightning.pytorch.callbacks.TQDMProgressBar` used by Lightning, you can override
|
||||
specific methods of the callback class and pass your custom implementation to the :class:`~lightning.pytorch.trainer.trainer.Trainer`.
|
||||
|
||||
|
|
|
@ -15,7 +15,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
|
|||
|
||||
- Added a flag `verbose` to the `seed_everything()` function ([#20108](https://github.com/Lightning-AI/pytorch-lightning/pull/20108))
|
||||
|
||||
-
|
||||
- The `TQDMProgressBar` now provides an option to retain prior training epoch bars ([#19578](https://github.com/Lightning-AI/pytorch-lightning/pull/19578))
|
||||
|
||||
### Changed
|
||||
|
||||
|
|
|
@ -99,12 +99,14 @@ class TQDMProgressBar(ProgressBar):
|
|||
together. This corresponds to
|
||||
:paramref:`~lightning.pytorch.trainer.trainer.Trainer.process_position` in the
|
||||
:class:`~lightning.pytorch.trainer.trainer.Trainer`.
|
||||
leave: If set to ``True``, leaves the finished progress bar in the terminal at the end of the epoch.
|
||||
Default: ``False``
|
||||
|
||||
"""
|
||||
|
||||
BAR_FORMAT = "{l_bar}{bar}| {n_fmt}/{total_fmt} [{elapsed}<{remaining}, {rate_noinv_fmt}{postfix}]"
|
||||
|
||||
def __init__(self, refresh_rate: int = 1, process_position: int = 0):
|
||||
def __init__(self, refresh_rate: int = 1, process_position: int = 0, leave: bool = False):
|
||||
super().__init__()
|
||||
self._refresh_rate = self._resolve_refresh_rate(refresh_rate)
|
||||
self._process_position = process_position
|
||||
|
@ -113,6 +115,7 @@ class TQDMProgressBar(ProgressBar):
|
|||
self._val_progress_bar: Optional[_tqdm] = None
|
||||
self._test_progress_bar: Optional[_tqdm] = None
|
||||
self._predict_progress_bar: Optional[_tqdm] = None
|
||||
self._leave = leave
|
||||
|
||||
def __getstate__(self) -> Dict:
|
||||
# can't pickle the tqdm objects
|
||||
|
@ -262,6 +265,8 @@ class TQDMProgressBar(ProgressBar):
|
|||
|
||||
@override
|
||||
def on_train_epoch_start(self, trainer: "pl.Trainer", *_: Any) -> None:
|
||||
if self._leave:
|
||||
self.train_progress_bar = self.init_train_tqdm()
|
||||
self.train_progress_bar.reset(convert_inf(self.total_train_batches))
|
||||
self.train_progress_bar.initial = 0
|
||||
self.train_progress_bar.set_description(f"Epoch {trainer.current_epoch}")
|
||||
|
@ -279,6 +284,8 @@ class TQDMProgressBar(ProgressBar):
|
|||
def on_train_epoch_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
|
||||
if not self.train_progress_bar.disable:
|
||||
self.train_progress_bar.set_postfix(self.get_metrics(trainer, pl_module))
|
||||
if self._leave:
|
||||
self.train_progress_bar.close()
|
||||
|
||||
@override
|
||||
def on_train_end(self, *_: Any) -> None:
|
||||
|
|
|
@ -18,7 +18,7 @@ import sys
|
|||
from collections import defaultdict
|
||||
from typing import Union
|
||||
from unittest import mock
|
||||
from unittest.mock import ANY, PropertyMock, call
|
||||
from unittest.mock import ANY, Mock, PropertyMock, call
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
@ -783,3 +783,20 @@ def test_tqdm_progress_bar_disabled_when_not_rank_zero(is_global_zero):
|
|||
pbar.enable()
|
||||
trainer.test(model)
|
||||
assert pbar.is_disabled
|
||||
|
||||
|
||||
@pytest.mark.parametrize("leave", [True, False])
|
||||
def test_tqdm_leave(leave, tmp_path):
|
||||
pbar = TQDMProgressBar(leave=leave)
|
||||
pbar.init_train_tqdm = Mock(wraps=pbar.init_train_tqdm)
|
||||
model = BoringModel()
|
||||
trainer = Trainer(
|
||||
default_root_dir=tmp_path,
|
||||
callbacks=[pbar],
|
||||
max_epochs=3,
|
||||
limit_train_batches=1,
|
||||
limit_val_batches=1,
|
||||
benchmark=True,
|
||||
)
|
||||
trainer.fit(model)
|
||||
assert pbar.init_train_tqdm.call_count == (4 if leave else 1)
|
||||
|
|
Loading…
Reference in New Issue