MySQL driver support (#366)

* initial support for MySQL

* Work towards testing both MySQL and Postgres drivers

* Linting

* STARLETTE_TEST_DATABASE -> STARLETTE_TEST_DATABASES

* Parameterize the database tests

* Include the MySQL Database URL

* MySQL tests should create the database if it doesn't yet exist

* Explict port of MySQL database URL

* Debugging on Travis

* Pass 'args' as a single argument

* Fix query compilation for mysql

* Lookup record values for MySQL

* Coerce MySQL results to correct python_type

* cursor.fetchrow should be cursor.fetchone

* Fix call to cursor.fetchone()

* Fix for fetchval()

* Debugging

* Nested transaction for MySQL should use SAVEPOINTS

* Drop savepoint implementation

* Fix SAVEPOINT implementation for MySQL

* Coverage for MySQL

* Linting

* Tweak defaults for MySQL connections
This commit is contained in:
Tom Christie 2019-01-30 14:21:58 +00:00 committed by GitHub
parent 82df72bdd7
commit 3270762c16
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
12 changed files with 363 additions and 113 deletions

View File

@ -7,10 +7,11 @@ python:
- "3.7-dev"
env:
- STARLETTE_TEST_DATABASE=postgresql://localhost/starlette
- STARLETTE_TEST_DATABASES="postgresql://localhost/starlette, mysql://localhost/starlette"
services:
- postgresql
- mysql
install:
- pip install -U -r requirements.txt

View File

@ -40,10 +40,13 @@ We provide a stand-alone **test script** to run tests in a reliable manner. Run
./scripts/test
```
By default, tests involving a database are excluded. To include them, set the `STARLETTE_TEST_DATABASE` environment variable to the URL of a PostgreSQL database, e.g.:
By default, tests involving a database are excluded. To include them, set the `STARLETTE_TEST_DATABASES` environment variable. This should be a comma separated string of database URLs.
```bash
export STARLETTE_TEST_DATABASE="postgresql://localhost/starlette"
# Any of the following are valid for running the database tests...
export STARLETTE_TEST_DATABASES="postgresql://localhost/starlette"
export STARLETTE_TEST_DATABASES="mysql://localhost/starlette_test"
export STARLETTE_TEST_DATABASES="postgresql://localhost/starlette, mysql://localhost/starlette_test"
```
## Linting

View File

@ -1,7 +1,10 @@
Starlette includes optional database support. There is currently only a driver
for Postgres databases, but MySQL and SQLite support is planned.
for Postgres and MySQL databases, but SQLite support is planned.
Enabling the built-in database support requires `sqlalchemy`, and an appropriate database driver. Currently this means `asyncpg` is a requirement.
Enabling the built-in database support requires `sqlalchemy`, and an appropriate database driver.
PostgreSQL: requires `asyncpg`
MySQL: requires `aiomysql`
The database support is completely optional - you can either include the middleware or not, or you can build alternative kinds of backends instead. It does not
include support for an ORM, but it does support using queries built using

View File

@ -8,8 +8,13 @@ pyyaml
requests
ujson
# Database backends
# Async database drivers
asyncpg
aiomysql
# Sync database drivers for standard tooling around setup/teardown/migrations.
psycopg2-binary
pymysql
# Testing
autoflake
@ -20,7 +25,6 @@ mypy
pytest
pytest-cov
sqlalchemy
psycopg2-binary
# Documentation
mkdocs

View File

@ -8,11 +8,11 @@ fi
export VERSION_SCRIPT="import sys; print('%s.%s' % sys.version_info[0:2])"
export PYTHON_VERSION=`python -c "$VERSION_SCRIPT"`
if [ -z "$STARLETTE_TEST_DATABASE" ] ; then
echo "Variable STARLETTE_TEST_DATABASE is unset. Excluding database tests."
if [ -z "$STARLETTE_TEST_DATABASES" ] ; then
echo "Variable STARLETTE_TEST_DATABASES is unset. Excluding database tests."
export IGNORE_MODULES="--ignore tests/test_database.py --cov-config tests/.ignore_database_tests"
else
echo "Variable STARLETTE_TEST_DATABASE is set"
echo "Variable STARLETTE_TEST_DATABASES is set"
export IGNORE_MODULES=""
fi

View File

@ -1,5 +1,4 @@
from starlette.database.core import (
compile,
transaction,
DatabaseBackend,
DatabaseSession,
@ -7,10 +6,4 @@ from starlette.database.core import (
)
__all__ = [
"compile",
"transaction",
"DatabaseBackend",
"DatabaseSession",
"DatabaseTransaction",
]
__all__ = ["transaction", "DatabaseBackend", "DatabaseSession", "DatabaseTransaction"]

View File

@ -1,34 +1,12 @@
import functools
import logging
import typing
from types import TracebackType
from sqlalchemy.engine.interfaces import Dialect
from sqlalchemy.sql import ClauseElement
from starlette.requests import Request
from starlette.responses import Response
logger = logging.getLogger("starlette.database")
def compile(query: ClauseElement, dialect: Dialect) -> typing.Tuple[str, list]:
# query = execute_defaults(query) # default values for Insert/Update
compiled = query.compile(dialect=dialect)
compiled_params = sorted(compiled.params.items())
mapping = {key: "$" + str(i) for i, (key, _) in enumerate(compiled_params, start=1)}
compiled_query = compiled.string % mapping
processors = compiled._bind_processors
args = [
processors[key](val) if key in processors else val
for key, val in compiled_params
]
logger.debug(compiled_query)
return compiled_query, args
def transaction(func: typing.Callable) -> typing.Callable:
@functools.wraps(func)

214
starlette/database/mysql.py Normal file
View File

@ -0,0 +1,214 @@
import getpass
import logging
import typing
import uuid
from types import TracebackType
import aiomysql
from sqlalchemy.dialects.mysql import pymysql
from sqlalchemy.engine.interfaces import Dialect
from sqlalchemy.sql import ClauseElement
from starlette.database.core import (
DatabaseBackend,
DatabaseSession,
DatabaseTransaction,
)
from starlette.datastructures import DatabaseURL
logger = logging.getLogger("starlette.database")
class MysqlBackend(DatabaseBackend):
def __init__(self, database_url: typing.Union[str, DatabaseURL]) -> None:
self.database_url = DatabaseURL(database_url)
self.dialect = self.get_dialect()
self.pool = None
def get_dialect(self) -> Dialect:
return pymysql.dialect(paramstyle="pyformat")
async def startup(self) -> None:
db = self.database_url
self.pool = await aiomysql.create_pool(
host=db.hostname,
port=db.port or 3306,
user=db.username or getpass.getuser(),
password=db.password,
db=db.database,
)
async def shutdown(self) -> None:
assert self.pool is not None, "DatabaseBackend is not running"
self.pool.close()
self.pool = None
def session(self) -> "MysqlSession":
assert self.pool is not None, "DatabaseBackend is not running"
return MysqlSession(self.pool, self.dialect)
class Record:
def __init__(self, row: tuple, result_columns: tuple) -> None:
self._row = row
self._result_columns = result_columns
self._column_map = {
col[0]: (idx, col) for idx, col in enumerate(self._result_columns)
}
def __getitem__(self, key: typing.Union[int, str]) -> typing.Any:
if isinstance(key, int):
idx = key
col = self._result_columns[idx]
else:
idx, col = self._column_map[key]
raw = self._row[idx]
return col[-1].python_type(raw)
class MysqlSession(DatabaseSession):
def __init__(self, pool: aiomysql.pool.Pool, dialect: Dialect):
self.pool = pool
self.dialect = dialect
self.conn = None
self.connection_holders = 0
self.has_root_transaction = False
def _compile(self, query: ClauseElement) -> typing.Tuple[str, list, tuple]:
compiled = query.compile(dialect=self.dialect)
args = compiled.construct_params()
logger.debug(compiled.string, args)
return compiled.string, args, compiled._result_columns
async def fetchall(self, query: ClauseElement) -> typing.Any:
query, args, result_columns = self._compile(query)
conn = await self.acquire_connection()
cursor = await conn.cursor()
try:
await cursor.execute(query, args)
rows = await cursor.fetchall()
return [Record(row, result_columns) for row in rows]
finally:
await cursor.close()
await self.release_connection()
async def fetchone(self, query: ClauseElement) -> typing.Any:
query, args, result_columns = self._compile(query)
conn = await self.acquire_connection()
cursor = await conn.cursor()
try:
await cursor.execute(query, args)
row = await cursor.fetchone()
return Record(row, result_columns)
finally:
await cursor.close()
await self.release_connection()
async def execute(self, query: ClauseElement) -> None:
query, args, result_columns = self._compile(query)
conn = await self.acquire_connection()
cursor = await conn.cursor()
try:
await cursor.execute(query, args)
finally:
await cursor.close()
await self.release_connection()
async def executemany(self, query: ClauseElement, values: list) -> None:
conn = await self.acquire_connection()
cursor = await conn.cursor()
try:
for item in values:
single_query = query.values(item)
single_query, args, result_columns = self._compile(single_query)
await cursor.execute(single_query, args)
finally:
await cursor.close()
await self.release_connection()
def transaction(self) -> DatabaseTransaction:
return MysqlTransaction(self)
async def acquire_connection(self) -> aiomysql.Connection:
"""
Either acquire a connection from the pool, or return the
existing connection. Must be followed by a corresponding
call to `release_connection`.
"""
self.connection_holders += 1
if self.conn is None:
self.conn = await self.pool.acquire()
return self.conn
async def release_connection(self) -> None:
self.connection_holders -= 1
if self.connection_holders == 0:
await self.pool.release(self.conn)
self.conn = None
class MysqlTransaction(DatabaseTransaction):
def __init__(self, session: MysqlSession):
self.session = session
self.is_root = False
async def __aenter__(self) -> None:
await self.start()
async def __aexit__(
self,
exc_type: typing.Type[BaseException] = None,
exc_value: BaseException = None,
traceback: TracebackType = None,
) -> None:
if exc_type is not None:
await self.rollback()
else:
await self.commit()
async def start(self) -> None:
if self.session.has_root_transaction is False:
self.session.has_root_transaction = True
self.is_root = True
self.conn = await self.session.acquire_connection()
if self.is_root:
await self.conn.begin()
else:
id = str(uuid.uuid4()).replace("-", "_")
self.savepoint_name = f"STARLETTE_SAVEPOINT_{id}"
cursor = await self.conn.cursor()
try:
await cursor.execute(f"SAVEPOINT {self.savepoint_name}")
finally:
await cursor.close()
async def commit(self) -> None:
if self.is_root: # pragma: no cover
# In test cases the root transaction is never committed,
# since we *always* wrap the test case up in a transaction
# and rollback to a clean state at the end.
await self.conn.commit()
self.session.has_root_transaction = False
else:
cursor = await self.conn.cursor()
try:
await cursor.execute(f"RELEASE SAVEPOINT {self.savepoint_name}")
finally:
await cursor.close()
await self.session.release_connection()
async def rollback(self) -> None:
if self.is_root:
await self.conn.rollback()
self.session.has_root_transaction = False
else:
cursor = await self.conn.cursor()
try:
await cursor.execute(f"ROLLBACK TO SAVEPOINT {self.savepoint_name}")
finally:
await cursor.close()
await self.session.release_connection()

View File

@ -1,3 +1,4 @@
import logging
import typing
from types import TracebackType
@ -10,14 +11,15 @@ from starlette.database.core import (
DatabaseBackend,
DatabaseSession,
DatabaseTransaction,
compile,
)
from starlette.datastructures import DatabaseURL
logger = logging.getLogger("starlette.database")
class PostgresBackend(DatabaseBackend):
def __init__(self, database_url: typing.Union[str, DatabaseURL]) -> None:
self.database_url = database_url
self.database_url = DatabaseURL(database_url)
self.dialect = self.get_dialect()
self.pool = None
@ -53,8 +55,26 @@ class PostgresSession(DatabaseSession):
self.conn = None
self.connection_holders = 0
def _compile(self, query: ClauseElement) -> typing.Tuple[str, list]:
compiled = query.compile(dialect=self.dialect)
compiled_params = sorted(compiled.params.items())
mapping = {
key: "$" + str(i) for i, (key, _) in enumerate(compiled_params, start=1)
}
compiled_query = compiled.string % mapping
processors = compiled._bind_processors
args = [
processors[key](val) if key in processors else val
for key, val in compiled_params
]
logger.debug(compiled_query, args)
return compiled_query, args
async def fetchall(self, query: ClauseElement) -> typing.Any:
query, args = compile(query, dialect=self.dialect)
query, args = self._compile(query)
conn = await self.acquire_connection()
try:
@ -63,7 +83,7 @@ class PostgresSession(DatabaseSession):
await self.release_connection()
async def fetchone(self, query: ClauseElement) -> typing.Any:
query, args = compile(query, dialect=self.dialect)
query, args = self._compile(query)
conn = await self.acquire_connection()
try:
@ -72,7 +92,7 @@ class PostgresSession(DatabaseSession):
await self.release_connection()
async def execute(self, query: ClauseElement) -> None:
query, args = compile(query, dialect=self.dialect)
query, args = self._compile(query)
conn = await self.acquire_connection()
try:
@ -88,7 +108,7 @@ class PostgresSession(DatabaseSession):
# using the same prepared statement.
for item in values:
single_query = query.values(item)
single_query, args = compile(single_query, dialect=self.dialect)
single_query, args = self._compile(single_query)
await conn.execute(single_query, *args)
finally:
await self.release_connection()

View File

@ -135,6 +135,9 @@ class URL:
class DatabaseURL(URL):
def __init__(self, url: typing.Union[str, URL]):
return super().__init__(str(url))
@property
def database(self) -> str:
return self.path.lstrip("/")

View File

@ -1,7 +1,5 @@
import typing
import asyncpg
from starlette.database.core import (
DatabaseBackend,
DatabaseSession,
@ -29,12 +27,21 @@ class DatabaseMiddleware:
) -> DatabaseBackend:
if isinstance(database_url, str):
database_url = DatabaseURL(database_url)
assert (
database_url.dialect == "postgresql"
), "Currently only postgresql is supported."
from starlette.database.postgres import PostgresBackend
assert database_url.dialect in [
"postgresql",
"mysql",
], "Currently only postgresql and mysql are supported."
return PostgresBackend(database_url)
if database_url.dialect == "postgresql":
from starlette.database.postgres import PostgresBackend
return PostgresBackend(database_url)
else:
assert database_url.dialect == "mysql"
from starlette.database.mysql import MysqlBackend
return MysqlBackend(database_url)
def __call__(self, scope: Scope) -> ASGIInstance:
if scope["type"] == "lifespan":

View File

@ -5,14 +5,15 @@ import sqlalchemy
from starlette.applications import Starlette
from starlette.database import transaction
from starlette.datastructures import CommaSeparatedStrings, DatabaseURL
from starlette.middleware.database import DatabaseMiddleware
from starlette.responses import JSONResponse
from starlette.testclient import TestClient
try:
DATABASE_URL = os.environ["STARLETTE_TEST_DATABASE"]
DATABASE_URLS = CommaSeparatedStrings(os.environ["STARLETTE_TEST_DATABASES"])
except KeyError: # pragma: no cover
pytest.skip("DATABASE_URL is not set", allow_module_level=True)
pytest.skip("STARLETTE_TEST_DATABASES is not set", allow_module_level=True)
metadata = sqlalchemy.MetaData()
@ -20,71 +21,87 @@ notes = sqlalchemy.Table(
"notes",
metadata,
sqlalchemy.Column("id", sqlalchemy.Integer, primary_key=True),
sqlalchemy.Column("text", sqlalchemy.String),
sqlalchemy.Column("text", sqlalchemy.String(length=100)),
sqlalchemy.Column("completed", sqlalchemy.Boolean),
)
app = Starlette()
app.add_middleware(
DatabaseMiddleware, database_url=DATABASE_URL, rollback_on_shutdown=True
)
@pytest.fixture(autouse=True, scope="module")
def create_test_database():
engine = sqlalchemy.create_engine(DATABASE_URL)
metadata.create_all(engine)
def create_test_databases():
engines = {}
for url in DATABASE_URLS:
db_url = DatabaseURL(url)
if db_url.dialect == "mysql":
#  Use the 'pymysql' driver for creating the database & tables.
url = str(db_url.replace(scheme="mysql+pymysql"))
db_name = db_url.database
db_url = db_url.replace(scheme="mysql+pymysql", database="")
engine = sqlalchemy.create_engine(str(db_url))
engine.execute("CREATE DATABASE IF NOT EXISTS " + db_name)
engines[url] = sqlalchemy.create_engine(url)
metadata.create_all(engines[url])
yield
engine.execute("DROP TABLE notes")
for engine in engines.values():
engine.execute("DROP TABLE notes")
@app.route("/notes", methods=["GET"])
async def list_notes(request):
query = notes.select()
results = await request.database.fetchall(query)
content = [
{"text": result["text"], "completed": result["completed"]} for result in results
]
return JSONResponse(content)
def get_app(database_url):
app = Starlette()
app.add_middleware(
DatabaseMiddleware, database_url=database_url, rollback_on_shutdown=True
)
@app.route("/notes", methods=["GET"])
async def list_notes(request):
query = notes.select()
results = await request.database.fetchall(query)
content = [
{"text": result["text"], "completed": result["completed"]}
for result in results
]
return JSONResponse(content)
@app.route("/notes", methods=["POST"])
@transaction
async def add_note(request):
data = await request.json()
query = notes.insert().values(text=data["text"], completed=data["completed"])
await request.database.execute(query)
if "raise_exc" in request.query_params:
raise RuntimeError()
return JSONResponse({"text": data["text"], "completed": data["completed"]})
@app.route("/notes/bulk_create", methods=["POST"])
async def bulk_create_notes(request):
data = await request.json()
query = notes.insert()
await request.database.executemany(query, data)
return JSONResponse({"notes": data})
@app.route("/notes/{note_id:int}", methods=["GET"])
async def read_note(request):
note_id = request.path_params["note_id"]
query = notes.select().where(notes.c.id == note_id)
result = await request.database.fetchone(query)
content = {"text": result["text"], "completed": result["completed"]}
return JSONResponse(content)
@app.route("/notes/{note_id:int}/text", methods=["GET"])
async def read_note_text(request):
note_id = request.path_params["note_id"]
query = sqlalchemy.select([notes.c.text]).where(notes.c.id == note_id)
text = await request.database.fetchval(query)
return JSONResponse(text)
return app
@app.route("/notes", methods=["POST"])
@transaction
async def add_note(request):
data = await request.json()
query = notes.insert().values(text=data["text"], completed=data["completed"])
await request.database.execute(query)
if "raise_exc" in request.query_params:
raise RuntimeError()
return JSONResponse({"text": data["text"], "completed": data["completed"]})
@app.route("/notes/bulk_create", methods=["POST"])
async def bulk_create_notes(request):
data = await request.json()
query = notes.insert()
await request.database.executemany(query, data)
return JSONResponse({"notes": data})
@app.route("/notes/{note_id:int}", methods=["GET"])
async def read_note(request):
note_id = request.path_params["note_id"]
query = notes.select().where(notes.c.id == note_id)
result = await request.database.fetchone(query)
content = {"text": result["text"], "completed": result["completed"]}
return JSONResponse(content)
@app.route("/notes/{note_id:int}/text", methods=["GET"])
async def read_note_text(request):
note_id = request.path_params["note_id"]
query = sqlalchemy.select([notes.c.text]).where(notes.c.id == note_id)
text = await request.database.fetchval(query)
return JSONResponse(text)
def test_database():
@pytest.mark.parametrize("database_url", DATABASE_URLS)
def test_database(database_url):
app = get_app(database_url)
with TestClient(app) as client:
response = client.post(
"/notes", json={"text": "buy the milk", "completed": True}
@ -119,8 +136,13 @@ def test_database():
assert response.json() == "buy the milk"
def test_database_executemany():
@pytest.mark.parametrize("database_url", DATABASE_URLS)
def test_database_executemany(database_url):
app = get_app(database_url)
with TestClient(app) as client:
response = client.get("/notes")
print(response.json())
data = [
{"text": "buy the milk", "completed": True},
{"text": "walk the dog", "completed": False},
@ -129,6 +151,7 @@ def test_database_executemany():
assert response.status_code == 200
response = client.get("/notes")
print(response.json())
assert response.status_code == 200
assert response.json() == [
{"text": "buy the milk", "completed": True},
@ -136,11 +159,12 @@ def test_database_executemany():
]
def test_database_isolated_during_test_cases():
@pytest.mark.parametrize("database_url", DATABASE_URLS)
def test_database_isolated_during_test_cases(database_url):
"""
Using `TestClient` as a context manager
"""
app = get_app(database_url)
with TestClient(app) as client:
response = client.post(
"/notes", json={"text": "just one note", "completed": True}