diff --git a/.travis.yml b/.travis.yml index 060b3eea..1a6b9d2e 100644 --- a/.travis.yml +++ b/.travis.yml @@ -7,10 +7,11 @@ python: - "3.7-dev" env: - - STARLETTE_TEST_DATABASE=postgresql://localhost/starlette + - STARLETTE_TEST_DATABASES="postgresql://localhost/starlette, mysql://localhost/starlette" services: - postgresql + - mysql install: - pip install -U -r requirements.txt diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 42824762..6acbc2c8 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -40,10 +40,13 @@ We provide a stand-alone **test script** to run tests in a reliable manner. Run ./scripts/test ``` -By default, tests involving a database are excluded. To include them, set the `STARLETTE_TEST_DATABASE` environment variable to the URL of a PostgreSQL database, e.g.: +By default, tests involving a database are excluded. To include them, set the `STARLETTE_TEST_DATABASES` environment variable. This should be a comma separated string of database URLs. ```bash -export STARLETTE_TEST_DATABASE="postgresql://localhost/starlette" +# Any of the following are valid for running the database tests... +export STARLETTE_TEST_DATABASES="postgresql://localhost/starlette" +export STARLETTE_TEST_DATABASES="mysql://localhost/starlette_test" +export STARLETTE_TEST_DATABASES="postgresql://localhost/starlette, mysql://localhost/starlette_test" ``` ## Linting diff --git a/docs/database.md b/docs/database.md index 9f37eb70..96d6def6 100644 --- a/docs/database.md +++ b/docs/database.md @@ -1,7 +1,10 @@ Starlette includes optional database support. There is currently only a driver -for Postgres databases, but MySQL and SQLite support is planned. +for Postgres and MySQL databases, but SQLite support is planned. -Enabling the built-in database support requires `sqlalchemy`, and an appropriate database driver. Currently this means `asyncpg` is a requirement. +Enabling the built-in database support requires `sqlalchemy`, and an appropriate database driver. + +PostgreSQL: requires `asyncpg` +MySQL: requires `aiomysql` The database support is completely optional - you can either include the middleware or not, or you can build alternative kinds of backends instead. It does not include support for an ORM, but it does support using queries built using diff --git a/requirements.txt b/requirements.txt index 0d72fb77..9f46eda5 100644 --- a/requirements.txt +++ b/requirements.txt @@ -8,8 +8,13 @@ pyyaml requests ujson -# Database backends +# Async database drivers asyncpg +aiomysql + +# Sync database drivers for standard tooling around setup/teardown/migrations. +psycopg2-binary +pymysql # Testing autoflake @@ -20,7 +25,6 @@ mypy pytest pytest-cov sqlalchemy -psycopg2-binary # Documentation mkdocs diff --git a/scripts/test b/scripts/test index 78ded30d..52b0056c 100755 --- a/scripts/test +++ b/scripts/test @@ -8,11 +8,11 @@ fi export VERSION_SCRIPT="import sys; print('%s.%s' % sys.version_info[0:2])" export PYTHON_VERSION=`python -c "$VERSION_SCRIPT"` -if [ -z "$STARLETTE_TEST_DATABASE" ] ; then - echo "Variable STARLETTE_TEST_DATABASE is unset. Excluding database tests." +if [ -z "$STARLETTE_TEST_DATABASES" ] ; then + echo "Variable STARLETTE_TEST_DATABASES is unset. Excluding database tests." export IGNORE_MODULES="--ignore tests/test_database.py --cov-config tests/.ignore_database_tests" else - echo "Variable STARLETTE_TEST_DATABASE is set" + echo "Variable STARLETTE_TEST_DATABASES is set" export IGNORE_MODULES="" fi diff --git a/starlette/database/__init__.py b/starlette/database/__init__.py index caac583a..21620089 100644 --- a/starlette/database/__init__.py +++ b/starlette/database/__init__.py @@ -1,5 +1,4 @@ from starlette.database.core import ( - compile, transaction, DatabaseBackend, DatabaseSession, @@ -7,10 +6,4 @@ from starlette.database.core import ( ) -__all__ = [ - "compile", - "transaction", - "DatabaseBackend", - "DatabaseSession", - "DatabaseTransaction", -] +__all__ = ["transaction", "DatabaseBackend", "DatabaseSession", "DatabaseTransaction"] diff --git a/starlette/database/core.py b/starlette/database/core.py index bf53bc69..7cfc8e83 100644 --- a/starlette/database/core.py +++ b/starlette/database/core.py @@ -1,34 +1,12 @@ import functools -import logging import typing from types import TracebackType -from sqlalchemy.engine.interfaces import Dialect from sqlalchemy.sql import ClauseElement from starlette.requests import Request from starlette.responses import Response -logger = logging.getLogger("starlette.database") - - -def compile(query: ClauseElement, dialect: Dialect) -> typing.Tuple[str, list]: - # query = execute_defaults(query) # default values for Insert/Update - compiled = query.compile(dialect=dialect) - compiled_params = sorted(compiled.params.items()) - - mapping = {key: "$" + str(i) for i, (key, _) in enumerate(compiled_params, start=1)} - compiled_query = compiled.string % mapping - - processors = compiled._bind_processors - args = [ - processors[key](val) if key in processors else val - for key, val in compiled_params - ] - - logger.debug(compiled_query) - return compiled_query, args - def transaction(func: typing.Callable) -> typing.Callable: @functools.wraps(func) diff --git a/starlette/database/mysql.py b/starlette/database/mysql.py new file mode 100644 index 00000000..c9a8ac43 --- /dev/null +++ b/starlette/database/mysql.py @@ -0,0 +1,214 @@ +import getpass +import logging +import typing +import uuid +from types import TracebackType + +import aiomysql +from sqlalchemy.dialects.mysql import pymysql +from sqlalchemy.engine.interfaces import Dialect +from sqlalchemy.sql import ClauseElement + +from starlette.database.core import ( + DatabaseBackend, + DatabaseSession, + DatabaseTransaction, +) +from starlette.datastructures import DatabaseURL + +logger = logging.getLogger("starlette.database") + + +class MysqlBackend(DatabaseBackend): + def __init__(self, database_url: typing.Union[str, DatabaseURL]) -> None: + self.database_url = DatabaseURL(database_url) + self.dialect = self.get_dialect() + self.pool = None + + def get_dialect(self) -> Dialect: + return pymysql.dialect(paramstyle="pyformat") + + async def startup(self) -> None: + db = self.database_url + self.pool = await aiomysql.create_pool( + host=db.hostname, + port=db.port or 3306, + user=db.username or getpass.getuser(), + password=db.password, + db=db.database, + ) + + async def shutdown(self) -> None: + assert self.pool is not None, "DatabaseBackend is not running" + self.pool.close() + self.pool = None + + def session(self) -> "MysqlSession": + assert self.pool is not None, "DatabaseBackend is not running" + return MysqlSession(self.pool, self.dialect) + + +class Record: + def __init__(self, row: tuple, result_columns: tuple) -> None: + self._row = row + self._result_columns = result_columns + self._column_map = { + col[0]: (idx, col) for idx, col in enumerate(self._result_columns) + } + + def __getitem__(self, key: typing.Union[int, str]) -> typing.Any: + if isinstance(key, int): + idx = key + col = self._result_columns[idx] + else: + idx, col = self._column_map[key] + raw = self._row[idx] + return col[-1].python_type(raw) + + +class MysqlSession(DatabaseSession): + def __init__(self, pool: aiomysql.pool.Pool, dialect: Dialect): + self.pool = pool + self.dialect = dialect + self.conn = None + self.connection_holders = 0 + self.has_root_transaction = False + + def _compile(self, query: ClauseElement) -> typing.Tuple[str, list, tuple]: + compiled = query.compile(dialect=self.dialect) + args = compiled.construct_params() + logger.debug(compiled.string, args) + return compiled.string, args, compiled._result_columns + + async def fetchall(self, query: ClauseElement) -> typing.Any: + query, args, result_columns = self._compile(query) + + conn = await self.acquire_connection() + cursor = await conn.cursor() + try: + await cursor.execute(query, args) + rows = await cursor.fetchall() + return [Record(row, result_columns) for row in rows] + finally: + await cursor.close() + await self.release_connection() + + async def fetchone(self, query: ClauseElement) -> typing.Any: + query, args, result_columns = self._compile(query) + + conn = await self.acquire_connection() + cursor = await conn.cursor() + try: + await cursor.execute(query, args) + row = await cursor.fetchone() + return Record(row, result_columns) + finally: + await cursor.close() + await self.release_connection() + + async def execute(self, query: ClauseElement) -> None: + query, args, result_columns = self._compile(query) + + conn = await self.acquire_connection() + cursor = await conn.cursor() + try: + await cursor.execute(query, args) + finally: + await cursor.close() + await self.release_connection() + + async def executemany(self, query: ClauseElement, values: list) -> None: + conn = await self.acquire_connection() + cursor = await conn.cursor() + try: + for item in values: + single_query = query.values(item) + single_query, args, result_columns = self._compile(single_query) + await cursor.execute(single_query, args) + finally: + await cursor.close() + await self.release_connection() + + def transaction(self) -> DatabaseTransaction: + return MysqlTransaction(self) + + async def acquire_connection(self) -> aiomysql.Connection: + """ + Either acquire a connection from the pool, or return the + existing connection. Must be followed by a corresponding + call to `release_connection`. + """ + self.connection_holders += 1 + if self.conn is None: + self.conn = await self.pool.acquire() + return self.conn + + async def release_connection(self) -> None: + self.connection_holders -= 1 + if self.connection_holders == 0: + await self.pool.release(self.conn) + self.conn = None + + +class MysqlTransaction(DatabaseTransaction): + def __init__(self, session: MysqlSession): + self.session = session + self.is_root = False + + async def __aenter__(self) -> None: + await self.start() + + async def __aexit__( + self, + exc_type: typing.Type[BaseException] = None, + exc_value: BaseException = None, + traceback: TracebackType = None, + ) -> None: + if exc_type is not None: + await self.rollback() + else: + await self.commit() + + async def start(self) -> None: + if self.session.has_root_transaction is False: + self.session.has_root_transaction = True + self.is_root = True + + self.conn = await self.session.acquire_connection() + if self.is_root: + await self.conn.begin() + else: + id = str(uuid.uuid4()).replace("-", "_") + self.savepoint_name = f"STARLETTE_SAVEPOINT_{id}" + cursor = await self.conn.cursor() + try: + await cursor.execute(f"SAVEPOINT {self.savepoint_name}") + finally: + await cursor.close() + + async def commit(self) -> None: + if self.is_root: # pragma: no cover + # In test cases the root transaction is never committed, + # since we *always* wrap the test case up in a transaction + # and rollback to a clean state at the end. + await self.conn.commit() + self.session.has_root_transaction = False + else: + cursor = await self.conn.cursor() + try: + await cursor.execute(f"RELEASE SAVEPOINT {self.savepoint_name}") + finally: + await cursor.close() + await self.session.release_connection() + + async def rollback(self) -> None: + if self.is_root: + await self.conn.rollback() + self.session.has_root_transaction = False + else: + cursor = await self.conn.cursor() + try: + await cursor.execute(f"ROLLBACK TO SAVEPOINT {self.savepoint_name}") + finally: + await cursor.close() + await self.session.release_connection() diff --git a/starlette/database/postgres.py b/starlette/database/postgres.py index e2e0530b..de9e4e41 100644 --- a/starlette/database/postgres.py +++ b/starlette/database/postgres.py @@ -1,3 +1,4 @@ +import logging import typing from types import TracebackType @@ -10,14 +11,15 @@ from starlette.database.core import ( DatabaseBackend, DatabaseSession, DatabaseTransaction, - compile, ) from starlette.datastructures import DatabaseURL +logger = logging.getLogger("starlette.database") + class PostgresBackend(DatabaseBackend): def __init__(self, database_url: typing.Union[str, DatabaseURL]) -> None: - self.database_url = database_url + self.database_url = DatabaseURL(database_url) self.dialect = self.get_dialect() self.pool = None @@ -53,8 +55,26 @@ class PostgresSession(DatabaseSession): self.conn = None self.connection_holders = 0 + def _compile(self, query: ClauseElement) -> typing.Tuple[str, list]: + compiled = query.compile(dialect=self.dialect) + compiled_params = sorted(compiled.params.items()) + + mapping = { + key: "$" + str(i) for i, (key, _) in enumerate(compiled_params, start=1) + } + compiled_query = compiled.string % mapping + + processors = compiled._bind_processors + args = [ + processors[key](val) if key in processors else val + for key, val in compiled_params + ] + + logger.debug(compiled_query, args) + return compiled_query, args + async def fetchall(self, query: ClauseElement) -> typing.Any: - query, args = compile(query, dialect=self.dialect) + query, args = self._compile(query) conn = await self.acquire_connection() try: @@ -63,7 +83,7 @@ class PostgresSession(DatabaseSession): await self.release_connection() async def fetchone(self, query: ClauseElement) -> typing.Any: - query, args = compile(query, dialect=self.dialect) + query, args = self._compile(query) conn = await self.acquire_connection() try: @@ -72,7 +92,7 @@ class PostgresSession(DatabaseSession): await self.release_connection() async def execute(self, query: ClauseElement) -> None: - query, args = compile(query, dialect=self.dialect) + query, args = self._compile(query) conn = await self.acquire_connection() try: @@ -88,7 +108,7 @@ class PostgresSession(DatabaseSession): # using the same prepared statement. for item in values: single_query = query.values(item) - single_query, args = compile(single_query, dialect=self.dialect) + single_query, args = self._compile(single_query) await conn.execute(single_query, *args) finally: await self.release_connection() diff --git a/starlette/datastructures.py b/starlette/datastructures.py index 26b0486b..141de4f4 100644 --- a/starlette/datastructures.py +++ b/starlette/datastructures.py @@ -135,6 +135,9 @@ class URL: class DatabaseURL(URL): + def __init__(self, url: typing.Union[str, URL]): + return super().__init__(str(url)) + @property def database(self) -> str: return self.path.lstrip("/") diff --git a/starlette/middleware/database.py b/starlette/middleware/database.py index 006cd558..4fe3acf5 100644 --- a/starlette/middleware/database.py +++ b/starlette/middleware/database.py @@ -1,7 +1,5 @@ import typing -import asyncpg - from starlette.database.core import ( DatabaseBackend, DatabaseSession, @@ -29,12 +27,21 @@ class DatabaseMiddleware: ) -> DatabaseBackend: if isinstance(database_url, str): database_url = DatabaseURL(database_url) - assert ( - database_url.dialect == "postgresql" - ), "Currently only postgresql is supported." - from starlette.database.postgres import PostgresBackend + assert database_url.dialect in [ + "postgresql", + "mysql", + ], "Currently only postgresql and mysql are supported." - return PostgresBackend(database_url) + if database_url.dialect == "postgresql": + from starlette.database.postgres import PostgresBackend + + return PostgresBackend(database_url) + + else: + assert database_url.dialect == "mysql" + from starlette.database.mysql import MysqlBackend + + return MysqlBackend(database_url) def __call__(self, scope: Scope) -> ASGIInstance: if scope["type"] == "lifespan": diff --git a/tests/test_database.py b/tests/test_database.py index c287c8ab..cda1e0a0 100644 --- a/tests/test_database.py +++ b/tests/test_database.py @@ -5,14 +5,15 @@ import sqlalchemy from starlette.applications import Starlette from starlette.database import transaction +from starlette.datastructures import CommaSeparatedStrings, DatabaseURL from starlette.middleware.database import DatabaseMiddleware from starlette.responses import JSONResponse from starlette.testclient import TestClient try: - DATABASE_URL = os.environ["STARLETTE_TEST_DATABASE"] + DATABASE_URLS = CommaSeparatedStrings(os.environ["STARLETTE_TEST_DATABASES"]) except KeyError: # pragma: no cover - pytest.skip("DATABASE_URL is not set", allow_module_level=True) + pytest.skip("STARLETTE_TEST_DATABASES is not set", allow_module_level=True) metadata = sqlalchemy.MetaData() @@ -20,71 +21,87 @@ notes = sqlalchemy.Table( "notes", metadata, sqlalchemy.Column("id", sqlalchemy.Integer, primary_key=True), - sqlalchemy.Column("text", sqlalchemy.String), + sqlalchemy.Column("text", sqlalchemy.String(length=100)), sqlalchemy.Column("completed", sqlalchemy.Boolean), ) -app = Starlette() -app.add_middleware( - DatabaseMiddleware, database_url=DATABASE_URL, rollback_on_shutdown=True -) - @pytest.fixture(autouse=True, scope="module") -def create_test_database(): - engine = sqlalchemy.create_engine(DATABASE_URL) - metadata.create_all(engine) +def create_test_databases(): + engines = {} + for url in DATABASE_URLS: + db_url = DatabaseURL(url) + if db_url.dialect == "mysql": + #  Use the 'pymysql' driver for creating the database & tables. + url = str(db_url.replace(scheme="mysql+pymysql")) + db_name = db_url.database + db_url = db_url.replace(scheme="mysql+pymysql", database="") + engine = sqlalchemy.create_engine(str(db_url)) + engine.execute("CREATE DATABASE IF NOT EXISTS " + db_name) + + engines[url] = sqlalchemy.create_engine(url) + metadata.create_all(engines[url]) + yield - engine.execute("DROP TABLE notes") + + for engine in engines.values(): + engine.execute("DROP TABLE notes") -@app.route("/notes", methods=["GET"]) -async def list_notes(request): - query = notes.select() - results = await request.database.fetchall(query) - content = [ - {"text": result["text"], "completed": result["completed"]} for result in results - ] - return JSONResponse(content) +def get_app(database_url): + app = Starlette() + app.add_middleware( + DatabaseMiddleware, database_url=database_url, rollback_on_shutdown=True + ) + + @app.route("/notes", methods=["GET"]) + async def list_notes(request): + query = notes.select() + results = await request.database.fetchall(query) + content = [ + {"text": result["text"], "completed": result["completed"]} + for result in results + ] + return JSONResponse(content) + + @app.route("/notes", methods=["POST"]) + @transaction + async def add_note(request): + data = await request.json() + query = notes.insert().values(text=data["text"], completed=data["completed"]) + await request.database.execute(query) + if "raise_exc" in request.query_params: + raise RuntimeError() + return JSONResponse({"text": data["text"], "completed": data["completed"]}) + + @app.route("/notes/bulk_create", methods=["POST"]) + async def bulk_create_notes(request): + data = await request.json() + query = notes.insert() + await request.database.executemany(query, data) + return JSONResponse({"notes": data}) + + @app.route("/notes/{note_id:int}", methods=["GET"]) + async def read_note(request): + note_id = request.path_params["note_id"] + query = notes.select().where(notes.c.id == note_id) + result = await request.database.fetchone(query) + content = {"text": result["text"], "completed": result["completed"]} + return JSONResponse(content) + + @app.route("/notes/{note_id:int}/text", methods=["GET"]) + async def read_note_text(request): + note_id = request.path_params["note_id"] + query = sqlalchemy.select([notes.c.text]).where(notes.c.id == note_id) + text = await request.database.fetchval(query) + return JSONResponse(text) + + return app -@app.route("/notes", methods=["POST"]) -@transaction -async def add_note(request): - data = await request.json() - query = notes.insert().values(text=data["text"], completed=data["completed"]) - await request.database.execute(query) - if "raise_exc" in request.query_params: - raise RuntimeError() - return JSONResponse({"text": data["text"], "completed": data["completed"]}) - - -@app.route("/notes/bulk_create", methods=["POST"]) -async def bulk_create_notes(request): - data = await request.json() - query = notes.insert() - await request.database.executemany(query, data) - return JSONResponse({"notes": data}) - - -@app.route("/notes/{note_id:int}", methods=["GET"]) -async def read_note(request): - note_id = request.path_params["note_id"] - query = notes.select().where(notes.c.id == note_id) - result = await request.database.fetchone(query) - content = {"text": result["text"], "completed": result["completed"]} - return JSONResponse(content) - - -@app.route("/notes/{note_id:int}/text", methods=["GET"]) -async def read_note_text(request): - note_id = request.path_params["note_id"] - query = sqlalchemy.select([notes.c.text]).where(notes.c.id == note_id) - text = await request.database.fetchval(query) - return JSONResponse(text) - - -def test_database(): +@pytest.mark.parametrize("database_url", DATABASE_URLS) +def test_database(database_url): + app = get_app(database_url) with TestClient(app) as client: response = client.post( "/notes", json={"text": "buy the milk", "completed": True} @@ -119,8 +136,13 @@ def test_database(): assert response.json() == "buy the milk" -def test_database_executemany(): +@pytest.mark.parametrize("database_url", DATABASE_URLS) +def test_database_executemany(database_url): + app = get_app(database_url) with TestClient(app) as client: + response = client.get("/notes") + print(response.json()) + data = [ {"text": "buy the milk", "completed": True}, {"text": "walk the dog", "completed": False}, @@ -129,6 +151,7 @@ def test_database_executemany(): assert response.status_code == 200 response = client.get("/notes") + print(response.json()) assert response.status_code == 200 assert response.json() == [ {"text": "buy the milk", "completed": True}, @@ -136,11 +159,12 @@ def test_database_executemany(): ] -def test_database_isolated_during_test_cases(): +@pytest.mark.parametrize("database_url", DATABASE_URLS) +def test_database_isolated_during_test_cases(database_url): """ Using `TestClient` as a context manager """ - + app = get_app(database_url) with TestClient(app) as client: response = client.post( "/notes", json={"text": "just one note", "completed": True}