Add Database Component (#14995)
This commit is contained in:
parent
d0b092fda8
commit
979d728563
|
@ -35,6 +35,7 @@ exclude = [
|
|||
"src/lightning_app/cli/pl-app-template",
|
||||
"src/lightning_app/cli/react-ui-template",
|
||||
"src/lightning_app/cli/app-template",
|
||||
"src/lightning_app/components/database",
|
||||
]
|
||||
install_types = "True"
|
||||
non_interactive = "True"
|
||||
|
|
|
@ -10,4 +10,5 @@ trio<0.22.0
|
|||
pympler
|
||||
psutil
|
||||
setuptools<=59.5.0
|
||||
sqlmodel
|
||||
requests-mock
|
||||
|
|
|
@ -13,8 +13,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
|
|||
- Added a `--secret` option to CLI to allow binding secrets to app environment variables when running in the cloud ([#14612](https://github.com/Lightning-AI/lightning/pull/14612))
|
||||
- Added support for running the works without cloud compute in the default container ([#14819](https://github.com/Lightning-AI/lightning/pull/14819))
|
||||
- Added an HTTPQueue as an optional replacement for the default redis queue ([#14978](https://github.com/Lightning-AI/lightning/pull/14978)
|
||||
- Added a try / catch mechanism around request processing to avoid killing the flow ([#15187](https://github.com/Lightning-AI/lightning/pull/15187)
|
||||
- Added support for adding descriptions to commands either through a docstring or the `DESCRIPTION` attribute ([#15193](https://github.com/Lightning-AI/lightning/pull/15193)
|
||||
- Added a try / catch mechanism around request processing to avoid killing the flow ([#15187](https://github.com/Lightning-AI/lightning/pull/15187)
|
||||
- Added an Database Component ([#14995](https://github.com/Lightning-AI/lightning/pull/14995)
|
||||
|
||||
|
||||
### Fixed
|
||||
|
|
|
@ -0,0 +1,4 @@
|
|||
from lightning_app.components.database.client import DatabaseClient
|
||||
from lightning_app.components.database.server import Database
|
||||
|
||||
__all__ = ["Database", "DatabaseClient"]
|
|
@ -0,0 +1,78 @@
|
|||
from typing import Any, Dict, List, Optional, Type, TypeVar
|
||||
|
||||
import requests
|
||||
from requests import Session
|
||||
from requests.adapters import HTTPAdapter
|
||||
from urllib3.util.retry import Retry
|
||||
|
||||
from lightning_app.components.database.utilities import _GeneralModel
|
||||
|
||||
_CONNECTION_RETRY_TOTAL = 5
|
||||
_CONNECTION_RETRY_BACKOFF_FACTOR = 1
|
||||
|
||||
|
||||
def _configure_session() -> Session:
|
||||
"""Configures the session for GET and POST requests.
|
||||
|
||||
It enables a generous retrial strategy that waits for the application server to connect.
|
||||
"""
|
||||
retry_strategy = Retry(
|
||||
# wait time between retries increases exponentially according to: backoff_factor * (2 ** (retry - 1))
|
||||
total=_CONNECTION_RETRY_TOTAL,
|
||||
backoff_factor=_CONNECTION_RETRY_BACKOFF_FACTOR,
|
||||
status_forcelist=[429, 500, 502, 503, 504],
|
||||
)
|
||||
adapter = HTTPAdapter(max_retries=retry_strategy)
|
||||
http = requests.Session()
|
||||
http.mount("https://", adapter)
|
||||
http.mount("http://", adapter)
|
||||
return http
|
||||
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
class DatabaseClient:
|
||||
def __init__(self, db_url: str, token: Optional[str] = None, model: Optional[T] = None) -> None:
|
||||
self.db_url = db_url
|
||||
self.model = model
|
||||
self.token = token or ""
|
||||
self._session = None
|
||||
|
||||
def select_all(self, model: Optional[Type[T]] = None) -> List[T]:
|
||||
cls = model if model else self.model
|
||||
resp = self.session.post(
|
||||
self.db_url + "/select_all/", data=_GeneralModel.from_cls(cls, token=self.token).json()
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
return [cls(**data) for data in resp.json()]
|
||||
|
||||
def insert(self, model: T) -> None:
|
||||
resp = self.session.post(
|
||||
self.db_url + "/insert/",
|
||||
data=_GeneralModel.from_obj(model, token=self.token).json(),
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
|
||||
def update(self, model: T) -> None:
|
||||
resp = self.session.post(
|
||||
self.db_url + "/update/",
|
||||
data=_GeneralModel.from_obj(model, token=self.token).json(),
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
|
||||
def delete(self, model: T) -> None:
|
||||
resp = self.session.post(
|
||||
self.db_url + "/delete/",
|
||||
data=_GeneralModel.from_obj(model, token=self.token).json(),
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
|
||||
@property
|
||||
def session(self):
|
||||
if self._session is None:
|
||||
self._session = _configure_session()
|
||||
return self._session
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
return {"db_url": self.db_url, "model": self.model.__name__ if self.model else None}
|
|
@ -0,0 +1,176 @@
|
|||
import asyncio
|
||||
import os
|
||||
import sys
|
||||
from typing import List, Optional, Type, Union
|
||||
|
||||
import uvicorn
|
||||
from fastapi import FastAPI
|
||||
from uvicorn import run
|
||||
|
||||
from lightning import BuildConfig, LightningWork
|
||||
from lightning_app.components.database.utilities import _create_database, _Delete, _Insert, _SelectAll, _Update
|
||||
from lightning_app.storage import Drive
|
||||
from lightning_app.utilities.imports import _is_sqlmodel_available
|
||||
|
||||
if _is_sqlmodel_available():
|
||||
from sqlmodel import SQLModel
|
||||
|
||||
|
||||
# Required to avoid Uvicorn Server overriding Lightning App signal handlers.
|
||||
# Discussions: https://github.com/encode/uvicorn/discussions/1708
|
||||
class DatabaseUvicornServer(uvicorn.Server):
|
||||
|
||||
has_started_queue = None
|
||||
|
||||
def run(self, sockets=None):
|
||||
self.config.setup_event_loop()
|
||||
loop = asyncio.get_event_loop()
|
||||
asyncio.ensure_future(self.serve(sockets=sockets))
|
||||
loop.run_forever()
|
||||
|
||||
def install_signal_handlers(self):
|
||||
"""Ignore Uvicorn Signal Handlers."""
|
||||
|
||||
|
||||
class Database(LightningWork):
|
||||
def __init__(
|
||||
self,
|
||||
models: Union[Type["SQLModel"], List[Type["SQLModel"]]],
|
||||
db_filename: str = "database.db",
|
||||
debug: bool = False,
|
||||
) -> None:
|
||||
"""The Database Component enables to interact with an SQLite database to store some structured information
|
||||
about your application.
|
||||
|
||||
The provided models are SQLModel tables
|
||||
|
||||
Arguments:
|
||||
models: A SQLModel or a list of SQLModels table to be added to the database.
|
||||
db_filename: The name of the SQLite database.
|
||||
debug: Whether to run the database in debug mode.
|
||||
|
||||
Example::
|
||||
|
||||
from typing import List
|
||||
from sqlmodel import SQLModel, Field
|
||||
from uuid import uuid4
|
||||
|
||||
from lightning import LightningFlow, LightningApp
|
||||
from lightning_app.components.database import Database, DatabaseClient
|
||||
|
||||
class CounterModel(SQLModel, table=True):
|
||||
__table_args__ = {"extend_existing": True}
|
||||
|
||||
id: int = Field(default=None, primary_key=True)
|
||||
count: int
|
||||
|
||||
|
||||
class Flow(LightningFlow):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self._private_token = uuid4().hex
|
||||
self.db = Database(models=[CounterModel])
|
||||
self._client = None
|
||||
self.counter = 0
|
||||
|
||||
def run(self):
|
||||
self.db.run(token=self._private_token)
|
||||
|
||||
if not self.db.alive():
|
||||
return
|
||||
|
||||
if self.counter == 0:
|
||||
self._client = DatabaseClient(
|
||||
model=CounterModel,
|
||||
db_url=self.db.url,
|
||||
token=self._private_token,
|
||||
)
|
||||
|
||||
rows = self._client.select_all()
|
||||
|
||||
print(f"{self.counter}: {rows}")
|
||||
|
||||
if not rows:
|
||||
self._client.insert(CounterModel(count=0))
|
||||
else:
|
||||
row: CounterModel = rows[0]
|
||||
row.count += 1
|
||||
self._client.update(row)
|
||||
|
||||
if self.counter >= 100:
|
||||
row: CounterModel = rows[0]
|
||||
self._client.delete(row)
|
||||
self._exit()
|
||||
|
||||
self.counter += 1
|
||||
|
||||
app = LightningApp(Flow())
|
||||
|
||||
If you want to use nested SQLModels, we provide a utility to do so as follows:
|
||||
|
||||
Example::
|
||||
|
||||
from typing import List
|
||||
from sqlmodel import SQLModel, Field
|
||||
from sqlalchemy import Column
|
||||
|
||||
from lightning_app.components.database.utilities import pydantic_column_type
|
||||
|
||||
class KeyValuePair(SQLModel):
|
||||
name: str
|
||||
value: str
|
||||
|
||||
class CounterModel(SQLModel, table=True):
|
||||
__table_args__ = {"extend_existing": True}
|
||||
|
||||
name: int = Field(default=None, primary_key=True)
|
||||
|
||||
# RIGHT THERE ! You need to use Field and Column with the `pydantic_column_type` utility.
|
||||
kv: List[KeyValuePair] = Field(..., sa_column=Column(pydantic_column_type(List[KeyValuePair])))
|
||||
"""
|
||||
super().__init__(parallel=True, cloud_build_config=BuildConfig(["sqlmodel"]))
|
||||
self.db_filename = db_filename
|
||||
self.debug = debug
|
||||
self._models = models if isinstance(models, list) else [models]
|
||||
self.drive = None
|
||||
|
||||
def run(self, token: Optional[str] = None) -> None:
|
||||
"""
|
||||
Arguments:
|
||||
token: Token used to protect the database access. Ensure you don't expose it through the App State.
|
||||
"""
|
||||
self.drive = Drive("lit://database")
|
||||
if self.drive.list(component_name=self.name):
|
||||
self.drive.get(self.db_filename)
|
||||
print("Retrieved the database from Drive.")
|
||||
|
||||
app = FastAPI()
|
||||
|
||||
_create_database(self.db_filename, self._models, self.debug)
|
||||
models = {m.__name__: m for m in self._models}
|
||||
app.post("/select_all/")(_SelectAll(models, token))
|
||||
app.post("/insert/")(_Insert(models, token))
|
||||
app.post("/update/")(_Update(models, token))
|
||||
app.post("/delete/")(_Delete(models, token))
|
||||
|
||||
sys.modules["uvicorn.main"].Server = DatabaseUvicornServer
|
||||
|
||||
run(app, host=self.host, port=self.port, log_level="error")
|
||||
|
||||
def alive(self) -> bool:
|
||||
"""Hack: Returns whether the server is alive."""
|
||||
return self.db_url != ""
|
||||
|
||||
@property
|
||||
def db_url(self) -> Optional[str]:
|
||||
use_localhost = "LIGHTNING_APP_STATE_URL" not in os.environ
|
||||
if use_localhost:
|
||||
return self.url
|
||||
if self.internal_ip != "":
|
||||
return f"http://{self.internal_ip}:{self.port}"
|
||||
return self.internal_ip
|
||||
|
||||
def on_exit(self):
|
||||
self.drive.put(self.db_filename)
|
||||
print("Stored the database to the Drive.")
|
|
@ -0,0 +1,248 @@
|
|||
import functools
|
||||
import json
|
||||
import pathlib
|
||||
from typing import Any, Dict, Generic, List, Type, TypeVar
|
||||
|
||||
from fastapi import Response, status
|
||||
from fastapi.encoders import jsonable_encoder
|
||||
from pydantic import BaseModel, parse_obj_as
|
||||
from pydantic.main import ModelMetaclass
|
||||
|
||||
from lightning_app.utilities.app_helpers import Logger
|
||||
from lightning_app.utilities.imports import _is_sqlmodel_available
|
||||
|
||||
if _is_sqlmodel_available():
|
||||
from sqlalchemy.inspection import inspect as sqlalchemy_inspect
|
||||
from sqlmodel import JSON, select, Session, SQLModel, TypeDecorator
|
||||
|
||||
logger = Logger(__name__)
|
||||
engine = None
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
# Taken from https://github.com/tiangolo/sqlmodel/issues/63#issuecomment-1081555082
|
||||
def pydantic_column_type(pydantic_type: Any) -> Any:
|
||||
"""This function enables to support JSON types with SQLModel.
|
||||
|
||||
Example::
|
||||
|
||||
from sqlmodel import SQLModel
|
||||
from sqlalchemy import Column
|
||||
|
||||
class TrialConfig(SQLModel, table=False):
|
||||
...
|
||||
params: Dict[str, Union[Dict[str, float]] = Field(sa_column=Column(pydantic_column_type[Dict[str, float]))
|
||||
"""
|
||||
|
||||
class PydanticJSONType(TypeDecorator, Generic[T]):
|
||||
impl = JSON()
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
json_encoder=json,
|
||||
):
|
||||
self.json_encoder = json_encoder
|
||||
super().__init__()
|
||||
|
||||
def bind_processor(self, dialect):
|
||||
impl_processor = self.impl.bind_processor(dialect)
|
||||
dumps = self.json_encoder.dumps
|
||||
if impl_processor:
|
||||
|
||||
def process(value: T):
|
||||
if value is not None:
|
||||
if isinstance(pydantic_type, ModelMetaclass):
|
||||
# This allows to assign non-InDB models and if they're
|
||||
# compatible, they're directly parsed into the InDB
|
||||
# representation, thus hiding the implementation in the
|
||||
# background. However, the InDB model will still be returned
|
||||
value_to_dump = pydantic_type.from_orm(value)
|
||||
else:
|
||||
value_to_dump = value
|
||||
value = jsonable_encoder(value_to_dump)
|
||||
return impl_processor(value)
|
||||
|
||||
else:
|
||||
|
||||
def process(value):
|
||||
if isinstance(pydantic_type, ModelMetaclass):
|
||||
# This allows to assign non-InDB models and if they're
|
||||
# compatible, they're directly parsed into the InDB
|
||||
# representation, thus hiding the implementation in the
|
||||
# background. However, the InDB model will still be returned
|
||||
value_to_dump = pydantic_type.from_orm(value)
|
||||
else:
|
||||
value_to_dump = value
|
||||
value = dumps(jsonable_encoder(value_to_dump))
|
||||
return value
|
||||
|
||||
return process
|
||||
|
||||
def result_processor(self, dialect, coltype) -> T:
|
||||
impl_processor = self.impl.result_processor(dialect, coltype)
|
||||
if impl_processor:
|
||||
|
||||
def process(value):
|
||||
value = impl_processor(value)
|
||||
if value is None:
|
||||
return None
|
||||
|
||||
data = value
|
||||
# Explicitly use the generic directly, not type(T)
|
||||
full_obj = parse_obj_as(pydantic_type, data)
|
||||
return full_obj
|
||||
|
||||
else:
|
||||
|
||||
def process(value):
|
||||
if value is None:
|
||||
return None
|
||||
|
||||
# Explicitly use the generic directly, not type(T)
|
||||
full_obj = parse_obj_as(pydantic_type, value)
|
||||
return full_obj
|
||||
|
||||
return process
|
||||
|
||||
def compare_values(self, x, y):
|
||||
return x == y
|
||||
|
||||
return PydanticJSONType
|
||||
|
||||
|
||||
@functools.lru_cache
|
||||
def get_primary_key(model_type: Type[SQLModel]) -> str:
|
||||
primary_keys = sqlalchemy_inspect(model_type).primary_key
|
||||
|
||||
if len(primary_keys) != 1:
|
||||
raise ValueError(f"The model {model_type.__name__} should have a single primary key field.")
|
||||
|
||||
return primary_keys[0].name
|
||||
|
||||
|
||||
class _GeneralModel(BaseModel):
|
||||
cls_name: str
|
||||
data: str
|
||||
token: str
|
||||
|
||||
def convert_to_model(self, models: Dict[str, BaseModel]):
|
||||
return models[self.cls_name].parse_raw(self.data)
|
||||
|
||||
@classmethod
|
||||
def from_obj(cls, obj, token):
|
||||
return cls(
|
||||
**{
|
||||
"cls_name": obj.__class__.__name__,
|
||||
"data": obj.json(),
|
||||
"token": token,
|
||||
}
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_cls(cls, obj_cls, token):
|
||||
return cls(
|
||||
**{
|
||||
"cls_name": obj_cls.__name__,
|
||||
"data": "",
|
||||
"token": token,
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
class _SelectAll:
|
||||
def __init__(self, models, token):
|
||||
print(models, token)
|
||||
self.models = models
|
||||
self.token = token
|
||||
|
||||
def __call__(self, data: Dict, response: Response):
|
||||
if self.token and data["token"] != self.token:
|
||||
response.status_code = status.HTTP_401_UNAUTHORIZED
|
||||
return {"status": "failure", "reason": "Unauthorized request to the database."}
|
||||
|
||||
with Session(engine) as session:
|
||||
cls: Type[SQLModel] = self.models[data["cls_name"]]
|
||||
statement = select(cls)
|
||||
results = session.exec(statement)
|
||||
return results.all()
|
||||
|
||||
|
||||
class _Insert:
|
||||
def __init__(self, models, token):
|
||||
self.models = models
|
||||
self.token = token
|
||||
|
||||
def __call__(self, data: Dict, response: Response):
|
||||
if self.token and data["token"] != self.token:
|
||||
response.status_code = status.HTTP_401_UNAUTHORIZED
|
||||
return {"status": "failure", "reason": "Unauthorized request to the database."}
|
||||
|
||||
with Session(engine) as session:
|
||||
ele: SQLModel = self.models[data["cls_name"]].parse_raw(data["data"])
|
||||
session.add(ele)
|
||||
session.commit()
|
||||
session.refresh(ele)
|
||||
return ele
|
||||
|
||||
|
||||
class _Update:
|
||||
def __init__(self, models, token):
|
||||
self.models = models
|
||||
self.token = token
|
||||
|
||||
def __call__(self, data: Dict, response: Response):
|
||||
if self.token and data["token"] != self.token:
|
||||
response.status_code = status.HTTP_401_UNAUTHORIZED
|
||||
return {"status": "failure", "reason": "Unauthorized request to the database."}
|
||||
|
||||
with Session(engine) as session:
|
||||
update_data: SQLModel = self.models[data["cls_name"]].parse_raw(data["data"])
|
||||
primary_key = get_primary_key(update_data.__class__)
|
||||
identifier = getattr(update_data.__class__, primary_key, None)
|
||||
statement = select(update_data.__class__).where(identifier == getattr(update_data, primary_key))
|
||||
results = session.exec(statement)
|
||||
result = results.one()
|
||||
for k, v in vars(update_data).items():
|
||||
if k in ("id", "_sa_instance_state"):
|
||||
continue
|
||||
if getattr(result, k) != v:
|
||||
setattr(result, k, v)
|
||||
session.add(result)
|
||||
session.commit()
|
||||
session.refresh(result)
|
||||
|
||||
|
||||
class _Delete:
|
||||
def __init__(self, models, token):
|
||||
self.models = models
|
||||
self.token = token
|
||||
|
||||
def __call__(self, data: Dict, response: Response):
|
||||
if self.token and data["token"] != self.token:
|
||||
response.status_code = status.HTTP_401_UNAUTHORIZED
|
||||
return {"status": "failure", "reason": "Unauthorized request to the database."}
|
||||
|
||||
with Session(engine) as session:
|
||||
update_data: SQLModel = self.models[data["cls_name"]].parse_raw(data["data"])
|
||||
primary_key = get_primary_key(update_data.__class__)
|
||||
identifier = getattr(update_data.__class__, primary_key, None)
|
||||
statement = select(update_data.__class__).where(identifier == getattr(update_data, primary_key))
|
||||
results = session.exec(statement)
|
||||
result = results.one()
|
||||
session.delete(result)
|
||||
session.commit()
|
||||
|
||||
|
||||
def _create_database(db_filename: str, models: List[Type["SQLModel"]], echo: bool = False):
|
||||
global engine
|
||||
|
||||
from sqlmodel import create_engine, SQLModel
|
||||
|
||||
engine = create_engine(f"sqlite:///{pathlib.Path(db_filename).resolve()}", echo=echo)
|
||||
|
||||
logger.debug(f"Creating the following tables {models}")
|
||||
try:
|
||||
SQLModel.metadata.create_all(engine)
|
||||
except Exception as e:
|
||||
logger.debug(e)
|
|
@ -108,4 +108,8 @@ def _is_s3fs_available() -> bool:
|
|||
return module_available("s3fs")
|
||||
|
||||
|
||||
def _is_sqlmodel_available() -> bool:
|
||||
return module_available("sqlmodel")
|
||||
|
||||
|
||||
_CLOUD_TEST_RUN = bool(os.getenv("CLOUD", False))
|
||||
|
|
|
@ -35,7 +35,6 @@ if TYPE_CHECKING:
|
|||
from lightning_app import LightningWork
|
||||
from lightning_app.core.queues import BaseQueue
|
||||
|
||||
|
||||
from lightning_app.utilities.app_helpers import Logger
|
||||
|
||||
logger = Logger(__name__)
|
||||
|
@ -99,7 +98,7 @@ class ProxyWorkRun:
|
|||
self._validate_call_args(args, kwargs)
|
||||
args, kwargs = self._process_call_args(args, kwargs)
|
||||
|
||||
call_hash = self.work._call_hash(self.work_run, args, kwargs)
|
||||
call_hash = self.work._call_hash(self.work_run, *self._convert_hashable(args, kwargs))
|
||||
entered = call_hash in self.work._calls
|
||||
returned = entered and "ret" in self.work._calls[call_hash]
|
||||
# TODO (tchaton): Handle spot instance retrieval differently from stopped work.
|
||||
|
@ -177,6 +176,27 @@ class ProxyWorkRun:
|
|||
|
||||
return apply_to_collection((args, kwargs), dtype=(Path, Drive), function=sanitize)
|
||||
|
||||
@staticmethod
|
||||
def _convert_hashable(args: Tuple[Any, ...], kwargs: Dict[str, Any]) -> Tuple[Tuple[Any, ...], Dict[str, Any]]:
|
||||
"""Processes all positional and keyword arguments before they get passed to the caller queue and sent to
|
||||
the LightningWork.
|
||||
|
||||
Currently, this method only applies sanitization to Hashable Objects.
|
||||
|
||||
Args:
|
||||
args: The tuple of positional arguments passed to the run method.
|
||||
kwargs: The dictionary of named arguments passed to the run method.
|
||||
|
||||
Returns:
|
||||
The positional and keyword arguments in the same order they were passed in.
|
||||
"""
|
||||
from lightning_app.utilities.types import Hashable
|
||||
|
||||
def sanitize(obj: Hashable) -> Union[Path, Dict]:
|
||||
return obj.to_dict()
|
||||
|
||||
return apply_to_collection((args, kwargs), dtype=Hashable, function=sanitize)
|
||||
|
||||
|
||||
class WorkStateObserver(Thread):
|
||||
"""This thread runs alongside LightningWork and periodically checks for state changes. If the state changed
|
||||
|
|
|
@ -1,7 +1,15 @@
|
|||
from typing import Union
|
||||
import typing as t
|
||||
|
||||
from typing_extensions import Protocol, runtime_checkable
|
||||
|
||||
from lightning_app import LightningFlow, LightningWork
|
||||
from lightning_app.structures import Dict, List
|
||||
|
||||
Component = Union[LightningFlow, LightningWork, Dict, List]
|
||||
Component = t.Union[LightningFlow, LightningWork, Dict, List]
|
||||
ComponentTuple = (LightningFlow, LightningWork, Dict, List)
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class Hashable(Protocol):
|
||||
def to_dict(self) -> t.Dict[str, t.Any]:
|
||||
...
|
||||
|
|
|
@ -0,0 +1,162 @@
|
|||
import os
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from time import sleep
|
||||
from typing import List, Optional
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
|
||||
from lightning_app import LightningApp, LightningFlow, LightningWork
|
||||
from lightning_app.components.database import Database, DatabaseClient
|
||||
from lightning_app.components.database.utilities import _GeneralModel, pydantic_column_type
|
||||
from lightning_app.runners import MultiProcessRuntime
|
||||
from lightning_app.utilities.imports import _is_sqlmodel_available
|
||||
|
||||
if _is_sqlmodel_available():
|
||||
from sqlalchemy import Column
|
||||
from sqlmodel import Field, SQLModel
|
||||
|
||||
class Secret(SQLModel):
|
||||
name: str
|
||||
value: str
|
||||
|
||||
class TestConfig(SQLModel, table=True):
|
||||
__table_args__ = {"extend_existing": True}
|
||||
|
||||
id: Optional[int] = Field(default=None, primary_key=True)
|
||||
name: str
|
||||
secrets: List[Secret] = Field(..., sa_column=Column(pydantic_column_type(List[Secret])))
|
||||
|
||||
|
||||
class Work(LightningWork):
|
||||
def __init__(self):
|
||||
super().__init__(parallel=True)
|
||||
self.done = False
|
||||
|
||||
def run(self, client: DatabaseClient):
|
||||
rows = client.select_all()
|
||||
while len(rows) == 0:
|
||||
print(rows)
|
||||
sleep(0.1)
|
||||
rows = client.select_all()
|
||||
self.done = True
|
||||
|
||||
|
||||
@pytest.mark.skipif(not _is_sqlmodel_available(), reason="sqlmodel is required for this test.")
|
||||
def test_client_server():
|
||||
|
||||
database_path = Path("database.db").resolve()
|
||||
if database_path.exists():
|
||||
os.remove(database_path)
|
||||
|
||||
secrets = [Secret(name="example", value="secret")]
|
||||
|
||||
general = _GeneralModel.from_obj(TestConfig(name="name", secrets=secrets), token="a")
|
||||
assert general.cls_name == "TestConfig"
|
||||
assert general.data == '{"id": null, "name": "name", "secrets": [{"name": "example", "value": "secret"}]}'
|
||||
|
||||
class Flow(LightningFlow):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self._token = str(uuid4())
|
||||
self.db = Database(models=[TestConfig])
|
||||
self._client = None
|
||||
self.tracker = None
|
||||
self.work = Work()
|
||||
|
||||
def run(self):
|
||||
self.db.run(token=self._token)
|
||||
|
||||
if not self.db.alive():
|
||||
return
|
||||
|
||||
if not self._client:
|
||||
self._client = DatabaseClient(model=TestConfig, db_url=self.db.url, token=self._token)
|
||||
|
||||
assert self._client
|
||||
|
||||
self.work.run(self._client)
|
||||
|
||||
if self.tracker is None:
|
||||
self._client.insert(TestConfig(name="name", secrets=secrets))
|
||||
elem = self._client.select_all(TestConfig)[0]
|
||||
assert elem.name == "name"
|
||||
self.tracker = "update"
|
||||
assert isinstance(elem.secrets[0], Secret)
|
||||
assert elem.secrets[0].name == "example"
|
||||
assert elem.secrets[0].value == "secret"
|
||||
|
||||
elif self.tracker == "update":
|
||||
elem = self._client.select_all(TestConfig)[0]
|
||||
elem.name = "new_name"
|
||||
self._client.update(elem)
|
||||
|
||||
elem = self._client.select_all(TestConfig)[0]
|
||||
assert elem.name == "new_name"
|
||||
self.tracker = "delete"
|
||||
|
||||
elif self.tracker == "delete" and self.work.done:
|
||||
self.work.stop()
|
||||
|
||||
elem = self._client.select_all(TestConfig)[0]
|
||||
elem = self._client.delete(elem)
|
||||
|
||||
assert not self._client.select_all(TestConfig)
|
||||
self._client.insert(TestConfig(name="name", secrets=secrets))
|
||||
|
||||
assert self._client.select_all(TestConfig)
|
||||
self._exit()
|
||||
|
||||
app = LightningApp(Flow())
|
||||
MultiProcessRuntime(app, start_server=False).dispatch()
|
||||
|
||||
database_path = Path("database.db").resolve()
|
||||
if database_path.exists():
|
||||
os.remove(database_path)
|
||||
|
||||
|
||||
@pytest.mark.skipif(sys.platform == "win32", reason="currently not supported for windows.")
|
||||
@pytest.mark.skipif(not _is_sqlmodel_available(), reason="sqlmodel is required for this test.")
|
||||
def test_work_database_restart():
|
||||
|
||||
id = str(uuid4()).split("-")[0]
|
||||
|
||||
class Flow(LightningFlow):
|
||||
def __init__(self, restart=False):
|
||||
super().__init__()
|
||||
self.db = Database(db_filename=id, models=[TestConfig])
|
||||
self._client = None
|
||||
self.restart = restart
|
||||
|
||||
def run(self):
|
||||
self.db.run()
|
||||
|
||||
if not self.db.alive():
|
||||
return
|
||||
elif not self._client:
|
||||
self._client = DatabaseClient(self.db.db_url, None, model=TestConfig)
|
||||
|
||||
if not self.restart:
|
||||
self._client.insert(TestConfig(name="echo", secrets=[Secret(name="example", value="secret")]))
|
||||
self._exit()
|
||||
else:
|
||||
assert os.path.exists(id)
|
||||
assert len(self._client.select_all()) == 1
|
||||
self._exit()
|
||||
|
||||
app = LightningApp(Flow())
|
||||
MultiProcessRuntime(app).dispatch()
|
||||
|
||||
# Note: Waiting for SIGTERM signal to be handled
|
||||
sleep(2)
|
||||
|
||||
os.remove(id)
|
||||
|
||||
app = LightningApp(Flow(restart=True))
|
||||
MultiProcessRuntime(app).dispatch()
|
||||
|
||||
# Note: Waiting for SIGTERM signal to be handled
|
||||
sleep(2)
|
||||
|
||||
os.remove(id)
|
|
@ -423,7 +423,7 @@ class EmptyFlow(LightningFlow):
|
|||
"sleep_time, expect",
|
||||
[
|
||||
(1, 0),
|
||||
(0, 100),
|
||||
(0, 20),
|
||||
],
|
||||
)
|
||||
def test_lightning_app_aggregation_speed(default_timeout, queue_type_cls: BaseQueue, sleep_time, expect):
|
||||
|
|
|
@ -12,8 +12,8 @@ from lightning_app.testing.helpers import run_script, RunIf
|
|||
@pytest.mark.parametrize(
|
||||
"file",
|
||||
[
|
||||
pytest.param("component_popen.py"),
|
||||
pytest.param("component_tracer.py"),
|
||||
pytest.param("component_popen.py"),
|
||||
],
|
||||
)
|
||||
def test_scripts(file):
|
||||
|
|
Loading…
Reference in New Issue