Update RichProgressBarTheme after detecting light theme on colab (#10993)
This commit is contained in:
parent
11bd176d2f
commit
6ff38d4c8e
|
@ -621,6 +621,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
|
|||
- Configure native Deepspeed schedulers with interval='step' ([#11788](https://github.com/PyTorchLightning/pytorch-lightning/pull/11788))
|
||||
|
||||
|
||||
- Update `RichProgressBarTheme` styles after detecting light theme on colab ([#10993](https://github.com/PyTorchLightning/pytorch-lightning/pull/10993))
|
||||
|
||||
|
||||
## [1.5.10] - 2022-02-08
|
||||
|
||||
### Fixed
|
||||
|
|
|
@ -248,6 +248,7 @@ class RichProgressBar(ProgressBarBase):
|
|||
self._metric_component = None
|
||||
self._progress_stopped: bool = False
|
||||
self.theme = theme
|
||||
self._update_for_light_colab_theme()
|
||||
|
||||
@property
|
||||
def refresh_rate(self) -> float:
|
||||
|
@ -261,12 +262,6 @@ class RichProgressBar(ProgressBarBase):
|
|||
def is_disabled(self) -> bool:
|
||||
return not self.is_enabled
|
||||
|
||||
def disable(self) -> None:
|
||||
self._enabled = False
|
||||
|
||||
def enable(self) -> None:
|
||||
self._enabled = True
|
||||
|
||||
@property
|
||||
def sanity_check_description(self) -> str:
|
||||
return "Validation Sanity Check"
|
||||
|
@ -283,6 +278,19 @@ class RichProgressBar(ProgressBarBase):
|
|||
def predict_description(self) -> str:
|
||||
return "Predicting"
|
||||
|
||||
def _update_for_light_colab_theme(self) -> None:
|
||||
if _detect_light_colab_theme():
|
||||
attributes = ["description", "batch_progress", "metrics"]
|
||||
for attr in attributes:
|
||||
if getattr(self.theme, attr) == "white":
|
||||
setattr(self.theme, attr, "black")
|
||||
|
||||
def disable(self) -> None:
|
||||
self._enabled = False
|
||||
|
||||
def enable(self) -> None:
|
||||
self._enabled = True
|
||||
|
||||
def _init_progress(self, trainer):
|
||||
if self.is_enabled and (self.progress is None or self._progress_stopped):
|
||||
self._reset_progress_bar_ids()
|
||||
|
@ -476,3 +484,20 @@ class RichProgressBar(ProgressBarBase):
|
|||
CustomTimeColumn(style=self.theme.time),
|
||||
ProcessingSpeedColumn(style=self.theme.processing_speed),
|
||||
]
|
||||
|
||||
|
||||
def _detect_light_colab_theme() -> bool:
|
||||
"""Detect if it's light theme in Colab."""
|
||||
try:
|
||||
get_ipython # type: ignore
|
||||
except NameError:
|
||||
return False
|
||||
ipython = get_ipython() # noqa: F821
|
||||
if "google.colab" in str(ipython.__class__):
|
||||
try:
|
||||
from google.colab import output
|
||||
|
||||
return output.eval_js('document.documentElement.matches("[theme=light]")')
|
||||
except ModuleNotFoundError:
|
||||
return False
|
||||
return False
|
||||
|
|
|
@ -285,6 +285,20 @@ def test_rich_progress_bar_counter_with_val_check_interval(tmpdir):
|
|||
assert val_bar.total == 4
|
||||
|
||||
|
||||
@RunIf(rich=True)
|
||||
@mock.patch("pytorch_lightning.callbacks.progress.rich_progress._detect_light_colab_theme", return_value=True)
|
||||
def test_rich_progress_bar_colab_light_theme_update(*_):
|
||||
theme = RichProgressBar().theme
|
||||
assert theme.description == "black"
|
||||
assert theme.batch_progress == "black"
|
||||
assert theme.metrics == "black"
|
||||
|
||||
theme = RichProgressBar(theme=RichProgressBarTheme(description="blue", metrics="red")).theme
|
||||
assert theme.description == "blue"
|
||||
assert theme.batch_progress == "black"
|
||||
assert theme.metrics == "red"
|
||||
|
||||
|
||||
@RunIf(rich=True)
|
||||
def test_rich_progress_bar_metric_display_task_id(tmpdir):
|
||||
class CustomModel(BoringModel):
|
||||
|
|
Loading…
Reference in New Issue