Add Database Component (#14995)

This commit is contained in:
thomas chaton 2022-10-19 20:52:12 +01:00 committed by GitHub
parent d0b092fda8
commit 979d728563
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
13 changed files with 710 additions and 7 deletions

View File

@ -35,6 +35,7 @@ exclude = [
"src/lightning_app/cli/pl-app-template",
"src/lightning_app/cli/react-ui-template",
"src/lightning_app/cli/app-template",
"src/lightning_app/components/database",
]
install_types = "True"
non_interactive = "True"

View File

@ -10,4 +10,5 @@ trio<0.22.0
pympler
psutil
setuptools<=59.5.0
sqlmodel
requests-mock

View File

@ -13,8 +13,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Added a `--secret` option to CLI to allow binding secrets to app environment variables when running in the cloud ([#14612](https://github.com/Lightning-AI/lightning/pull/14612))
- Added support for running the works without cloud compute in the default container ([#14819](https://github.com/Lightning-AI/lightning/pull/14819))
- Added an HTTPQueue as an optional replacement for the default redis queue ([#14978](https://github.com/Lightning-AI/lightning/pull/14978)
- Added a try / catch mechanism around request processing to avoid killing the flow ([#15187](https://github.com/Lightning-AI/lightning/pull/15187)
- Added support for adding descriptions to commands either through a docstring or the `DESCRIPTION` attribute ([#15193](https://github.com/Lightning-AI/lightning/pull/15193)
- Added a try / catch mechanism around request processing to avoid killing the flow ([#15187](https://github.com/Lightning-AI/lightning/pull/15187)
- Added an Database Component ([#14995](https://github.com/Lightning-AI/lightning/pull/14995)
### Fixed

View File

@ -0,0 +1,4 @@
from lightning_app.components.database.client import DatabaseClient
from lightning_app.components.database.server import Database
__all__ = ["Database", "DatabaseClient"]

View File

@ -0,0 +1,78 @@
from typing import Any, Dict, List, Optional, Type, TypeVar
import requests
from requests import Session
from requests.adapters import HTTPAdapter
from urllib3.util.retry import Retry
from lightning_app.components.database.utilities import _GeneralModel
_CONNECTION_RETRY_TOTAL = 5
_CONNECTION_RETRY_BACKOFF_FACTOR = 1
def _configure_session() -> Session:
"""Configures the session for GET and POST requests.
It enables a generous retrial strategy that waits for the application server to connect.
"""
retry_strategy = Retry(
# wait time between retries increases exponentially according to: backoff_factor * (2 ** (retry - 1))
total=_CONNECTION_RETRY_TOTAL,
backoff_factor=_CONNECTION_RETRY_BACKOFF_FACTOR,
status_forcelist=[429, 500, 502, 503, 504],
)
adapter = HTTPAdapter(max_retries=retry_strategy)
http = requests.Session()
http.mount("https://", adapter)
http.mount("http://", adapter)
return http
T = TypeVar("T")
class DatabaseClient:
def __init__(self, db_url: str, token: Optional[str] = None, model: Optional[T] = None) -> None:
self.db_url = db_url
self.model = model
self.token = token or ""
self._session = None
def select_all(self, model: Optional[Type[T]] = None) -> List[T]:
cls = model if model else self.model
resp = self.session.post(
self.db_url + "/select_all/", data=_GeneralModel.from_cls(cls, token=self.token).json()
)
assert resp.status_code == 200
return [cls(**data) for data in resp.json()]
def insert(self, model: T) -> None:
resp = self.session.post(
self.db_url + "/insert/",
data=_GeneralModel.from_obj(model, token=self.token).json(),
)
assert resp.status_code == 200
def update(self, model: T) -> None:
resp = self.session.post(
self.db_url + "/update/",
data=_GeneralModel.from_obj(model, token=self.token).json(),
)
assert resp.status_code == 200
def delete(self, model: T) -> None:
resp = self.session.post(
self.db_url + "/delete/",
data=_GeneralModel.from_obj(model, token=self.token).json(),
)
assert resp.status_code == 200
@property
def session(self):
if self._session is None:
self._session = _configure_session()
return self._session
def to_dict(self) -> Dict[str, Any]:
return {"db_url": self.db_url, "model": self.model.__name__ if self.model else None}

View File

@ -0,0 +1,176 @@
import asyncio
import os
import sys
from typing import List, Optional, Type, Union
import uvicorn
from fastapi import FastAPI
from uvicorn import run
from lightning import BuildConfig, LightningWork
from lightning_app.components.database.utilities import _create_database, _Delete, _Insert, _SelectAll, _Update
from lightning_app.storage import Drive
from lightning_app.utilities.imports import _is_sqlmodel_available
if _is_sqlmodel_available():
from sqlmodel import SQLModel
# Required to avoid Uvicorn Server overriding Lightning App signal handlers.
# Discussions: https://github.com/encode/uvicorn/discussions/1708
class DatabaseUvicornServer(uvicorn.Server):
has_started_queue = None
def run(self, sockets=None):
self.config.setup_event_loop()
loop = asyncio.get_event_loop()
asyncio.ensure_future(self.serve(sockets=sockets))
loop.run_forever()
def install_signal_handlers(self):
"""Ignore Uvicorn Signal Handlers."""
class Database(LightningWork):
def __init__(
self,
models: Union[Type["SQLModel"], List[Type["SQLModel"]]],
db_filename: str = "database.db",
debug: bool = False,
) -> None:
"""The Database Component enables to interact with an SQLite database to store some structured information
about your application.
The provided models are SQLModel tables
Arguments:
models: A SQLModel or a list of SQLModels table to be added to the database.
db_filename: The name of the SQLite database.
debug: Whether to run the database in debug mode.
Example::
from typing import List
from sqlmodel import SQLModel, Field
from uuid import uuid4
from lightning import LightningFlow, LightningApp
from lightning_app.components.database import Database, DatabaseClient
class CounterModel(SQLModel, table=True):
__table_args__ = {"extend_existing": True}
id: int = Field(default=None, primary_key=True)
count: int
class Flow(LightningFlow):
def __init__(self):
super().__init__()
self._private_token = uuid4().hex
self.db = Database(models=[CounterModel])
self._client = None
self.counter = 0
def run(self):
self.db.run(token=self._private_token)
if not self.db.alive():
return
if self.counter == 0:
self._client = DatabaseClient(
model=CounterModel,
db_url=self.db.url,
token=self._private_token,
)
rows = self._client.select_all()
print(f"{self.counter}: {rows}")
if not rows:
self._client.insert(CounterModel(count=0))
else:
row: CounterModel = rows[0]
row.count += 1
self._client.update(row)
if self.counter >= 100:
row: CounterModel = rows[0]
self._client.delete(row)
self._exit()
self.counter += 1
app = LightningApp(Flow())
If you want to use nested SQLModels, we provide a utility to do so as follows:
Example::
from typing import List
from sqlmodel import SQLModel, Field
from sqlalchemy import Column
from lightning_app.components.database.utilities import pydantic_column_type
class KeyValuePair(SQLModel):
name: str
value: str
class CounterModel(SQLModel, table=True):
__table_args__ = {"extend_existing": True}
name: int = Field(default=None, primary_key=True)
# RIGHT THERE ! You need to use Field and Column with the `pydantic_column_type` utility.
kv: List[KeyValuePair] = Field(..., sa_column=Column(pydantic_column_type(List[KeyValuePair])))
"""
super().__init__(parallel=True, cloud_build_config=BuildConfig(["sqlmodel"]))
self.db_filename = db_filename
self.debug = debug
self._models = models if isinstance(models, list) else [models]
self.drive = None
def run(self, token: Optional[str] = None) -> None:
"""
Arguments:
token: Token used to protect the database access. Ensure you don't expose it through the App State.
"""
self.drive = Drive("lit://database")
if self.drive.list(component_name=self.name):
self.drive.get(self.db_filename)
print("Retrieved the database from Drive.")
app = FastAPI()
_create_database(self.db_filename, self._models, self.debug)
models = {m.__name__: m for m in self._models}
app.post("/select_all/")(_SelectAll(models, token))
app.post("/insert/")(_Insert(models, token))
app.post("/update/")(_Update(models, token))
app.post("/delete/")(_Delete(models, token))
sys.modules["uvicorn.main"].Server = DatabaseUvicornServer
run(app, host=self.host, port=self.port, log_level="error")
def alive(self) -> bool:
"""Hack: Returns whether the server is alive."""
return self.db_url != ""
@property
def db_url(self) -> Optional[str]:
use_localhost = "LIGHTNING_APP_STATE_URL" not in os.environ
if use_localhost:
return self.url
if self.internal_ip != "":
return f"http://{self.internal_ip}:{self.port}"
return self.internal_ip
def on_exit(self):
self.drive.put(self.db_filename)
print("Stored the database to the Drive.")

View File

@ -0,0 +1,248 @@
import functools
import json
import pathlib
from typing import Any, Dict, Generic, List, Type, TypeVar
from fastapi import Response, status
from fastapi.encoders import jsonable_encoder
from pydantic import BaseModel, parse_obj_as
from pydantic.main import ModelMetaclass
from lightning_app.utilities.app_helpers import Logger
from lightning_app.utilities.imports import _is_sqlmodel_available
if _is_sqlmodel_available():
from sqlalchemy.inspection import inspect as sqlalchemy_inspect
from sqlmodel import JSON, select, Session, SQLModel, TypeDecorator
logger = Logger(__name__)
engine = None
T = TypeVar("T")
# Taken from https://github.com/tiangolo/sqlmodel/issues/63#issuecomment-1081555082
def pydantic_column_type(pydantic_type: Any) -> Any:
"""This function enables to support JSON types with SQLModel.
Example::
from sqlmodel import SQLModel
from sqlalchemy import Column
class TrialConfig(SQLModel, table=False):
...
params: Dict[str, Union[Dict[str, float]] = Field(sa_column=Column(pydantic_column_type[Dict[str, float]))
"""
class PydanticJSONType(TypeDecorator, Generic[T]):
impl = JSON()
def __init__(
self,
json_encoder=json,
):
self.json_encoder = json_encoder
super().__init__()
def bind_processor(self, dialect):
impl_processor = self.impl.bind_processor(dialect)
dumps = self.json_encoder.dumps
if impl_processor:
def process(value: T):
if value is not None:
if isinstance(pydantic_type, ModelMetaclass):
# This allows to assign non-InDB models and if they're
# compatible, they're directly parsed into the InDB
# representation, thus hiding the implementation in the
# background. However, the InDB model will still be returned
value_to_dump = pydantic_type.from_orm(value)
else:
value_to_dump = value
value = jsonable_encoder(value_to_dump)
return impl_processor(value)
else:
def process(value):
if isinstance(pydantic_type, ModelMetaclass):
# This allows to assign non-InDB models and if they're
# compatible, they're directly parsed into the InDB
# representation, thus hiding the implementation in the
# background. However, the InDB model will still be returned
value_to_dump = pydantic_type.from_orm(value)
else:
value_to_dump = value
value = dumps(jsonable_encoder(value_to_dump))
return value
return process
def result_processor(self, dialect, coltype) -> T:
impl_processor = self.impl.result_processor(dialect, coltype)
if impl_processor:
def process(value):
value = impl_processor(value)
if value is None:
return None
data = value
# Explicitly use the generic directly, not type(T)
full_obj = parse_obj_as(pydantic_type, data)
return full_obj
else:
def process(value):
if value is None:
return None
# Explicitly use the generic directly, not type(T)
full_obj = parse_obj_as(pydantic_type, value)
return full_obj
return process
def compare_values(self, x, y):
return x == y
return PydanticJSONType
@functools.lru_cache
def get_primary_key(model_type: Type[SQLModel]) -> str:
primary_keys = sqlalchemy_inspect(model_type).primary_key
if len(primary_keys) != 1:
raise ValueError(f"The model {model_type.__name__} should have a single primary key field.")
return primary_keys[0].name
class _GeneralModel(BaseModel):
cls_name: str
data: str
token: str
def convert_to_model(self, models: Dict[str, BaseModel]):
return models[self.cls_name].parse_raw(self.data)
@classmethod
def from_obj(cls, obj, token):
return cls(
**{
"cls_name": obj.__class__.__name__,
"data": obj.json(),
"token": token,
}
)
@classmethod
def from_cls(cls, obj_cls, token):
return cls(
**{
"cls_name": obj_cls.__name__,
"data": "",
"token": token,
}
)
class _SelectAll:
def __init__(self, models, token):
print(models, token)
self.models = models
self.token = token
def __call__(self, data: Dict, response: Response):
if self.token and data["token"] != self.token:
response.status_code = status.HTTP_401_UNAUTHORIZED
return {"status": "failure", "reason": "Unauthorized request to the database."}
with Session(engine) as session:
cls: Type[SQLModel] = self.models[data["cls_name"]]
statement = select(cls)
results = session.exec(statement)
return results.all()
class _Insert:
def __init__(self, models, token):
self.models = models
self.token = token
def __call__(self, data: Dict, response: Response):
if self.token and data["token"] != self.token:
response.status_code = status.HTTP_401_UNAUTHORIZED
return {"status": "failure", "reason": "Unauthorized request to the database."}
with Session(engine) as session:
ele: SQLModel = self.models[data["cls_name"]].parse_raw(data["data"])
session.add(ele)
session.commit()
session.refresh(ele)
return ele
class _Update:
def __init__(self, models, token):
self.models = models
self.token = token
def __call__(self, data: Dict, response: Response):
if self.token and data["token"] != self.token:
response.status_code = status.HTTP_401_UNAUTHORIZED
return {"status": "failure", "reason": "Unauthorized request to the database."}
with Session(engine) as session:
update_data: SQLModel = self.models[data["cls_name"]].parse_raw(data["data"])
primary_key = get_primary_key(update_data.__class__)
identifier = getattr(update_data.__class__, primary_key, None)
statement = select(update_data.__class__).where(identifier == getattr(update_data, primary_key))
results = session.exec(statement)
result = results.one()
for k, v in vars(update_data).items():
if k in ("id", "_sa_instance_state"):
continue
if getattr(result, k) != v:
setattr(result, k, v)
session.add(result)
session.commit()
session.refresh(result)
class _Delete:
def __init__(self, models, token):
self.models = models
self.token = token
def __call__(self, data: Dict, response: Response):
if self.token and data["token"] != self.token:
response.status_code = status.HTTP_401_UNAUTHORIZED
return {"status": "failure", "reason": "Unauthorized request to the database."}
with Session(engine) as session:
update_data: SQLModel = self.models[data["cls_name"]].parse_raw(data["data"])
primary_key = get_primary_key(update_data.__class__)
identifier = getattr(update_data.__class__, primary_key, None)
statement = select(update_data.__class__).where(identifier == getattr(update_data, primary_key))
results = session.exec(statement)
result = results.one()
session.delete(result)
session.commit()
def _create_database(db_filename: str, models: List[Type["SQLModel"]], echo: bool = False):
global engine
from sqlmodel import create_engine, SQLModel
engine = create_engine(f"sqlite:///{pathlib.Path(db_filename).resolve()}", echo=echo)
logger.debug(f"Creating the following tables {models}")
try:
SQLModel.metadata.create_all(engine)
except Exception as e:
logger.debug(e)

View File

@ -108,4 +108,8 @@ def _is_s3fs_available() -> bool:
return module_available("s3fs")
def _is_sqlmodel_available() -> bool:
return module_available("sqlmodel")
_CLOUD_TEST_RUN = bool(os.getenv("CLOUD", False))

View File

@ -35,7 +35,6 @@ if TYPE_CHECKING:
from lightning_app import LightningWork
from lightning_app.core.queues import BaseQueue
from lightning_app.utilities.app_helpers import Logger
logger = Logger(__name__)
@ -99,7 +98,7 @@ class ProxyWorkRun:
self._validate_call_args(args, kwargs)
args, kwargs = self._process_call_args(args, kwargs)
call_hash = self.work._call_hash(self.work_run, args, kwargs)
call_hash = self.work._call_hash(self.work_run, *self._convert_hashable(args, kwargs))
entered = call_hash in self.work._calls
returned = entered and "ret" in self.work._calls[call_hash]
# TODO (tchaton): Handle spot instance retrieval differently from stopped work.
@ -177,6 +176,27 @@ class ProxyWorkRun:
return apply_to_collection((args, kwargs), dtype=(Path, Drive), function=sanitize)
@staticmethod
def _convert_hashable(args: Tuple[Any, ...], kwargs: Dict[str, Any]) -> Tuple[Tuple[Any, ...], Dict[str, Any]]:
"""Processes all positional and keyword arguments before they get passed to the caller queue and sent to
the LightningWork.
Currently, this method only applies sanitization to Hashable Objects.
Args:
args: The tuple of positional arguments passed to the run method.
kwargs: The dictionary of named arguments passed to the run method.
Returns:
The positional and keyword arguments in the same order they were passed in.
"""
from lightning_app.utilities.types import Hashable
def sanitize(obj: Hashable) -> Union[Path, Dict]:
return obj.to_dict()
return apply_to_collection((args, kwargs), dtype=Hashable, function=sanitize)
class WorkStateObserver(Thread):
"""This thread runs alongside LightningWork and periodically checks for state changes. If the state changed

View File

@ -1,7 +1,15 @@
from typing import Union
import typing as t
from typing_extensions import Protocol, runtime_checkable
from lightning_app import LightningFlow, LightningWork
from lightning_app.structures import Dict, List
Component = Union[LightningFlow, LightningWork, Dict, List]
Component = t.Union[LightningFlow, LightningWork, Dict, List]
ComponentTuple = (LightningFlow, LightningWork, Dict, List)
@runtime_checkable
class Hashable(Protocol):
def to_dict(self) -> t.Dict[str, t.Any]:
...

View File

@ -0,0 +1,162 @@
import os
import sys
from pathlib import Path
from time import sleep
from typing import List, Optional
from uuid import uuid4
import pytest
from lightning_app import LightningApp, LightningFlow, LightningWork
from lightning_app.components.database import Database, DatabaseClient
from lightning_app.components.database.utilities import _GeneralModel, pydantic_column_type
from lightning_app.runners import MultiProcessRuntime
from lightning_app.utilities.imports import _is_sqlmodel_available
if _is_sqlmodel_available():
from sqlalchemy import Column
from sqlmodel import Field, SQLModel
class Secret(SQLModel):
name: str
value: str
class TestConfig(SQLModel, table=True):
__table_args__ = {"extend_existing": True}
id: Optional[int] = Field(default=None, primary_key=True)
name: str
secrets: List[Secret] = Field(..., sa_column=Column(pydantic_column_type(List[Secret])))
class Work(LightningWork):
def __init__(self):
super().__init__(parallel=True)
self.done = False
def run(self, client: DatabaseClient):
rows = client.select_all()
while len(rows) == 0:
print(rows)
sleep(0.1)
rows = client.select_all()
self.done = True
@pytest.mark.skipif(not _is_sqlmodel_available(), reason="sqlmodel is required for this test.")
def test_client_server():
database_path = Path("database.db").resolve()
if database_path.exists():
os.remove(database_path)
secrets = [Secret(name="example", value="secret")]
general = _GeneralModel.from_obj(TestConfig(name="name", secrets=secrets), token="a")
assert general.cls_name == "TestConfig"
assert general.data == '{"id": null, "name": "name", "secrets": [{"name": "example", "value": "secret"}]}'
class Flow(LightningFlow):
def __init__(self):
super().__init__()
self._token = str(uuid4())
self.db = Database(models=[TestConfig])
self._client = None
self.tracker = None
self.work = Work()
def run(self):
self.db.run(token=self._token)
if not self.db.alive():
return
if not self._client:
self._client = DatabaseClient(model=TestConfig, db_url=self.db.url, token=self._token)
assert self._client
self.work.run(self._client)
if self.tracker is None:
self._client.insert(TestConfig(name="name", secrets=secrets))
elem = self._client.select_all(TestConfig)[0]
assert elem.name == "name"
self.tracker = "update"
assert isinstance(elem.secrets[0], Secret)
assert elem.secrets[0].name == "example"
assert elem.secrets[0].value == "secret"
elif self.tracker == "update":
elem = self._client.select_all(TestConfig)[0]
elem.name = "new_name"
self._client.update(elem)
elem = self._client.select_all(TestConfig)[0]
assert elem.name == "new_name"
self.tracker = "delete"
elif self.tracker == "delete" and self.work.done:
self.work.stop()
elem = self._client.select_all(TestConfig)[0]
elem = self._client.delete(elem)
assert not self._client.select_all(TestConfig)
self._client.insert(TestConfig(name="name", secrets=secrets))
assert self._client.select_all(TestConfig)
self._exit()
app = LightningApp(Flow())
MultiProcessRuntime(app, start_server=False).dispatch()
database_path = Path("database.db").resolve()
if database_path.exists():
os.remove(database_path)
@pytest.mark.skipif(sys.platform == "win32", reason="currently not supported for windows.")
@pytest.mark.skipif(not _is_sqlmodel_available(), reason="sqlmodel is required for this test.")
def test_work_database_restart():
id = str(uuid4()).split("-")[0]
class Flow(LightningFlow):
def __init__(self, restart=False):
super().__init__()
self.db = Database(db_filename=id, models=[TestConfig])
self._client = None
self.restart = restart
def run(self):
self.db.run()
if not self.db.alive():
return
elif not self._client:
self._client = DatabaseClient(self.db.db_url, None, model=TestConfig)
if not self.restart:
self._client.insert(TestConfig(name="echo", secrets=[Secret(name="example", value="secret")]))
self._exit()
else:
assert os.path.exists(id)
assert len(self._client.select_all()) == 1
self._exit()
app = LightningApp(Flow())
MultiProcessRuntime(app).dispatch()
# Note: Waiting for SIGTERM signal to be handled
sleep(2)
os.remove(id)
app = LightningApp(Flow(restart=True))
MultiProcessRuntime(app).dispatch()
# Note: Waiting for SIGTERM signal to be handled
sleep(2)
os.remove(id)

View File

@ -423,7 +423,7 @@ class EmptyFlow(LightningFlow):
"sleep_time, expect",
[
(1, 0),
(0, 100),
(0, 20),
],
)
def test_lightning_app_aggregation_speed(default_timeout, queue_type_cls: BaseQueue, sleep_time, expect):

View File

@ -12,8 +12,8 @@ from lightning_app.testing.helpers import run_script, RunIf
@pytest.mark.parametrize(
"file",
[
pytest.param("component_popen.py"),
pytest.param("component_tracer.py"),
pytest.param("component_popen.py"),
],
)
def test_scripts(file):