From 979d728563c14a103a608d9fe49022d9c6320e76 Mon Sep 17 00:00:00 2001 From: thomas chaton Date: Wed, 19 Oct 2022 20:52:12 +0100 Subject: [PATCH] Add Database Component (#14995) --- pyproject.toml | 1 + requirements/app/test.txt | 1 + src/lightning_app/CHANGELOG.md | 3 +- .../components/database/__init__.py | 4 + .../components/database/client.py | 78 ++++++ .../components/database/server.py | 176 +++++++++++++ .../components/database/utilities.py | 248 ++++++++++++++++++ src/lightning_app/utilities/imports.py | 4 + src/lightning_app/utilities/proxies.py | 24 +- src/lightning_app/utilities/types.py | 12 +- .../components/database/test_client_server.py | 162 ++++++++++++ tests/tests_app/core/test_lightning_app.py | 2 +- .../components/python/test_scripts.py | 2 +- 13 files changed, 710 insertions(+), 7 deletions(-) create mode 100644 src/lightning_app/components/database/__init__.py create mode 100644 src/lightning_app/components/database/client.py create mode 100644 src/lightning_app/components/database/server.py create mode 100644 src/lightning_app/components/database/utilities.py create mode 100644 tests/tests_app/components/database/test_client_server.py diff --git a/pyproject.toml b/pyproject.toml index 687a31b966..236a876cd5 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -35,6 +35,7 @@ exclude = [ "src/lightning_app/cli/pl-app-template", "src/lightning_app/cli/react-ui-template", "src/lightning_app/cli/app-template", + "src/lightning_app/components/database", ] install_types = "True" non_interactive = "True" diff --git a/requirements/app/test.txt b/requirements/app/test.txt index 19471d79ce..6b3dc5f5f8 100644 --- a/requirements/app/test.txt +++ b/requirements/app/test.txt @@ -10,4 +10,5 @@ trio<0.22.0 pympler psutil setuptools<=59.5.0 +sqlmodel requests-mock diff --git a/src/lightning_app/CHANGELOG.md b/src/lightning_app/CHANGELOG.md index 8862f65ac7..26f2dfe545 100644 --- a/src/lightning_app/CHANGELOG.md +++ b/src/lightning_app/CHANGELOG.md @@ -13,8 +13,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Added a `--secret` option to CLI to allow binding secrets to app environment variables when running in the cloud ([#14612](https://github.com/Lightning-AI/lightning/pull/14612)) - Added support for running the works without cloud compute in the default container ([#14819](https://github.com/Lightning-AI/lightning/pull/14819)) - Added an HTTPQueue as an optional replacement for the default redis queue ([#14978](https://github.com/Lightning-AI/lightning/pull/14978) -- Added a try / catch mechanism around request processing to avoid killing the flow ([#15187](https://github.com/Lightning-AI/lightning/pull/15187) - Added support for adding descriptions to commands either through a docstring or the `DESCRIPTION` attribute ([#15193](https://github.com/Lightning-AI/lightning/pull/15193) +- Added a try / catch mechanism around request processing to avoid killing the flow ([#15187](https://github.com/Lightning-AI/lightning/pull/15187) +- Added an Database Component ([#14995](https://github.com/Lightning-AI/lightning/pull/14995) ### Fixed diff --git a/src/lightning_app/components/database/__init__.py b/src/lightning_app/components/database/__init__.py new file mode 100644 index 0000000000..d973517ea9 --- /dev/null +++ b/src/lightning_app/components/database/__init__.py @@ -0,0 +1,4 @@ +from lightning_app.components.database.client import DatabaseClient +from lightning_app.components.database.server import Database + +__all__ = ["Database", "DatabaseClient"] diff --git a/src/lightning_app/components/database/client.py b/src/lightning_app/components/database/client.py new file mode 100644 index 0000000000..7460736ef6 --- /dev/null +++ b/src/lightning_app/components/database/client.py @@ -0,0 +1,78 @@ +from typing import Any, Dict, List, Optional, Type, TypeVar + +import requests +from requests import Session +from requests.adapters import HTTPAdapter +from urllib3.util.retry import Retry + +from lightning_app.components.database.utilities import _GeneralModel + +_CONNECTION_RETRY_TOTAL = 5 +_CONNECTION_RETRY_BACKOFF_FACTOR = 1 + + +def _configure_session() -> Session: + """Configures the session for GET and POST requests. + + It enables a generous retrial strategy that waits for the application server to connect. + """ + retry_strategy = Retry( + # wait time between retries increases exponentially according to: backoff_factor * (2 ** (retry - 1)) + total=_CONNECTION_RETRY_TOTAL, + backoff_factor=_CONNECTION_RETRY_BACKOFF_FACTOR, + status_forcelist=[429, 500, 502, 503, 504], + ) + adapter = HTTPAdapter(max_retries=retry_strategy) + http = requests.Session() + http.mount("https://", adapter) + http.mount("http://", adapter) + return http + + +T = TypeVar("T") + + +class DatabaseClient: + def __init__(self, db_url: str, token: Optional[str] = None, model: Optional[T] = None) -> None: + self.db_url = db_url + self.model = model + self.token = token or "" + self._session = None + + def select_all(self, model: Optional[Type[T]] = None) -> List[T]: + cls = model if model else self.model + resp = self.session.post( + self.db_url + "/select_all/", data=_GeneralModel.from_cls(cls, token=self.token).json() + ) + assert resp.status_code == 200 + return [cls(**data) for data in resp.json()] + + def insert(self, model: T) -> None: + resp = self.session.post( + self.db_url + "/insert/", + data=_GeneralModel.from_obj(model, token=self.token).json(), + ) + assert resp.status_code == 200 + + def update(self, model: T) -> None: + resp = self.session.post( + self.db_url + "/update/", + data=_GeneralModel.from_obj(model, token=self.token).json(), + ) + assert resp.status_code == 200 + + def delete(self, model: T) -> None: + resp = self.session.post( + self.db_url + "/delete/", + data=_GeneralModel.from_obj(model, token=self.token).json(), + ) + assert resp.status_code == 200 + + @property + def session(self): + if self._session is None: + self._session = _configure_session() + return self._session + + def to_dict(self) -> Dict[str, Any]: + return {"db_url": self.db_url, "model": self.model.__name__ if self.model else None} diff --git a/src/lightning_app/components/database/server.py b/src/lightning_app/components/database/server.py new file mode 100644 index 0000000000..e26a008496 --- /dev/null +++ b/src/lightning_app/components/database/server.py @@ -0,0 +1,176 @@ +import asyncio +import os +import sys +from typing import List, Optional, Type, Union + +import uvicorn +from fastapi import FastAPI +from uvicorn import run + +from lightning import BuildConfig, LightningWork +from lightning_app.components.database.utilities import _create_database, _Delete, _Insert, _SelectAll, _Update +from lightning_app.storage import Drive +from lightning_app.utilities.imports import _is_sqlmodel_available + +if _is_sqlmodel_available(): + from sqlmodel import SQLModel + + +# Required to avoid Uvicorn Server overriding Lightning App signal handlers. +# Discussions: https://github.com/encode/uvicorn/discussions/1708 +class DatabaseUvicornServer(uvicorn.Server): + + has_started_queue = None + + def run(self, sockets=None): + self.config.setup_event_loop() + loop = asyncio.get_event_loop() + asyncio.ensure_future(self.serve(sockets=sockets)) + loop.run_forever() + + def install_signal_handlers(self): + """Ignore Uvicorn Signal Handlers.""" + + +class Database(LightningWork): + def __init__( + self, + models: Union[Type["SQLModel"], List[Type["SQLModel"]]], + db_filename: str = "database.db", + debug: bool = False, + ) -> None: + """The Database Component enables to interact with an SQLite database to store some structured information + about your application. + + The provided models are SQLModel tables + + Arguments: + models: A SQLModel or a list of SQLModels table to be added to the database. + db_filename: The name of the SQLite database. + debug: Whether to run the database in debug mode. + + Example:: + + from typing import List + from sqlmodel import SQLModel, Field + from uuid import uuid4 + + from lightning import LightningFlow, LightningApp + from lightning_app.components.database import Database, DatabaseClient + + class CounterModel(SQLModel, table=True): + __table_args__ = {"extend_existing": True} + + id: int = Field(default=None, primary_key=True) + count: int + + + class Flow(LightningFlow): + + def __init__(self): + super().__init__() + self._private_token = uuid4().hex + self.db = Database(models=[CounterModel]) + self._client = None + self.counter = 0 + + def run(self): + self.db.run(token=self._private_token) + + if not self.db.alive(): + return + + if self.counter == 0: + self._client = DatabaseClient( + model=CounterModel, + db_url=self.db.url, + token=self._private_token, + ) + + rows = self._client.select_all() + + print(f"{self.counter}: {rows}") + + if not rows: + self._client.insert(CounterModel(count=0)) + else: + row: CounterModel = rows[0] + row.count += 1 + self._client.update(row) + + if self.counter >= 100: + row: CounterModel = rows[0] + self._client.delete(row) + self._exit() + + self.counter += 1 + + app = LightningApp(Flow()) + + If you want to use nested SQLModels, we provide a utility to do so as follows: + + Example:: + + from typing import List + from sqlmodel import SQLModel, Field + from sqlalchemy import Column + + from lightning_app.components.database.utilities import pydantic_column_type + + class KeyValuePair(SQLModel): + name: str + value: str + + class CounterModel(SQLModel, table=True): + __table_args__ = {"extend_existing": True} + + name: int = Field(default=None, primary_key=True) + + # RIGHT THERE ! You need to use Field and Column with the `pydantic_column_type` utility. + kv: List[KeyValuePair] = Field(..., sa_column=Column(pydantic_column_type(List[KeyValuePair]))) + """ + super().__init__(parallel=True, cloud_build_config=BuildConfig(["sqlmodel"])) + self.db_filename = db_filename + self.debug = debug + self._models = models if isinstance(models, list) else [models] + self.drive = None + + def run(self, token: Optional[str] = None) -> None: + """ + Arguments: + token: Token used to protect the database access. Ensure you don't expose it through the App State. + """ + self.drive = Drive("lit://database") + if self.drive.list(component_name=self.name): + self.drive.get(self.db_filename) + print("Retrieved the database from Drive.") + + app = FastAPI() + + _create_database(self.db_filename, self._models, self.debug) + models = {m.__name__: m for m in self._models} + app.post("/select_all/")(_SelectAll(models, token)) + app.post("/insert/")(_Insert(models, token)) + app.post("/update/")(_Update(models, token)) + app.post("/delete/")(_Delete(models, token)) + + sys.modules["uvicorn.main"].Server = DatabaseUvicornServer + + run(app, host=self.host, port=self.port, log_level="error") + + def alive(self) -> bool: + """Hack: Returns whether the server is alive.""" + return self.db_url != "" + + @property + def db_url(self) -> Optional[str]: + use_localhost = "LIGHTNING_APP_STATE_URL" not in os.environ + if use_localhost: + return self.url + if self.internal_ip != "": + return f"http://{self.internal_ip}:{self.port}" + return self.internal_ip + + def on_exit(self): + self.drive.put(self.db_filename) + print("Stored the database to the Drive.") diff --git a/src/lightning_app/components/database/utilities.py b/src/lightning_app/components/database/utilities.py new file mode 100644 index 0000000000..5405a4d114 --- /dev/null +++ b/src/lightning_app/components/database/utilities.py @@ -0,0 +1,248 @@ +import functools +import json +import pathlib +from typing import Any, Dict, Generic, List, Type, TypeVar + +from fastapi import Response, status +from fastapi.encoders import jsonable_encoder +from pydantic import BaseModel, parse_obj_as +from pydantic.main import ModelMetaclass + +from lightning_app.utilities.app_helpers import Logger +from lightning_app.utilities.imports import _is_sqlmodel_available + +if _is_sqlmodel_available(): + from sqlalchemy.inspection import inspect as sqlalchemy_inspect + from sqlmodel import JSON, select, Session, SQLModel, TypeDecorator + +logger = Logger(__name__) +engine = None + +T = TypeVar("T") + + +# Taken from https://github.com/tiangolo/sqlmodel/issues/63#issuecomment-1081555082 +def pydantic_column_type(pydantic_type: Any) -> Any: + """This function enables to support JSON types with SQLModel. + + Example:: + + from sqlmodel import SQLModel + from sqlalchemy import Column + + class TrialConfig(SQLModel, table=False): + ... + params: Dict[str, Union[Dict[str, float]] = Field(sa_column=Column(pydantic_column_type[Dict[str, float])) + """ + + class PydanticJSONType(TypeDecorator, Generic[T]): + impl = JSON() + + def __init__( + self, + json_encoder=json, + ): + self.json_encoder = json_encoder + super().__init__() + + def bind_processor(self, dialect): + impl_processor = self.impl.bind_processor(dialect) + dumps = self.json_encoder.dumps + if impl_processor: + + def process(value: T): + if value is not None: + if isinstance(pydantic_type, ModelMetaclass): + # This allows to assign non-InDB models and if they're + # compatible, they're directly parsed into the InDB + # representation, thus hiding the implementation in the + # background. However, the InDB model will still be returned + value_to_dump = pydantic_type.from_orm(value) + else: + value_to_dump = value + value = jsonable_encoder(value_to_dump) + return impl_processor(value) + + else: + + def process(value): + if isinstance(pydantic_type, ModelMetaclass): + # This allows to assign non-InDB models and if they're + # compatible, they're directly parsed into the InDB + # representation, thus hiding the implementation in the + # background. However, the InDB model will still be returned + value_to_dump = pydantic_type.from_orm(value) + else: + value_to_dump = value + value = dumps(jsonable_encoder(value_to_dump)) + return value + + return process + + def result_processor(self, dialect, coltype) -> T: + impl_processor = self.impl.result_processor(dialect, coltype) + if impl_processor: + + def process(value): + value = impl_processor(value) + if value is None: + return None + + data = value + # Explicitly use the generic directly, not type(T) + full_obj = parse_obj_as(pydantic_type, data) + return full_obj + + else: + + def process(value): + if value is None: + return None + + # Explicitly use the generic directly, not type(T) + full_obj = parse_obj_as(pydantic_type, value) + return full_obj + + return process + + def compare_values(self, x, y): + return x == y + + return PydanticJSONType + + +@functools.lru_cache +def get_primary_key(model_type: Type[SQLModel]) -> str: + primary_keys = sqlalchemy_inspect(model_type).primary_key + + if len(primary_keys) != 1: + raise ValueError(f"The model {model_type.__name__} should have a single primary key field.") + + return primary_keys[0].name + + +class _GeneralModel(BaseModel): + cls_name: str + data: str + token: str + + def convert_to_model(self, models: Dict[str, BaseModel]): + return models[self.cls_name].parse_raw(self.data) + + @classmethod + def from_obj(cls, obj, token): + return cls( + **{ + "cls_name": obj.__class__.__name__, + "data": obj.json(), + "token": token, + } + ) + + @classmethod + def from_cls(cls, obj_cls, token): + return cls( + **{ + "cls_name": obj_cls.__name__, + "data": "", + "token": token, + } + ) + + +class _SelectAll: + def __init__(self, models, token): + print(models, token) + self.models = models + self.token = token + + def __call__(self, data: Dict, response: Response): + if self.token and data["token"] != self.token: + response.status_code = status.HTTP_401_UNAUTHORIZED + return {"status": "failure", "reason": "Unauthorized request to the database."} + + with Session(engine) as session: + cls: Type[SQLModel] = self.models[data["cls_name"]] + statement = select(cls) + results = session.exec(statement) + return results.all() + + +class _Insert: + def __init__(self, models, token): + self.models = models + self.token = token + + def __call__(self, data: Dict, response: Response): + if self.token and data["token"] != self.token: + response.status_code = status.HTTP_401_UNAUTHORIZED + return {"status": "failure", "reason": "Unauthorized request to the database."} + + with Session(engine) as session: + ele: SQLModel = self.models[data["cls_name"]].parse_raw(data["data"]) + session.add(ele) + session.commit() + session.refresh(ele) + return ele + + +class _Update: + def __init__(self, models, token): + self.models = models + self.token = token + + def __call__(self, data: Dict, response: Response): + if self.token and data["token"] != self.token: + response.status_code = status.HTTP_401_UNAUTHORIZED + return {"status": "failure", "reason": "Unauthorized request to the database."} + + with Session(engine) as session: + update_data: SQLModel = self.models[data["cls_name"]].parse_raw(data["data"]) + primary_key = get_primary_key(update_data.__class__) + identifier = getattr(update_data.__class__, primary_key, None) + statement = select(update_data.__class__).where(identifier == getattr(update_data, primary_key)) + results = session.exec(statement) + result = results.one() + for k, v in vars(update_data).items(): + if k in ("id", "_sa_instance_state"): + continue + if getattr(result, k) != v: + setattr(result, k, v) + session.add(result) + session.commit() + session.refresh(result) + + +class _Delete: + def __init__(self, models, token): + self.models = models + self.token = token + + def __call__(self, data: Dict, response: Response): + if self.token and data["token"] != self.token: + response.status_code = status.HTTP_401_UNAUTHORIZED + return {"status": "failure", "reason": "Unauthorized request to the database."} + + with Session(engine) as session: + update_data: SQLModel = self.models[data["cls_name"]].parse_raw(data["data"]) + primary_key = get_primary_key(update_data.__class__) + identifier = getattr(update_data.__class__, primary_key, None) + statement = select(update_data.__class__).where(identifier == getattr(update_data, primary_key)) + results = session.exec(statement) + result = results.one() + session.delete(result) + session.commit() + + +def _create_database(db_filename: str, models: List[Type["SQLModel"]], echo: bool = False): + global engine + + from sqlmodel import create_engine, SQLModel + + engine = create_engine(f"sqlite:///{pathlib.Path(db_filename).resolve()}", echo=echo) + + logger.debug(f"Creating the following tables {models}") + try: + SQLModel.metadata.create_all(engine) + except Exception as e: + logger.debug(e) diff --git a/src/lightning_app/utilities/imports.py b/src/lightning_app/utilities/imports.py index 06f780edb6..c44cae515f 100644 --- a/src/lightning_app/utilities/imports.py +++ b/src/lightning_app/utilities/imports.py @@ -108,4 +108,8 @@ def _is_s3fs_available() -> bool: return module_available("s3fs") +def _is_sqlmodel_available() -> bool: + return module_available("sqlmodel") + + _CLOUD_TEST_RUN = bool(os.getenv("CLOUD", False)) diff --git a/src/lightning_app/utilities/proxies.py b/src/lightning_app/utilities/proxies.py index 00a8f00f7a..7712913437 100644 --- a/src/lightning_app/utilities/proxies.py +++ b/src/lightning_app/utilities/proxies.py @@ -35,7 +35,6 @@ if TYPE_CHECKING: from lightning_app import LightningWork from lightning_app.core.queues import BaseQueue - from lightning_app.utilities.app_helpers import Logger logger = Logger(__name__) @@ -99,7 +98,7 @@ class ProxyWorkRun: self._validate_call_args(args, kwargs) args, kwargs = self._process_call_args(args, kwargs) - call_hash = self.work._call_hash(self.work_run, args, kwargs) + call_hash = self.work._call_hash(self.work_run, *self._convert_hashable(args, kwargs)) entered = call_hash in self.work._calls returned = entered and "ret" in self.work._calls[call_hash] # TODO (tchaton): Handle spot instance retrieval differently from stopped work. @@ -177,6 +176,27 @@ class ProxyWorkRun: return apply_to_collection((args, kwargs), dtype=(Path, Drive), function=sanitize) + @staticmethod + def _convert_hashable(args: Tuple[Any, ...], kwargs: Dict[str, Any]) -> Tuple[Tuple[Any, ...], Dict[str, Any]]: + """Processes all positional and keyword arguments before they get passed to the caller queue and sent to + the LightningWork. + + Currently, this method only applies sanitization to Hashable Objects. + + Args: + args: The tuple of positional arguments passed to the run method. + kwargs: The dictionary of named arguments passed to the run method. + + Returns: + The positional and keyword arguments in the same order they were passed in. + """ + from lightning_app.utilities.types import Hashable + + def sanitize(obj: Hashable) -> Union[Path, Dict]: + return obj.to_dict() + + return apply_to_collection((args, kwargs), dtype=Hashable, function=sanitize) + class WorkStateObserver(Thread): """This thread runs alongside LightningWork and periodically checks for state changes. If the state changed diff --git a/src/lightning_app/utilities/types.py b/src/lightning_app/utilities/types.py index a4c7a7d86a..32c77c1984 100644 --- a/src/lightning_app/utilities/types.py +++ b/src/lightning_app/utilities/types.py @@ -1,7 +1,15 @@ -from typing import Union +import typing as t + +from typing_extensions import Protocol, runtime_checkable from lightning_app import LightningFlow, LightningWork from lightning_app.structures import Dict, List -Component = Union[LightningFlow, LightningWork, Dict, List] +Component = t.Union[LightningFlow, LightningWork, Dict, List] ComponentTuple = (LightningFlow, LightningWork, Dict, List) + + +@runtime_checkable +class Hashable(Protocol): + def to_dict(self) -> t.Dict[str, t.Any]: + ... diff --git a/tests/tests_app/components/database/test_client_server.py b/tests/tests_app/components/database/test_client_server.py new file mode 100644 index 0000000000..7ab2382f1c --- /dev/null +++ b/tests/tests_app/components/database/test_client_server.py @@ -0,0 +1,162 @@ +import os +import sys +from pathlib import Path +from time import sleep +from typing import List, Optional +from uuid import uuid4 + +import pytest + +from lightning_app import LightningApp, LightningFlow, LightningWork +from lightning_app.components.database import Database, DatabaseClient +from lightning_app.components.database.utilities import _GeneralModel, pydantic_column_type +from lightning_app.runners import MultiProcessRuntime +from lightning_app.utilities.imports import _is_sqlmodel_available + +if _is_sqlmodel_available(): + from sqlalchemy import Column + from sqlmodel import Field, SQLModel + + class Secret(SQLModel): + name: str + value: str + + class TestConfig(SQLModel, table=True): + __table_args__ = {"extend_existing": True} + + id: Optional[int] = Field(default=None, primary_key=True) + name: str + secrets: List[Secret] = Field(..., sa_column=Column(pydantic_column_type(List[Secret]))) + + +class Work(LightningWork): + def __init__(self): + super().__init__(parallel=True) + self.done = False + + def run(self, client: DatabaseClient): + rows = client.select_all() + while len(rows) == 0: + print(rows) + sleep(0.1) + rows = client.select_all() + self.done = True + + +@pytest.mark.skipif(not _is_sqlmodel_available(), reason="sqlmodel is required for this test.") +def test_client_server(): + + database_path = Path("database.db").resolve() + if database_path.exists(): + os.remove(database_path) + + secrets = [Secret(name="example", value="secret")] + + general = _GeneralModel.from_obj(TestConfig(name="name", secrets=secrets), token="a") + assert general.cls_name == "TestConfig" + assert general.data == '{"id": null, "name": "name", "secrets": [{"name": "example", "value": "secret"}]}' + + class Flow(LightningFlow): + def __init__(self): + super().__init__() + self._token = str(uuid4()) + self.db = Database(models=[TestConfig]) + self._client = None + self.tracker = None + self.work = Work() + + def run(self): + self.db.run(token=self._token) + + if not self.db.alive(): + return + + if not self._client: + self._client = DatabaseClient(model=TestConfig, db_url=self.db.url, token=self._token) + + assert self._client + + self.work.run(self._client) + + if self.tracker is None: + self._client.insert(TestConfig(name="name", secrets=secrets)) + elem = self._client.select_all(TestConfig)[0] + assert elem.name == "name" + self.tracker = "update" + assert isinstance(elem.secrets[0], Secret) + assert elem.secrets[0].name == "example" + assert elem.secrets[0].value == "secret" + + elif self.tracker == "update": + elem = self._client.select_all(TestConfig)[0] + elem.name = "new_name" + self._client.update(elem) + + elem = self._client.select_all(TestConfig)[0] + assert elem.name == "new_name" + self.tracker = "delete" + + elif self.tracker == "delete" and self.work.done: + self.work.stop() + + elem = self._client.select_all(TestConfig)[0] + elem = self._client.delete(elem) + + assert not self._client.select_all(TestConfig) + self._client.insert(TestConfig(name="name", secrets=secrets)) + + assert self._client.select_all(TestConfig) + self._exit() + + app = LightningApp(Flow()) + MultiProcessRuntime(app, start_server=False).dispatch() + + database_path = Path("database.db").resolve() + if database_path.exists(): + os.remove(database_path) + + +@pytest.mark.skipif(sys.platform == "win32", reason="currently not supported for windows.") +@pytest.mark.skipif(not _is_sqlmodel_available(), reason="sqlmodel is required for this test.") +def test_work_database_restart(): + + id = str(uuid4()).split("-")[0] + + class Flow(LightningFlow): + def __init__(self, restart=False): + super().__init__() + self.db = Database(db_filename=id, models=[TestConfig]) + self._client = None + self.restart = restart + + def run(self): + self.db.run() + + if not self.db.alive(): + return + elif not self._client: + self._client = DatabaseClient(self.db.db_url, None, model=TestConfig) + + if not self.restart: + self._client.insert(TestConfig(name="echo", secrets=[Secret(name="example", value="secret")])) + self._exit() + else: + assert os.path.exists(id) + assert len(self._client.select_all()) == 1 + self._exit() + + app = LightningApp(Flow()) + MultiProcessRuntime(app).dispatch() + + # Note: Waiting for SIGTERM signal to be handled + sleep(2) + + os.remove(id) + + app = LightningApp(Flow(restart=True)) + MultiProcessRuntime(app).dispatch() + + # Note: Waiting for SIGTERM signal to be handled + sleep(2) + + os.remove(id) diff --git a/tests/tests_app/core/test_lightning_app.py b/tests/tests_app/core/test_lightning_app.py index 2f42643297..99db000ae3 100644 --- a/tests/tests_app/core/test_lightning_app.py +++ b/tests/tests_app/core/test_lightning_app.py @@ -423,7 +423,7 @@ class EmptyFlow(LightningFlow): "sleep_time, expect", [ (1, 0), - (0, 100), + (0, 20), ], ) def test_lightning_app_aggregation_speed(default_timeout, queue_type_cls: BaseQueue, sleep_time, expect): diff --git a/tests/tests_app_examples/components/python/test_scripts.py b/tests/tests_app_examples/components/python/test_scripts.py index 4a3084832a..0e25fefa28 100644 --- a/tests/tests_app_examples/components/python/test_scripts.py +++ b/tests/tests_app_examples/components/python/test_scripts.py @@ -12,8 +12,8 @@ from lightning_app.testing.helpers import run_script, RunIf @pytest.mark.parametrize( "file", [ - pytest.param("component_popen.py"), pytest.param("component_tracer.py"), + pytest.param("component_popen.py"), ], ) def test_scripts(file):