From 8c9cb0c133bb815ac40985f4b646a8a47b2e06f3 Mon Sep 17 00:00:00 2001 From: Sean Naren Date: Wed, 29 Sep 2021 18:24:51 +0100 Subject: [PATCH] [3/n] add additional rich version check (#9757) --- pytorch_lightning/callbacks/progress/rich_progress.py | 2 +- pytorch_lightning/callbacks/rich_model_summary.py | 2 +- pytorch_lightning/utilities/imports.py | 9 +++++++-- 3 files changed, 9 insertions(+), 4 deletions(-) diff --git a/pytorch_lightning/callbacks/progress/rich_progress.py b/pytorch_lightning/callbacks/progress/rich_progress.py index 4319bd6fb5..4da35c7c7a 100644 --- a/pytorch_lightning/callbacks/progress/rich_progress.py +++ b/pytorch_lightning/callbacks/progress/rich_progress.py @@ -209,7 +209,7 @@ class RichProgressBar(ProgressBarBase): ) -> None: if not _RICH_AVAILABLE: raise ImportError( - "`RichProgressBar` requires `rich` to be installed. Install it by running `pip install rich`." + "`RichProgressBar` requires `rich` to be installed. Install it by running `pip install -U rich`." ) super().__init__() self._refresh_rate_per_second: int = refresh_rate_per_second diff --git a/pytorch_lightning/callbacks/rich_model_summary.py b/pytorch_lightning/callbacks/rich_model_summary.py index 2e55665c44..da2f8c66cc 100644 --- a/pytorch_lightning/callbacks/rich_model_summary.py +++ b/pytorch_lightning/callbacks/rich_model_summary.py @@ -61,7 +61,7 @@ class RichModelSummary(ModelSummary): def __init__(self, max_depth: int = 1) -> None: if not _RICH_AVAILABLE: raise ImportError( - "`RichModelSummary` requires `rich` to be installed. Install it by running `pip install rich`." + "`RichProgressBar` requires `rich` to be installed. Install it by running `pip install -U rich`." ) super().__init__(max_depth) diff --git a/pytorch_lightning/utilities/imports.py b/pytorch_lightning/utilities/imports.py index 441a919710..aa59bc4537 100644 --- a/pytorch_lightning/utilities/imports.py +++ b/pytorch_lightning/utilities/imports.py @@ -19,6 +19,7 @@ import platform import sys from importlib.util import find_spec +import pkg_resources import torch from packaging.version import Version from pkg_resources import DistributionNotFound @@ -53,7 +54,11 @@ def _compare_version(package: str, op, version) -> bool: except (ModuleNotFoundError, DistributionNotFound): return False try: - pkg_version = Version(pkg.__version__) + if hasattr(pkg, "__version__"): + pkg_version = Version(pkg.__version__) + else: + # try pkg_resources to infer version + pkg_version = Version(pkg_resources.get_distribution(pkg).version) except TypeError: # this is mock by sphinx, so it shall return True ro generate all summaries return True @@ -84,7 +89,7 @@ _NEPTUNE_AVAILABLE = _module_available("neptune") _NEPTUNE_GREATER_EQUAL_0_9 = _NEPTUNE_AVAILABLE and _compare_version("neptune", operator.ge, "0.9.0") _OMEGACONF_AVAILABLE = _module_available("omegaconf") _POPTORCH_AVAILABLE = _module_available("poptorch") -_RICH_AVAILABLE = _module_available("rich") +_RICH_AVAILABLE = _module_available("rich") and _compare_version("rich", operator.ge, "10.2.2") _TORCH_CPU_AMP_AVAILABLE = _compare_version( "torch", operator.ge, "1.10.dev20210902" ) # todo: swap to 1.10.0 once released