diff --git a/CHANGELOG.md b/CHANGELOG.md index 33b8737006..7ef6553f8f 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -621,6 +621,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Configure native Deepspeed schedulers with interval='step' ([#11788](https://github.com/PyTorchLightning/pytorch-lightning/pull/11788)) +- Update `RichProgressBarTheme` styles after detecting light theme on colab ([#10993](https://github.com/PyTorchLightning/pytorch-lightning/pull/10993)) + + ## [1.5.10] - 2022-02-08 ### Fixed diff --git a/pytorch_lightning/callbacks/progress/rich_progress.py b/pytorch_lightning/callbacks/progress/rich_progress.py index 8f8ba6c93f..446bf5f3f4 100644 --- a/pytorch_lightning/callbacks/progress/rich_progress.py +++ b/pytorch_lightning/callbacks/progress/rich_progress.py @@ -248,6 +248,7 @@ class RichProgressBar(ProgressBarBase): self._metric_component = None self._progress_stopped: bool = False self.theme = theme + self._update_for_light_colab_theme() @property def refresh_rate(self) -> float: @@ -261,12 +262,6 @@ class RichProgressBar(ProgressBarBase): def is_disabled(self) -> bool: return not self.is_enabled - def disable(self) -> None: - self._enabled = False - - def enable(self) -> None: - self._enabled = True - @property def sanity_check_description(self) -> str: return "Validation Sanity Check" @@ -283,6 +278,19 @@ class RichProgressBar(ProgressBarBase): def predict_description(self) -> str: return "Predicting" + def _update_for_light_colab_theme(self) -> None: + if _detect_light_colab_theme(): + attributes = ["description", "batch_progress", "metrics"] + for attr in attributes: + if getattr(self.theme, attr) == "white": + setattr(self.theme, attr, "black") + + def disable(self) -> None: + self._enabled = False + + def enable(self) -> None: + self._enabled = True + def _init_progress(self, trainer): if self.is_enabled and (self.progress is None or self._progress_stopped): self._reset_progress_bar_ids() @@ -476,3 +484,20 @@ class RichProgressBar(ProgressBarBase): CustomTimeColumn(style=self.theme.time), ProcessingSpeedColumn(style=self.theme.processing_speed), ] + + +def _detect_light_colab_theme() -> bool: + """Detect if it's light theme in Colab.""" + try: + get_ipython # type: ignore + except NameError: + return False + ipython = get_ipython() # noqa: F821 + if "google.colab" in str(ipython.__class__): + try: + from google.colab import output + + return output.eval_js('document.documentElement.matches("[theme=light]")') + except ModuleNotFoundError: + return False + return False diff --git a/tests/callbacks/test_rich_progress_bar.py b/tests/callbacks/test_rich_progress_bar.py index 9f7f8e4541..cfe32dc495 100644 --- a/tests/callbacks/test_rich_progress_bar.py +++ b/tests/callbacks/test_rich_progress_bar.py @@ -285,6 +285,20 @@ def test_rich_progress_bar_counter_with_val_check_interval(tmpdir): assert val_bar.total == 4 +@RunIf(rich=True) +@mock.patch("pytorch_lightning.callbacks.progress.rich_progress._detect_light_colab_theme", return_value=True) +def test_rich_progress_bar_colab_light_theme_update(*_): + theme = RichProgressBar().theme + assert theme.description == "black" + assert theme.batch_progress == "black" + assert theme.metrics == "black" + + theme = RichProgressBar(theme=RichProgressBarTheme(description="blue", metrics="red")).theme + assert theme.description == "blue" + assert theme.batch_progress == "black" + assert theme.metrics == "red" + + @RunIf(rich=True) def test_rich_progress_bar_metric_display_task_id(tmpdir): class CustomModel(BoringModel):