From 8dc36c3745038063326cb85c3ab5c6b6021b0dda Mon Sep 17 00:00:00 2001 From: Aki Nitta Date: Wed, 12 Jan 2022 12:55:51 +0900 Subject: [PATCH] Fix inconsistent exceptions raised with no `rich` installed (#11360) Co-authored-by: Carlos Mocholi --- CHANGELOG.md | 3 +++ pytorch_lightning/callbacks/progress/rich_progress.py | 3 +-- pytorch_lightning/callbacks/rich_model_summary.py | 2 +- tests/callbacks/test_rich_model_summary.py | 11 ++++++----- tests/callbacks/test_rich_progress_bar.py | 11 ++++++----- 5 files changed, 17 insertions(+), 13 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 26392c4689..495b4f70d5 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -207,6 +207,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - `Trainer.logged_metrics` now always contains scalar tensors, even when a Python scalar was logged ([#11270](https://github.com/PyTorchLightning/pytorch-lightning/pull/11270)) +- Changed `MisconfigurationException` to `ModuleNotFoundError` when `rich` isn't available ([#11360](https://github.com/PyTorchLightning/pytorch-lightning/pull/11360)) + + ### Deprecated - Deprecated `ClusterEnvironment.master_{address,port}` in favor of `ClusterEnvironment.main_{address,port}` ([#10103](https://github.com/PyTorchLightning/pytorch-lightning/issues/10103)) diff --git a/pytorch_lightning/callbacks/progress/rich_progress.py b/pytorch_lightning/callbacks/progress/rich_progress.py index 46b5437934..570c6d7df6 100644 --- a/pytorch_lightning/callbacks/progress/rich_progress.py +++ b/pytorch_lightning/callbacks/progress/rich_progress.py @@ -17,7 +17,6 @@ from datetime import timedelta from typing import Any, Dict, Optional, Union from pytorch_lightning.callbacks.progress.base import ProgressBarBase -from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.imports import _RICH_AVAILABLE Task, Style = None, None @@ -231,7 +230,7 @@ class RichProgressBar(ProgressBarBase): console_kwargs: Optional[Dict[str, Any]] = None, ) -> None: if not _RICH_AVAILABLE: - raise MisconfigurationException( + raise ModuleNotFoundError( "`RichProgressBar` requires `rich` >= 10.2.2. Install it by running `pip install -U rich`." ) diff --git a/pytorch_lightning/callbacks/rich_model_summary.py b/pytorch_lightning/callbacks/rich_model_summary.py index cce4eb316a..14c078a273 100644 --- a/pytorch_lightning/callbacks/rich_model_summary.py +++ b/pytorch_lightning/callbacks/rich_model_summary.py @@ -61,7 +61,7 @@ class RichModelSummary(ModelSummary): def __init__(self, max_depth: int = 1) -> None: if not _RICH_AVAILABLE: raise ModuleNotFoundError( - "`RichProgressBar` requires `rich` to be installed. Install it by running `pip install -U rich`." + "`RichModelSummary` requires `rich` to be installed. Install it by running `pip install -U rich`." ) super().__init__(max_depth) diff --git a/tests/callbacks/test_rich_model_summary.py b/tests/callbacks/test_rich_model_summary.py index 5ab091bd01..c596557eed 100644 --- a/tests/callbacks/test_rich_model_summary.py +++ b/tests/callbacks/test_rich_model_summary.py @@ -19,7 +19,6 @@ import torch from pytorch_lightning import Trainer from pytorch_lightning.callbacks import RichModelSummary, RichProgressBar -from pytorch_lightning.utilities.imports import _RICH_AVAILABLE from pytorch_lightning.utilities.model_summary import summarize from tests.helpers import BoringModel from tests.helpers.runif import RunIf @@ -33,10 +32,12 @@ def test_rich_model_summary_callback(): assert isinstance(trainer.progress_bar_callback, RichProgressBar) -def test_rich_progress_bar_import_error(): - if not _RICH_AVAILABLE: - with pytest.raises(ImportError, match="`RichModelSummary` requires `rich` to be installed."): - Trainer(callbacks=RichModelSummary()) +def test_rich_progress_bar_import_error(monkeypatch): + import pytorch_lightning.callbacks.rich_model_summary as imports + + monkeypatch.setattr(imports, "_RICH_AVAILABLE", False) + with pytest.raises(ModuleNotFoundError, match="`RichModelSummary` requires `rich` to be installed."): + RichModelSummary() @RunIf(rich=True) diff --git a/tests/callbacks/test_rich_progress_bar.py b/tests/callbacks/test_rich_progress_bar.py index f2e75006f7..6b47e9558f 100644 --- a/tests/callbacks/test_rich_progress_bar.py +++ b/tests/callbacks/test_rich_progress_bar.py @@ -20,7 +20,6 @@ from torch.utils.data import DataLoader from pytorch_lightning import Trainer from pytorch_lightning.callbacks import ProgressBarBase, RichProgressBar from pytorch_lightning.callbacks.progress.rich_progress import RichProgressBarTheme -from pytorch_lightning.utilities.imports import _RICH_AVAILABLE from tests.helpers.boring_model import BoringModel, RandomDataset, RandomIterableDataset from tests.helpers.runif import RunIf @@ -83,10 +82,12 @@ def test_rich_progress_bar(progress_update, tmpdir, dataset): assert progress_update.call_count == 8 -def test_rich_progress_bar_import_error(): - if not _RICH_AVAILABLE: - with pytest.raises(ImportError, match="`RichProgressBar` requires `rich` >= 10.2.2."): - Trainer(callbacks=RichProgressBar()) +def test_rich_progress_bar_import_error(monkeypatch): + import pytorch_lightning.callbacks.progress.rich_progress as imports + + monkeypatch.setattr(imports, "_RICH_AVAILABLE", False) + with pytest.raises(ModuleNotFoundError, match="`RichProgressBar` requires `rich` >= 10.2.2."): + RichProgressBar() @RunIf(rich=True)