Add `configure_columns` method to RichProgressBar (#10288)

Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com>
This commit is contained in:
Kaushik B 2021-11-01 22:52:53 +05:30 committed by GitHub
parent 6fd6283e07
commit c52d7ba73d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 35 additions and 11 deletions

View File

@ -124,6 +124,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Added `auto_device_count` method to `Accelerators` ([#10222](https://github.com/PyTorchLightning/pytorch-lightning/pull/10222))
- Added support for `devices="auto"` ([#10264](https://github.com/PyTorchLightning/pytorch-lightning/pull/10264))
- Added a `filename` argument in `ModelCheckpoint.format_checkpoint_name` ([#9818](https://github.com/PyTorchLightning/pytorch-lightning/pull/9818))
- Added `configure_columns` method to `RichProgressBar` ([#10288](https://github.com/PyTorchLightning/pytorch-lightning/pull/10288))
### Changed

View File

@ -265,16 +265,7 @@ class RichProgressBar(ProgressBarBase):
self._reset_progress_bar_ids()
self._console.clear_live()
self.progress = CustomProgress(
TextColumn("[progress.description]{task.description}"),
CustomBarColumn(
complete_style=self.theme.progress_bar_complete,
finished_style=self.theme.progress_bar_finished,
pulse_style=self.theme.progress_bar_pulse,
),
BatchesProcessedColumn(style=self.theme.batch_process),
CustomTimeColumn(style=self.theme.time),
ProcessingSpeedColumn(style=self.theme.processing_speed),
MetricsTextColumn(trainer, pl_module),
*self.configure_columns(trainer, pl_module),
refresh_per_second=self.refresh_rate_per_second,
disable=self.is_disabled,
console=self._console,
@ -435,3 +426,17 @@ class RichProgressBar(ProgressBarBase):
@property
def test_progress_bar(self) -> Task:
return self.progress.tasks[self.test_progress_bar_id]
def configure_columns(self, trainer, pl_module) -> list:
return [
TextColumn("[progress.description]{task.description}"),
CustomBarColumn(
complete_style=self.theme.progress_bar_complete,
finished_style=self.theme.progress_bar_finished,
pulse_style=self.theme.progress_bar_pulse,
),
BatchesProcessedColumn(style=self.theme.batch_process),
CustomTimeColumn(style=self.theme.time),
ProcessingSpeedColumn(style=self.theme.processing_speed),
MetricsTextColumn(trainer, pl_module),
]

View File

@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from unittest import mock
from unittest.mock import DEFAULT
from unittest.mock import DEFAULT, Mock
import pytest
from torch.utils.data import DataLoader
@ -141,3 +141,21 @@ def test_rich_progress_bar_keyboard_interrupt(tmpdir):
trainer.fit(model)
mock_progress_stop.assert_called_once()
@RunIf(rich=True)
def test_rich_progress_bar_configure_columns(tmpdir):
from rich.progress import TextColumn
custom_column = TextColumn("[progress.description]Testing Rich!")
class CustomRichProgressBar(RichProgressBar):
def configure_columns(self, trainer, pl_module):
return [custom_column]
progress_bar = CustomRichProgressBar()
progress_bar._init_progress(Mock(), Mock())
assert progress_bar.progress.columns[0] == custom_column
assert len(progress_bar.progress.columns) == 1