Add `configure_columns` method to RichProgressBar (#10288)
Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com>
This commit is contained in:
parent
6fd6283e07
commit
c52d7ba73d
|
@ -124,6 +124,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
|
|||
- Added `auto_device_count` method to `Accelerators` ([#10222](https://github.com/PyTorchLightning/pytorch-lightning/pull/10222))
|
||||
- Added support for `devices="auto"` ([#10264](https://github.com/PyTorchLightning/pytorch-lightning/pull/10264))
|
||||
- Added a `filename` argument in `ModelCheckpoint.format_checkpoint_name` ([#9818](https://github.com/PyTorchLightning/pytorch-lightning/pull/9818))
|
||||
- Added `configure_columns` method to `RichProgressBar` ([#10288](https://github.com/PyTorchLightning/pytorch-lightning/pull/10288))
|
||||
|
||||
|
||||
### Changed
|
||||
|
|
|
@ -265,16 +265,7 @@ class RichProgressBar(ProgressBarBase):
|
|||
self._reset_progress_bar_ids()
|
||||
self._console.clear_live()
|
||||
self.progress = CustomProgress(
|
||||
TextColumn("[progress.description]{task.description}"),
|
||||
CustomBarColumn(
|
||||
complete_style=self.theme.progress_bar_complete,
|
||||
finished_style=self.theme.progress_bar_finished,
|
||||
pulse_style=self.theme.progress_bar_pulse,
|
||||
),
|
||||
BatchesProcessedColumn(style=self.theme.batch_process),
|
||||
CustomTimeColumn(style=self.theme.time),
|
||||
ProcessingSpeedColumn(style=self.theme.processing_speed),
|
||||
MetricsTextColumn(trainer, pl_module),
|
||||
*self.configure_columns(trainer, pl_module),
|
||||
refresh_per_second=self.refresh_rate_per_second,
|
||||
disable=self.is_disabled,
|
||||
console=self._console,
|
||||
|
@ -435,3 +426,17 @@ class RichProgressBar(ProgressBarBase):
|
|||
@property
|
||||
def test_progress_bar(self) -> Task:
|
||||
return self.progress.tasks[self.test_progress_bar_id]
|
||||
|
||||
def configure_columns(self, trainer, pl_module) -> list:
|
||||
return [
|
||||
TextColumn("[progress.description]{task.description}"),
|
||||
CustomBarColumn(
|
||||
complete_style=self.theme.progress_bar_complete,
|
||||
finished_style=self.theme.progress_bar_finished,
|
||||
pulse_style=self.theme.progress_bar_pulse,
|
||||
),
|
||||
BatchesProcessedColumn(style=self.theme.batch_process),
|
||||
CustomTimeColumn(style=self.theme.time),
|
||||
ProcessingSpeedColumn(style=self.theme.processing_speed),
|
||||
MetricsTextColumn(trainer, pl_module),
|
||||
]
|
||||
|
|
|
@ -12,7 +12,7 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from unittest import mock
|
||||
from unittest.mock import DEFAULT
|
||||
from unittest.mock import DEFAULT, Mock
|
||||
|
||||
import pytest
|
||||
from torch.utils.data import DataLoader
|
||||
|
@ -141,3 +141,21 @@ def test_rich_progress_bar_keyboard_interrupt(tmpdir):
|
|||
|
||||
trainer.fit(model)
|
||||
mock_progress_stop.assert_called_once()
|
||||
|
||||
|
||||
@RunIf(rich=True)
|
||||
def test_rich_progress_bar_configure_columns(tmpdir):
|
||||
from rich.progress import TextColumn
|
||||
|
||||
custom_column = TextColumn("[progress.description]Testing Rich!")
|
||||
|
||||
class CustomRichProgressBar(RichProgressBar):
|
||||
def configure_columns(self, trainer, pl_module):
|
||||
return [custom_column]
|
||||
|
||||
progress_bar = CustomRichProgressBar()
|
||||
|
||||
progress_bar._init_progress(Mock(), Mock())
|
||||
|
||||
assert progress_bar.progress.columns[0] == custom_column
|
||||
assert len(progress_bar.progress.columns) == 1
|
||||
|
|
Loading…
Reference in New Issue