71 lines
2.3 KiB
Python
71 lines
2.3 KiB
Python
import logging
|
|
import os
|
|
from unittest import mock
|
|
|
|
import pytest
|
|
from click.testing import CliRunner
|
|
from tests_examples_app.public import _PATH_EXAMPLES
|
|
|
|
from lightning_app import LightningApp
|
|
from lightning_app.cli.lightning_cli import run_app
|
|
from lightning_app.testing.helpers import _RunIf
|
|
from lightning_app.testing.testing import run_app_in_cloud, wait_for
|
|
|
|
|
|
class QuickStartApp(LightningApp):
|
|
def __init__(self, *args, **kwargs):
|
|
super().__init__(*args, **kwargs)
|
|
self.root.serve_work._parallel = True
|
|
|
|
def run_once(self):
|
|
done = super().run_once()
|
|
if self.root.train_work.best_model_path:
|
|
return True
|
|
return done
|
|
|
|
|
|
# TODO: Investigate why it doesn't work
|
|
@pytest.mark.skip(reason="test is skipped because CI was blocking all the PRs.")
|
|
@_RunIf(pl=True, skip_windows=True, skip_linux=True)
|
|
def test_quick_start_example(caplog, monkeypatch):
|
|
"""This test ensures the Quick Start example properly train and serve PyTorch Lightning."""
|
|
|
|
monkeypatch.setattr("logging.getLogger", mock.MagicMock(return_value=logging.getLogger()))
|
|
|
|
with caplog.at_level(logging.INFO):
|
|
with mock.patch("lightning_app.LightningApp", QuickStartApp):
|
|
runner = CliRunner()
|
|
result = runner.invoke(
|
|
run_app,
|
|
[
|
|
os.path.join(_PATH_EXAMPLES, "lightning-quick-start", "app.py"),
|
|
"--blocking",
|
|
"False",
|
|
"--open-ui",
|
|
"False",
|
|
],
|
|
catch_exceptions=False,
|
|
)
|
|
assert result.exit_code == 0
|
|
|
|
|
|
@pytest.mark.cloud
|
|
def test_quick_start_example_cloud() -> None:
|
|
with run_app_in_cloud(os.path.join(_PATH_EXAMPLES, "lightning-quick-start")) as (_, view_page, _, _):
|
|
|
|
def click_gradio_demo(*_, **__):
|
|
button = view_page.locator('button:has-text("Interactive demo")')
|
|
button.wait_for(timeout=5 * 1000)
|
|
button.click()
|
|
return True
|
|
|
|
wait_for(view_page, click_gradio_demo)
|
|
|
|
def check_examples(*_, **__):
|
|
locator = view_page.frame_locator("iframe").locator('button:has-text("Submit")')
|
|
locator.wait_for(timeout=10 * 1000)
|
|
if len(locator.all_text_contents()) > 0:
|
|
return True
|
|
|
|
wait_for(view_page, check_examples)
|