From 9e75bc9572152401a20d420f618260eaf4b69e73 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Wed, 25 Oct 2023 02:51:11 +0200 Subject: [PATCH] Fix failing lightning cli entry point (#18821) Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Jirka Borovec <6035284+Borda@users.noreply.github.com> --- src/lightning/__init__.py | 17 +++++++++++++++++ src/lightning/__setup__.py | 2 +- tests/tests_fabric/test_cli.py | 14 ++++++++++++++ 3 files changed, 32 insertions(+), 1 deletion(-) diff --git a/src/lightning/__init__.py b/src/lightning/__init__.py index 8a7fc68e72..4399e0348c 100644 --- a/src/lightning/__init__.py +++ b/src/lightning/__init__.py @@ -1,5 +1,6 @@ """Root package info.""" import logging +import sys # explicitly don't set root logger's propagation and leave this to subpackages to manage _logger = logging.getLogger(__name__) @@ -28,3 +29,19 @@ __all__ = [ "seed_everything", "Fabric", ] + + +def _cli_entry_point() -> None: + from lightning_utilities.core.imports import ModuleAvailableCache, RequirementCache + + if not ( + ModuleAvailableCache("lightning.app") + if RequirementCache("lightning-utilities<0.10.0") + else RequirementCache(module="lightning.app") # type: ignore[call-arg] + ): + print("The `lightning` command requires additional dependencies: `pip install lightning[app]`") + sys.exit(1) + + from lightning.app.cli.lightning_cli import main + + main() diff --git a/src/lightning/__setup__.py b/src/lightning/__setup__.py index 734fc73567..93b53dd254 100644 --- a/src/lightning/__setup__.py +++ b/src/lightning/__setup__.py @@ -115,7 +115,7 @@ def _setup_args() -> Dict[str, Any]: "python_requires": ">=3.8", # todo: take the lowes based on all packages "entry_points": { "console_scripts": [ - "lightning = lightning.app.cli.lightning_cli:main", + "lightning = lightning:_cli_entry_point", ], }, "setup_requires": [], diff --git a/tests/tests_fabric/test_cli.py b/tests/tests_fabric/test_cli.py index 2ff8ecad4f..4b2defafbc 100644 --- a/tests/tests_fabric/test_cli.py +++ b/tests/tests_fabric/test_cli.py @@ -13,6 +13,7 @@ # limitations under the License. import contextlib import os +import subprocess from io import StringIO from unittest import mock from unittest.mock import Mock @@ -20,6 +21,7 @@ from unittest.mock import Mock import pytest import torch.distributed.run from lightning.fabric.cli import _get_supported_strategies, _run_model +from lightning_utilities.core.imports import ModuleAvailableCache from tests_fabric.helpers.runif import RunIf @@ -172,3 +174,15 @@ def test_cli_torchrun_num_processes_launched(_, devices, expected, monkeypatch, fake_script, ] ) + + +@pytest.mark.skipif("lightning.fabric" == "lightning_fabric", reason="standalone package") +def test_cli_through_lightning_entry_point(): + result = subprocess.run("lightning run model --help", capture_output=True, text=True, shell=True) + if not ModuleAvailableCache("lightning.app"): + message = "The `lightning` command requires additional dependencies" + assert message in result.stdout or message in result.stderr + assert result.returncode != 0 + else: + message = "Usage: lightning run model [OPTIONS] SCRIPT [SCRIPT_ARGS]" + assert message in result.stdout or message in result.stderr