mirror of https://github.com/encode/starlette.git
MySQL driver support (#366)
* initial support for MySQL * Work towards testing both MySQL and Postgres drivers * Linting * STARLETTE_TEST_DATABASE -> STARLETTE_TEST_DATABASES * Parameterize the database tests * Include the MySQL Database URL * MySQL tests should create the database if it doesn't yet exist * Explict port of MySQL database URL * Debugging on Travis * Pass 'args' as a single argument * Fix query compilation for mysql * Lookup record values for MySQL * Coerce MySQL results to correct python_type * cursor.fetchrow should be cursor.fetchone * Fix call to cursor.fetchone() * Fix for fetchval() * Debugging * Nested transaction for MySQL should use SAVEPOINTS * Drop savepoint implementation * Fix SAVEPOINT implementation for MySQL * Coverage for MySQL * Linting * Tweak defaults for MySQL connections
This commit is contained in:
parent
82df72bdd7
commit
3270762c16
|
@ -7,10 +7,11 @@ python:
|
|||
- "3.7-dev"
|
||||
|
||||
env:
|
||||
- STARLETTE_TEST_DATABASE=postgresql://localhost/starlette
|
||||
- STARLETTE_TEST_DATABASES="postgresql://localhost/starlette, mysql://localhost/starlette"
|
||||
|
||||
services:
|
||||
- postgresql
|
||||
- mysql
|
||||
|
||||
install:
|
||||
- pip install -U -r requirements.txt
|
||||
|
|
|
@ -40,10 +40,13 @@ We provide a stand-alone **test script** to run tests in a reliable manner. Run
|
|||
./scripts/test
|
||||
```
|
||||
|
||||
By default, tests involving a database are excluded. To include them, set the `STARLETTE_TEST_DATABASE` environment variable to the URL of a PostgreSQL database, e.g.:
|
||||
By default, tests involving a database are excluded. To include them, set the `STARLETTE_TEST_DATABASES` environment variable. This should be a comma separated string of database URLs.
|
||||
|
||||
```bash
|
||||
export STARLETTE_TEST_DATABASE="postgresql://localhost/starlette"
|
||||
# Any of the following are valid for running the database tests...
|
||||
export STARLETTE_TEST_DATABASES="postgresql://localhost/starlette"
|
||||
export STARLETTE_TEST_DATABASES="mysql://localhost/starlette_test"
|
||||
export STARLETTE_TEST_DATABASES="postgresql://localhost/starlette, mysql://localhost/starlette_test"
|
||||
```
|
||||
|
||||
## Linting
|
||||
|
|
|
@ -1,7 +1,10 @@
|
|||
Starlette includes optional database support. There is currently only a driver
|
||||
for Postgres databases, but MySQL and SQLite support is planned.
|
||||
for Postgres and MySQL databases, but SQLite support is planned.
|
||||
|
||||
Enabling the built-in database support requires `sqlalchemy`, and an appropriate database driver. Currently this means `asyncpg` is a requirement.
|
||||
Enabling the built-in database support requires `sqlalchemy`, and an appropriate database driver.
|
||||
|
||||
PostgreSQL: requires `asyncpg`
|
||||
MySQL: requires `aiomysql`
|
||||
|
||||
The database support is completely optional - you can either include the middleware or not, or you can build alternative kinds of backends instead. It does not
|
||||
include support for an ORM, but it does support using queries built using
|
||||
|
|
|
@ -8,8 +8,13 @@ pyyaml
|
|||
requests
|
||||
ujson
|
||||
|
||||
# Database backends
|
||||
# Async database drivers
|
||||
asyncpg
|
||||
aiomysql
|
||||
|
||||
# Sync database drivers for standard tooling around setup/teardown/migrations.
|
||||
psycopg2-binary
|
||||
pymysql
|
||||
|
||||
# Testing
|
||||
autoflake
|
||||
|
@ -20,7 +25,6 @@ mypy
|
|||
pytest
|
||||
pytest-cov
|
||||
sqlalchemy
|
||||
psycopg2-binary
|
||||
|
||||
# Documentation
|
||||
mkdocs
|
||||
|
|
|
@ -8,11 +8,11 @@ fi
|
|||
export VERSION_SCRIPT="import sys; print('%s.%s' % sys.version_info[0:2])"
|
||||
export PYTHON_VERSION=`python -c "$VERSION_SCRIPT"`
|
||||
|
||||
if [ -z "$STARLETTE_TEST_DATABASE" ] ; then
|
||||
echo "Variable STARLETTE_TEST_DATABASE is unset. Excluding database tests."
|
||||
if [ -z "$STARLETTE_TEST_DATABASES" ] ; then
|
||||
echo "Variable STARLETTE_TEST_DATABASES is unset. Excluding database tests."
|
||||
export IGNORE_MODULES="--ignore tests/test_database.py --cov-config tests/.ignore_database_tests"
|
||||
else
|
||||
echo "Variable STARLETTE_TEST_DATABASE is set"
|
||||
echo "Variable STARLETTE_TEST_DATABASES is set"
|
||||
export IGNORE_MODULES=""
|
||||
fi
|
||||
|
||||
|
|
|
@ -1,5 +1,4 @@
|
|||
from starlette.database.core import (
|
||||
compile,
|
||||
transaction,
|
||||
DatabaseBackend,
|
||||
DatabaseSession,
|
||||
|
@ -7,10 +6,4 @@ from starlette.database.core import (
|
|||
)
|
||||
|
||||
|
||||
__all__ = [
|
||||
"compile",
|
||||
"transaction",
|
||||
"DatabaseBackend",
|
||||
"DatabaseSession",
|
||||
"DatabaseTransaction",
|
||||
]
|
||||
__all__ = ["transaction", "DatabaseBackend", "DatabaseSession", "DatabaseTransaction"]
|
||||
|
|
|
@ -1,34 +1,12 @@
|
|||
import functools
|
||||
import logging
|
||||
import typing
|
||||
from types import TracebackType
|
||||
|
||||
from sqlalchemy.engine.interfaces import Dialect
|
||||
from sqlalchemy.sql import ClauseElement
|
||||
|
||||
from starlette.requests import Request
|
||||
from starlette.responses import Response
|
||||
|
||||
logger = logging.getLogger("starlette.database")
|
||||
|
||||
|
||||
def compile(query: ClauseElement, dialect: Dialect) -> typing.Tuple[str, list]:
|
||||
# query = execute_defaults(query) # default values for Insert/Update
|
||||
compiled = query.compile(dialect=dialect)
|
||||
compiled_params = sorted(compiled.params.items())
|
||||
|
||||
mapping = {key: "$" + str(i) for i, (key, _) in enumerate(compiled_params, start=1)}
|
||||
compiled_query = compiled.string % mapping
|
||||
|
||||
processors = compiled._bind_processors
|
||||
args = [
|
||||
processors[key](val) if key in processors else val
|
||||
for key, val in compiled_params
|
||||
]
|
||||
|
||||
logger.debug(compiled_query)
|
||||
return compiled_query, args
|
||||
|
||||
|
||||
def transaction(func: typing.Callable) -> typing.Callable:
|
||||
@functools.wraps(func)
|
||||
|
|
|
@ -0,0 +1,214 @@
|
|||
import getpass
|
||||
import logging
|
||||
import typing
|
||||
import uuid
|
||||
from types import TracebackType
|
||||
|
||||
import aiomysql
|
||||
from sqlalchemy.dialects.mysql import pymysql
|
||||
from sqlalchemy.engine.interfaces import Dialect
|
||||
from sqlalchemy.sql import ClauseElement
|
||||
|
||||
from starlette.database.core import (
|
||||
DatabaseBackend,
|
||||
DatabaseSession,
|
||||
DatabaseTransaction,
|
||||
)
|
||||
from starlette.datastructures import DatabaseURL
|
||||
|
||||
logger = logging.getLogger("starlette.database")
|
||||
|
||||
|
||||
class MysqlBackend(DatabaseBackend):
|
||||
def __init__(self, database_url: typing.Union[str, DatabaseURL]) -> None:
|
||||
self.database_url = DatabaseURL(database_url)
|
||||
self.dialect = self.get_dialect()
|
||||
self.pool = None
|
||||
|
||||
def get_dialect(self) -> Dialect:
|
||||
return pymysql.dialect(paramstyle="pyformat")
|
||||
|
||||
async def startup(self) -> None:
|
||||
db = self.database_url
|
||||
self.pool = await aiomysql.create_pool(
|
||||
host=db.hostname,
|
||||
port=db.port or 3306,
|
||||
user=db.username or getpass.getuser(),
|
||||
password=db.password,
|
||||
db=db.database,
|
||||
)
|
||||
|
||||
async def shutdown(self) -> None:
|
||||
assert self.pool is not None, "DatabaseBackend is not running"
|
||||
self.pool.close()
|
||||
self.pool = None
|
||||
|
||||
def session(self) -> "MysqlSession":
|
||||
assert self.pool is not None, "DatabaseBackend is not running"
|
||||
return MysqlSession(self.pool, self.dialect)
|
||||
|
||||
|
||||
class Record:
|
||||
def __init__(self, row: tuple, result_columns: tuple) -> None:
|
||||
self._row = row
|
||||
self._result_columns = result_columns
|
||||
self._column_map = {
|
||||
col[0]: (idx, col) for idx, col in enumerate(self._result_columns)
|
||||
}
|
||||
|
||||
def __getitem__(self, key: typing.Union[int, str]) -> typing.Any:
|
||||
if isinstance(key, int):
|
||||
idx = key
|
||||
col = self._result_columns[idx]
|
||||
else:
|
||||
idx, col = self._column_map[key]
|
||||
raw = self._row[idx]
|
||||
return col[-1].python_type(raw)
|
||||
|
||||
|
||||
class MysqlSession(DatabaseSession):
|
||||
def __init__(self, pool: aiomysql.pool.Pool, dialect: Dialect):
|
||||
self.pool = pool
|
||||
self.dialect = dialect
|
||||
self.conn = None
|
||||
self.connection_holders = 0
|
||||
self.has_root_transaction = False
|
||||
|
||||
def _compile(self, query: ClauseElement) -> typing.Tuple[str, list, tuple]:
|
||||
compiled = query.compile(dialect=self.dialect)
|
||||
args = compiled.construct_params()
|
||||
logger.debug(compiled.string, args)
|
||||
return compiled.string, args, compiled._result_columns
|
||||
|
||||
async def fetchall(self, query: ClauseElement) -> typing.Any:
|
||||
query, args, result_columns = self._compile(query)
|
||||
|
||||
conn = await self.acquire_connection()
|
||||
cursor = await conn.cursor()
|
||||
try:
|
||||
await cursor.execute(query, args)
|
||||
rows = await cursor.fetchall()
|
||||
return [Record(row, result_columns) for row in rows]
|
||||
finally:
|
||||
await cursor.close()
|
||||
await self.release_connection()
|
||||
|
||||
async def fetchone(self, query: ClauseElement) -> typing.Any:
|
||||
query, args, result_columns = self._compile(query)
|
||||
|
||||
conn = await self.acquire_connection()
|
||||
cursor = await conn.cursor()
|
||||
try:
|
||||
await cursor.execute(query, args)
|
||||
row = await cursor.fetchone()
|
||||
return Record(row, result_columns)
|
||||
finally:
|
||||
await cursor.close()
|
||||
await self.release_connection()
|
||||
|
||||
async def execute(self, query: ClauseElement) -> None:
|
||||
query, args, result_columns = self._compile(query)
|
||||
|
||||
conn = await self.acquire_connection()
|
||||
cursor = await conn.cursor()
|
||||
try:
|
||||
await cursor.execute(query, args)
|
||||
finally:
|
||||
await cursor.close()
|
||||
await self.release_connection()
|
||||
|
||||
async def executemany(self, query: ClauseElement, values: list) -> None:
|
||||
conn = await self.acquire_connection()
|
||||
cursor = await conn.cursor()
|
||||
try:
|
||||
for item in values:
|
||||
single_query = query.values(item)
|
||||
single_query, args, result_columns = self._compile(single_query)
|
||||
await cursor.execute(single_query, args)
|
||||
finally:
|
||||
await cursor.close()
|
||||
await self.release_connection()
|
||||
|
||||
def transaction(self) -> DatabaseTransaction:
|
||||
return MysqlTransaction(self)
|
||||
|
||||
async def acquire_connection(self) -> aiomysql.Connection:
|
||||
"""
|
||||
Either acquire a connection from the pool, or return the
|
||||
existing connection. Must be followed by a corresponding
|
||||
call to `release_connection`.
|
||||
"""
|
||||
self.connection_holders += 1
|
||||
if self.conn is None:
|
||||
self.conn = await self.pool.acquire()
|
||||
return self.conn
|
||||
|
||||
async def release_connection(self) -> None:
|
||||
self.connection_holders -= 1
|
||||
if self.connection_holders == 0:
|
||||
await self.pool.release(self.conn)
|
||||
self.conn = None
|
||||
|
||||
|
||||
class MysqlTransaction(DatabaseTransaction):
|
||||
def __init__(self, session: MysqlSession):
|
||||
self.session = session
|
||||
self.is_root = False
|
||||
|
||||
async def __aenter__(self) -> None:
|
||||
await self.start()
|
||||
|
||||
async def __aexit__(
|
||||
self,
|
||||
exc_type: typing.Type[BaseException] = None,
|
||||
exc_value: BaseException = None,
|
||||
traceback: TracebackType = None,
|
||||
) -> None:
|
||||
if exc_type is not None:
|
||||
await self.rollback()
|
||||
else:
|
||||
await self.commit()
|
||||
|
||||
async def start(self) -> None:
|
||||
if self.session.has_root_transaction is False:
|
||||
self.session.has_root_transaction = True
|
||||
self.is_root = True
|
||||
|
||||
self.conn = await self.session.acquire_connection()
|
||||
if self.is_root:
|
||||
await self.conn.begin()
|
||||
else:
|
||||
id = str(uuid.uuid4()).replace("-", "_")
|
||||
self.savepoint_name = f"STARLETTE_SAVEPOINT_{id}"
|
||||
cursor = await self.conn.cursor()
|
||||
try:
|
||||
await cursor.execute(f"SAVEPOINT {self.savepoint_name}")
|
||||
finally:
|
||||
await cursor.close()
|
||||
|
||||
async def commit(self) -> None:
|
||||
if self.is_root: # pragma: no cover
|
||||
# In test cases the root transaction is never committed,
|
||||
# since we *always* wrap the test case up in a transaction
|
||||
# and rollback to a clean state at the end.
|
||||
await self.conn.commit()
|
||||
self.session.has_root_transaction = False
|
||||
else:
|
||||
cursor = await self.conn.cursor()
|
||||
try:
|
||||
await cursor.execute(f"RELEASE SAVEPOINT {self.savepoint_name}")
|
||||
finally:
|
||||
await cursor.close()
|
||||
await self.session.release_connection()
|
||||
|
||||
async def rollback(self) -> None:
|
||||
if self.is_root:
|
||||
await self.conn.rollback()
|
||||
self.session.has_root_transaction = False
|
||||
else:
|
||||
cursor = await self.conn.cursor()
|
||||
try:
|
||||
await cursor.execute(f"ROLLBACK TO SAVEPOINT {self.savepoint_name}")
|
||||
finally:
|
||||
await cursor.close()
|
||||
await self.session.release_connection()
|
|
@ -1,3 +1,4 @@
|
|||
import logging
|
||||
import typing
|
||||
from types import TracebackType
|
||||
|
||||
|
@ -10,14 +11,15 @@ from starlette.database.core import (
|
|||
DatabaseBackend,
|
||||
DatabaseSession,
|
||||
DatabaseTransaction,
|
||||
compile,
|
||||
)
|
||||
from starlette.datastructures import DatabaseURL
|
||||
|
||||
logger = logging.getLogger("starlette.database")
|
||||
|
||||
|
||||
class PostgresBackend(DatabaseBackend):
|
||||
def __init__(self, database_url: typing.Union[str, DatabaseURL]) -> None:
|
||||
self.database_url = database_url
|
||||
self.database_url = DatabaseURL(database_url)
|
||||
self.dialect = self.get_dialect()
|
||||
self.pool = None
|
||||
|
||||
|
@ -53,8 +55,26 @@ class PostgresSession(DatabaseSession):
|
|||
self.conn = None
|
||||
self.connection_holders = 0
|
||||
|
||||
def _compile(self, query: ClauseElement) -> typing.Tuple[str, list]:
|
||||
compiled = query.compile(dialect=self.dialect)
|
||||
compiled_params = sorted(compiled.params.items())
|
||||
|
||||
mapping = {
|
||||
key: "$" + str(i) for i, (key, _) in enumerate(compiled_params, start=1)
|
||||
}
|
||||
compiled_query = compiled.string % mapping
|
||||
|
||||
processors = compiled._bind_processors
|
||||
args = [
|
||||
processors[key](val) if key in processors else val
|
||||
for key, val in compiled_params
|
||||
]
|
||||
|
||||
logger.debug(compiled_query, args)
|
||||
return compiled_query, args
|
||||
|
||||
async def fetchall(self, query: ClauseElement) -> typing.Any:
|
||||
query, args = compile(query, dialect=self.dialect)
|
||||
query, args = self._compile(query)
|
||||
|
||||
conn = await self.acquire_connection()
|
||||
try:
|
||||
|
@ -63,7 +83,7 @@ class PostgresSession(DatabaseSession):
|
|||
await self.release_connection()
|
||||
|
||||
async def fetchone(self, query: ClauseElement) -> typing.Any:
|
||||
query, args = compile(query, dialect=self.dialect)
|
||||
query, args = self._compile(query)
|
||||
|
||||
conn = await self.acquire_connection()
|
||||
try:
|
||||
|
@ -72,7 +92,7 @@ class PostgresSession(DatabaseSession):
|
|||
await self.release_connection()
|
||||
|
||||
async def execute(self, query: ClauseElement) -> None:
|
||||
query, args = compile(query, dialect=self.dialect)
|
||||
query, args = self._compile(query)
|
||||
|
||||
conn = await self.acquire_connection()
|
||||
try:
|
||||
|
@ -88,7 +108,7 @@ class PostgresSession(DatabaseSession):
|
|||
# using the same prepared statement.
|
||||
for item in values:
|
||||
single_query = query.values(item)
|
||||
single_query, args = compile(single_query, dialect=self.dialect)
|
||||
single_query, args = self._compile(single_query)
|
||||
await conn.execute(single_query, *args)
|
||||
finally:
|
||||
await self.release_connection()
|
||||
|
|
|
@ -135,6 +135,9 @@ class URL:
|
|||
|
||||
|
||||
class DatabaseURL(URL):
|
||||
def __init__(self, url: typing.Union[str, URL]):
|
||||
return super().__init__(str(url))
|
||||
|
||||
@property
|
||||
def database(self) -> str:
|
||||
return self.path.lstrip("/")
|
||||
|
|
|
@ -1,7 +1,5 @@
|
|||
import typing
|
||||
|
||||
import asyncpg
|
||||
|
||||
from starlette.database.core import (
|
||||
DatabaseBackend,
|
||||
DatabaseSession,
|
||||
|
@ -29,12 +27,21 @@ class DatabaseMiddleware:
|
|||
) -> DatabaseBackend:
|
||||
if isinstance(database_url, str):
|
||||
database_url = DatabaseURL(database_url)
|
||||
assert (
|
||||
database_url.dialect == "postgresql"
|
||||
), "Currently only postgresql is supported."
|
||||
from starlette.database.postgres import PostgresBackend
|
||||
assert database_url.dialect in [
|
||||
"postgresql",
|
||||
"mysql",
|
||||
], "Currently only postgresql and mysql are supported."
|
||||
|
||||
return PostgresBackend(database_url)
|
||||
if database_url.dialect == "postgresql":
|
||||
from starlette.database.postgres import PostgresBackend
|
||||
|
||||
return PostgresBackend(database_url)
|
||||
|
||||
else:
|
||||
assert database_url.dialect == "mysql"
|
||||
from starlette.database.mysql import MysqlBackend
|
||||
|
||||
return MysqlBackend(database_url)
|
||||
|
||||
def __call__(self, scope: Scope) -> ASGIInstance:
|
||||
if scope["type"] == "lifespan":
|
||||
|
|
|
@ -5,14 +5,15 @@ import sqlalchemy
|
|||
|
||||
from starlette.applications import Starlette
|
||||
from starlette.database import transaction
|
||||
from starlette.datastructures import CommaSeparatedStrings, DatabaseURL
|
||||
from starlette.middleware.database import DatabaseMiddleware
|
||||
from starlette.responses import JSONResponse
|
||||
from starlette.testclient import TestClient
|
||||
|
||||
try:
|
||||
DATABASE_URL = os.environ["STARLETTE_TEST_DATABASE"]
|
||||
DATABASE_URLS = CommaSeparatedStrings(os.environ["STARLETTE_TEST_DATABASES"])
|
||||
except KeyError: # pragma: no cover
|
||||
pytest.skip("DATABASE_URL is not set", allow_module_level=True)
|
||||
pytest.skip("STARLETTE_TEST_DATABASES is not set", allow_module_level=True)
|
||||
|
||||
metadata = sqlalchemy.MetaData()
|
||||
|
||||
|
@ -20,71 +21,87 @@ notes = sqlalchemy.Table(
|
|||
"notes",
|
||||
metadata,
|
||||
sqlalchemy.Column("id", sqlalchemy.Integer, primary_key=True),
|
||||
sqlalchemy.Column("text", sqlalchemy.String),
|
||||
sqlalchemy.Column("text", sqlalchemy.String(length=100)),
|
||||
sqlalchemy.Column("completed", sqlalchemy.Boolean),
|
||||
)
|
||||
|
||||
app = Starlette()
|
||||
app.add_middleware(
|
||||
DatabaseMiddleware, database_url=DATABASE_URL, rollback_on_shutdown=True
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True, scope="module")
|
||||
def create_test_database():
|
||||
engine = sqlalchemy.create_engine(DATABASE_URL)
|
||||
metadata.create_all(engine)
|
||||
def create_test_databases():
|
||||
engines = {}
|
||||
for url in DATABASE_URLS:
|
||||
db_url = DatabaseURL(url)
|
||||
if db_url.dialect == "mysql":
|
||||
# Use the 'pymysql' driver for creating the database & tables.
|
||||
url = str(db_url.replace(scheme="mysql+pymysql"))
|
||||
db_name = db_url.database
|
||||
db_url = db_url.replace(scheme="mysql+pymysql", database="")
|
||||
engine = sqlalchemy.create_engine(str(db_url))
|
||||
engine.execute("CREATE DATABASE IF NOT EXISTS " + db_name)
|
||||
|
||||
engines[url] = sqlalchemy.create_engine(url)
|
||||
metadata.create_all(engines[url])
|
||||
|
||||
yield
|
||||
engine.execute("DROP TABLE notes")
|
||||
|
||||
for engine in engines.values():
|
||||
engine.execute("DROP TABLE notes")
|
||||
|
||||
|
||||
@app.route("/notes", methods=["GET"])
|
||||
async def list_notes(request):
|
||||
query = notes.select()
|
||||
results = await request.database.fetchall(query)
|
||||
content = [
|
||||
{"text": result["text"], "completed": result["completed"]} for result in results
|
||||
]
|
||||
return JSONResponse(content)
|
||||
def get_app(database_url):
|
||||
app = Starlette()
|
||||
app.add_middleware(
|
||||
DatabaseMiddleware, database_url=database_url, rollback_on_shutdown=True
|
||||
)
|
||||
|
||||
@app.route("/notes", methods=["GET"])
|
||||
async def list_notes(request):
|
||||
query = notes.select()
|
||||
results = await request.database.fetchall(query)
|
||||
content = [
|
||||
{"text": result["text"], "completed": result["completed"]}
|
||||
for result in results
|
||||
]
|
||||
return JSONResponse(content)
|
||||
|
||||
@app.route("/notes", methods=["POST"])
|
||||
@transaction
|
||||
async def add_note(request):
|
||||
data = await request.json()
|
||||
query = notes.insert().values(text=data["text"], completed=data["completed"])
|
||||
await request.database.execute(query)
|
||||
if "raise_exc" in request.query_params:
|
||||
raise RuntimeError()
|
||||
return JSONResponse({"text": data["text"], "completed": data["completed"]})
|
||||
|
||||
@app.route("/notes/bulk_create", methods=["POST"])
|
||||
async def bulk_create_notes(request):
|
||||
data = await request.json()
|
||||
query = notes.insert()
|
||||
await request.database.executemany(query, data)
|
||||
return JSONResponse({"notes": data})
|
||||
|
||||
@app.route("/notes/{note_id:int}", methods=["GET"])
|
||||
async def read_note(request):
|
||||
note_id = request.path_params["note_id"]
|
||||
query = notes.select().where(notes.c.id == note_id)
|
||||
result = await request.database.fetchone(query)
|
||||
content = {"text": result["text"], "completed": result["completed"]}
|
||||
return JSONResponse(content)
|
||||
|
||||
@app.route("/notes/{note_id:int}/text", methods=["GET"])
|
||||
async def read_note_text(request):
|
||||
note_id = request.path_params["note_id"]
|
||||
query = sqlalchemy.select([notes.c.text]).where(notes.c.id == note_id)
|
||||
text = await request.database.fetchval(query)
|
||||
return JSONResponse(text)
|
||||
|
||||
return app
|
||||
|
||||
|
||||
@app.route("/notes", methods=["POST"])
|
||||
@transaction
|
||||
async def add_note(request):
|
||||
data = await request.json()
|
||||
query = notes.insert().values(text=data["text"], completed=data["completed"])
|
||||
await request.database.execute(query)
|
||||
if "raise_exc" in request.query_params:
|
||||
raise RuntimeError()
|
||||
return JSONResponse({"text": data["text"], "completed": data["completed"]})
|
||||
|
||||
|
||||
@app.route("/notes/bulk_create", methods=["POST"])
|
||||
async def bulk_create_notes(request):
|
||||
data = await request.json()
|
||||
query = notes.insert()
|
||||
await request.database.executemany(query, data)
|
||||
return JSONResponse({"notes": data})
|
||||
|
||||
|
||||
@app.route("/notes/{note_id:int}", methods=["GET"])
|
||||
async def read_note(request):
|
||||
note_id = request.path_params["note_id"]
|
||||
query = notes.select().where(notes.c.id == note_id)
|
||||
result = await request.database.fetchone(query)
|
||||
content = {"text": result["text"], "completed": result["completed"]}
|
||||
return JSONResponse(content)
|
||||
|
||||
|
||||
@app.route("/notes/{note_id:int}/text", methods=["GET"])
|
||||
async def read_note_text(request):
|
||||
note_id = request.path_params["note_id"]
|
||||
query = sqlalchemy.select([notes.c.text]).where(notes.c.id == note_id)
|
||||
text = await request.database.fetchval(query)
|
||||
return JSONResponse(text)
|
||||
|
||||
|
||||
def test_database():
|
||||
@pytest.mark.parametrize("database_url", DATABASE_URLS)
|
||||
def test_database(database_url):
|
||||
app = get_app(database_url)
|
||||
with TestClient(app) as client:
|
||||
response = client.post(
|
||||
"/notes", json={"text": "buy the milk", "completed": True}
|
||||
|
@ -119,8 +136,13 @@ def test_database():
|
|||
assert response.json() == "buy the milk"
|
||||
|
||||
|
||||
def test_database_executemany():
|
||||
@pytest.mark.parametrize("database_url", DATABASE_URLS)
|
||||
def test_database_executemany(database_url):
|
||||
app = get_app(database_url)
|
||||
with TestClient(app) as client:
|
||||
response = client.get("/notes")
|
||||
print(response.json())
|
||||
|
||||
data = [
|
||||
{"text": "buy the milk", "completed": True},
|
||||
{"text": "walk the dog", "completed": False},
|
||||
|
@ -129,6 +151,7 @@ def test_database_executemany():
|
|||
assert response.status_code == 200
|
||||
|
||||
response = client.get("/notes")
|
||||
print(response.json())
|
||||
assert response.status_code == 200
|
||||
assert response.json() == [
|
||||
{"text": "buy the milk", "completed": True},
|
||||
|
@ -136,11 +159,12 @@ def test_database_executemany():
|
|||
]
|
||||
|
||||
|
||||
def test_database_isolated_during_test_cases():
|
||||
@pytest.mark.parametrize("database_url", DATABASE_URLS)
|
||||
def test_database_isolated_during_test_cases(database_url):
|
||||
"""
|
||||
Using `TestClient` as a context manager
|
||||
"""
|
||||
|
||||
app = get_app(database_url)
|
||||
with TestClient(app) as client:
|
||||
response = client.post(
|
||||
"/notes", json={"text": "just one note", "completed": True}
|
||||
|
|
Loading…
Reference in New Issue