Fix failing lightning cli entry point (#18821)

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Jirka Borovec <6035284+Borda@users.noreply.github.com>
This commit is contained in:
Adrian Wälchli 2023-10-25 02:51:11 +02:00 committed by GitHub
parent 245865d586
commit 9e75bc9572
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 32 additions and 1 deletions

View File

@ -1,5 +1,6 @@
"""Root package info."""
import logging
import sys
# explicitly don't set root logger's propagation and leave this to subpackages to manage
_logger = logging.getLogger(__name__)
@ -28,3 +29,19 @@ __all__ = [
"seed_everything",
"Fabric",
]
def _cli_entry_point() -> None:
from lightning_utilities.core.imports import ModuleAvailableCache, RequirementCache
if not (
ModuleAvailableCache("lightning.app")
if RequirementCache("lightning-utilities<0.10.0")
else RequirementCache(module="lightning.app") # type: ignore[call-arg]
):
print("The `lightning` command requires additional dependencies: `pip install lightning[app]`")
sys.exit(1)
from lightning.app.cli.lightning_cli import main
main()

View File

@ -115,7 +115,7 @@ def _setup_args() -> Dict[str, Any]:
"python_requires": ">=3.8", # todo: take the lowes based on all packages
"entry_points": {
"console_scripts": [
"lightning = lightning.app.cli.lightning_cli:main",
"lightning = lightning:_cli_entry_point",
],
},
"setup_requires": [],

View File

@ -13,6 +13,7 @@
# limitations under the License.
import contextlib
import os
import subprocess
from io import StringIO
from unittest import mock
from unittest.mock import Mock
@ -20,6 +21,7 @@ from unittest.mock import Mock
import pytest
import torch.distributed.run
from lightning.fabric.cli import _get_supported_strategies, _run_model
from lightning_utilities.core.imports import ModuleAvailableCache
from tests_fabric.helpers.runif import RunIf
@ -172,3 +174,15 @@ def test_cli_torchrun_num_processes_launched(_, devices, expected, monkeypatch,
fake_script,
]
)
@pytest.mark.skipif("lightning.fabric" == "lightning_fabric", reason="standalone package")
def test_cli_through_lightning_entry_point():
result = subprocess.run("lightning run model --help", capture_output=True, text=True, shell=True)
if not ModuleAvailableCache("lightning.app"):
message = "The `lightning` command requires additional dependencies"
assert message in result.stdout or message in result.stderr
assert result.returncode != 0
else:
message = "Usage: lightning run model [OPTIONS] SCRIPT [SCRIPT_ARGS]"
assert message in result.stdout or message in result.stderr