Update RichProgressBarTheme after detecting light theme on colab (#10993)

This commit is contained in:
Kaushik B 2022-02-22 11:02:27 +05:30 committed by GitHub
parent 11bd176d2f
commit 6ff38d4c8e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 48 additions and 6 deletions

View File

@ -621,6 +621,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Configure native Deepspeed schedulers with interval='step' ([#11788](https://github.com/PyTorchLightning/pytorch-lightning/pull/11788))
- Update `RichProgressBarTheme` styles after detecting light theme on colab ([#10993](https://github.com/PyTorchLightning/pytorch-lightning/pull/10993))
## [1.5.10] - 2022-02-08
### Fixed

View File

@ -248,6 +248,7 @@ class RichProgressBar(ProgressBarBase):
self._metric_component = None
self._progress_stopped: bool = False
self.theme = theme
self._update_for_light_colab_theme()
@property
def refresh_rate(self) -> float:
@ -261,12 +262,6 @@ class RichProgressBar(ProgressBarBase):
def is_disabled(self) -> bool:
return not self.is_enabled
def disable(self) -> None:
self._enabled = False
def enable(self) -> None:
self._enabled = True
@property
def sanity_check_description(self) -> str:
return "Validation Sanity Check"
@ -283,6 +278,19 @@ class RichProgressBar(ProgressBarBase):
def predict_description(self) -> str:
return "Predicting"
def _update_for_light_colab_theme(self) -> None:
if _detect_light_colab_theme():
attributes = ["description", "batch_progress", "metrics"]
for attr in attributes:
if getattr(self.theme, attr) == "white":
setattr(self.theme, attr, "black")
def disable(self) -> None:
self._enabled = False
def enable(self) -> None:
self._enabled = True
def _init_progress(self, trainer):
if self.is_enabled and (self.progress is None or self._progress_stopped):
self._reset_progress_bar_ids()
@ -476,3 +484,20 @@ class RichProgressBar(ProgressBarBase):
CustomTimeColumn(style=self.theme.time),
ProcessingSpeedColumn(style=self.theme.processing_speed),
]
def _detect_light_colab_theme() -> bool:
"""Detect if it's light theme in Colab."""
try:
get_ipython # type: ignore
except NameError:
return False
ipython = get_ipython() # noqa: F821
if "google.colab" in str(ipython.__class__):
try:
from google.colab import output
return output.eval_js('document.documentElement.matches("[theme=light]")')
except ModuleNotFoundError:
return False
return False

View File

@ -285,6 +285,20 @@ def test_rich_progress_bar_counter_with_val_check_interval(tmpdir):
assert val_bar.total == 4
@RunIf(rich=True)
@mock.patch("pytorch_lightning.callbacks.progress.rich_progress._detect_light_colab_theme", return_value=True)
def test_rich_progress_bar_colab_light_theme_update(*_):
theme = RichProgressBar().theme
assert theme.description == "black"
assert theme.batch_progress == "black"
assert theme.metrics == "black"
theme = RichProgressBar(theme=RichProgressBarTheme(description="blue", metrics="red")).theme
assert theme.description == "blue"
assert theme.batch_progress == "black"
assert theme.metrics == "red"
@RunIf(rich=True)
def test_rich_progress_bar_metric_display_task_id(tmpdir):
class CustomModel(BoringModel):