diff --git a/CHANGELOG.md b/CHANGELOG.md index 7bda865002..a33aa60fde 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -124,6 +124,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Added `auto_device_count` method to `Accelerators` ([#10222](https://github.com/PyTorchLightning/pytorch-lightning/pull/10222)) - Added support for `devices="auto"` ([#10264](https://github.com/PyTorchLightning/pytorch-lightning/pull/10264)) - Added a `filename` argument in `ModelCheckpoint.format_checkpoint_name` ([#9818](https://github.com/PyTorchLightning/pytorch-lightning/pull/9818)) +- Added `configure_columns` method to `RichProgressBar` ([#10288](https://github.com/PyTorchLightning/pytorch-lightning/pull/10288)) ### Changed diff --git a/pytorch_lightning/callbacks/progress/rich_progress.py b/pytorch_lightning/callbacks/progress/rich_progress.py index 1fb5307368..b042adf897 100644 --- a/pytorch_lightning/callbacks/progress/rich_progress.py +++ b/pytorch_lightning/callbacks/progress/rich_progress.py @@ -265,16 +265,7 @@ class RichProgressBar(ProgressBarBase): self._reset_progress_bar_ids() self._console.clear_live() self.progress = CustomProgress( - TextColumn("[progress.description]{task.description}"), - CustomBarColumn( - complete_style=self.theme.progress_bar_complete, - finished_style=self.theme.progress_bar_finished, - pulse_style=self.theme.progress_bar_pulse, - ), - BatchesProcessedColumn(style=self.theme.batch_process), - CustomTimeColumn(style=self.theme.time), - ProcessingSpeedColumn(style=self.theme.processing_speed), - MetricsTextColumn(trainer, pl_module), + *self.configure_columns(trainer, pl_module), refresh_per_second=self.refresh_rate_per_second, disable=self.is_disabled, console=self._console, @@ -435,3 +426,17 @@ class RichProgressBar(ProgressBarBase): @property def test_progress_bar(self) -> Task: return self.progress.tasks[self.test_progress_bar_id] + + def configure_columns(self, trainer, pl_module) -> list: + return [ + TextColumn("[progress.description]{task.description}"), + CustomBarColumn( + complete_style=self.theme.progress_bar_complete, + finished_style=self.theme.progress_bar_finished, + pulse_style=self.theme.progress_bar_pulse, + ), + BatchesProcessedColumn(style=self.theme.batch_process), + CustomTimeColumn(style=self.theme.time), + ProcessingSpeedColumn(style=self.theme.processing_speed), + MetricsTextColumn(trainer, pl_module), + ] diff --git a/tests/callbacks/test_rich_progress_bar.py b/tests/callbacks/test_rich_progress_bar.py index 708cbafa4d..ab0852c472 100644 --- a/tests/callbacks/test_rich_progress_bar.py +++ b/tests/callbacks/test_rich_progress_bar.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. from unittest import mock -from unittest.mock import DEFAULT +from unittest.mock import DEFAULT, Mock import pytest from torch.utils.data import DataLoader @@ -141,3 +141,21 @@ def test_rich_progress_bar_keyboard_interrupt(tmpdir): trainer.fit(model) mock_progress_stop.assert_called_once() + + +@RunIf(rich=True) +def test_rich_progress_bar_configure_columns(tmpdir): + from rich.progress import TextColumn + + custom_column = TextColumn("[progress.description]Testing Rich!") + + class CustomRichProgressBar(RichProgressBar): + def configure_columns(self, trainer, pl_module): + return [custom_column] + + progress_bar = CustomRichProgressBar() + + progress_bar._init_progress(Mock(), Mock()) + + assert progress_bar.progress.columns[0] == custom_column + assert len(progress_bar.progress.columns) == 1