Fix failing lightning cli entry point (#18821)
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Jirka Borovec <6035284+Borda@users.noreply.github.com>
This commit is contained in:
parent
245865d586
commit
9e75bc9572
|
@ -1,5 +1,6 @@
|
|||
"""Root package info."""
|
||||
import logging
|
||||
import sys
|
||||
|
||||
# explicitly don't set root logger's propagation and leave this to subpackages to manage
|
||||
_logger = logging.getLogger(__name__)
|
||||
|
@ -28,3 +29,19 @@ __all__ = [
|
|||
"seed_everything",
|
||||
"Fabric",
|
||||
]
|
||||
|
||||
|
||||
def _cli_entry_point() -> None:
|
||||
from lightning_utilities.core.imports import ModuleAvailableCache, RequirementCache
|
||||
|
||||
if not (
|
||||
ModuleAvailableCache("lightning.app")
|
||||
if RequirementCache("lightning-utilities<0.10.0")
|
||||
else RequirementCache(module="lightning.app") # type: ignore[call-arg]
|
||||
):
|
||||
print("The `lightning` command requires additional dependencies: `pip install lightning[app]`")
|
||||
sys.exit(1)
|
||||
|
||||
from lightning.app.cli.lightning_cli import main
|
||||
|
||||
main()
|
||||
|
|
|
@ -115,7 +115,7 @@ def _setup_args() -> Dict[str, Any]:
|
|||
"python_requires": ">=3.8", # todo: take the lowes based on all packages
|
||||
"entry_points": {
|
||||
"console_scripts": [
|
||||
"lightning = lightning.app.cli.lightning_cli:main",
|
||||
"lightning = lightning:_cli_entry_point",
|
||||
],
|
||||
},
|
||||
"setup_requires": [],
|
||||
|
|
|
@ -13,6 +13,7 @@
|
|||
# limitations under the License.
|
||||
import contextlib
|
||||
import os
|
||||
import subprocess
|
||||
from io import StringIO
|
||||
from unittest import mock
|
||||
from unittest.mock import Mock
|
||||
|
@ -20,6 +21,7 @@ from unittest.mock import Mock
|
|||
import pytest
|
||||
import torch.distributed.run
|
||||
from lightning.fabric.cli import _get_supported_strategies, _run_model
|
||||
from lightning_utilities.core.imports import ModuleAvailableCache
|
||||
|
||||
from tests_fabric.helpers.runif import RunIf
|
||||
|
||||
|
@ -172,3 +174,15 @@ def test_cli_torchrun_num_processes_launched(_, devices, expected, monkeypatch,
|
|||
fake_script,
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.skipif("lightning.fabric" == "lightning_fabric", reason="standalone package")
|
||||
def test_cli_through_lightning_entry_point():
|
||||
result = subprocess.run("lightning run model --help", capture_output=True, text=True, shell=True)
|
||||
if not ModuleAvailableCache("lightning.app"):
|
||||
message = "The `lightning` command requires additional dependencies"
|
||||
assert message in result.stdout or message in result.stderr
|
||||
assert result.returncode != 0
|
||||
else:
|
||||
message = "Usage: lightning run model [OPTIONS] SCRIPT [SCRIPT_ARGS]"
|
||||
assert message in result.stdout or message in result.stderr
|
||||
|
|
Loading…
Reference in New Issue