starlette/tests/test_database.py

187 lines
6.3 KiB
Python
Raw Normal View History

import os
import pytest
import sqlalchemy
from starlette.applications import Starlette
from starlette.database import transaction
from starlette.datastructures import CommaSeparatedStrings, DatabaseURL
from starlette.middleware.database import DatabaseMiddleware
from starlette.responses import JSONResponse
from starlette.testclient import TestClient
try:
DATABASE_URLS = CommaSeparatedStrings(os.environ["STARLETTE_TEST_DATABASES"])
except KeyError: # pragma: no cover
pytest.skip("STARLETTE_TEST_DATABASES is not set", allow_module_level=True)
metadata = sqlalchemy.MetaData()
notes = sqlalchemy.Table(
"notes",
metadata,
sqlalchemy.Column("id", sqlalchemy.Integer, primary_key=True),
sqlalchemy.Column("text", sqlalchemy.String(length=100)),
sqlalchemy.Column("completed", sqlalchemy.Boolean),
)
@pytest.fixture(autouse=True, scope="module")
def create_test_databases():
engines = {}
for url in DATABASE_URLS:
db_url = DatabaseURL(url)
if db_url.dialect == "mysql":
#  Use the 'pymysql' driver for creating the database & tables.
url = str(db_url.replace(scheme="mysql+pymysql"))
db_name = db_url.database
db_url = db_url.replace(scheme="mysql+pymysql", database="")
engine = sqlalchemy.create_engine(str(db_url))
engine.execute("CREATE DATABASE IF NOT EXISTS " + db_name)
engines[url] = sqlalchemy.create_engine(url)
metadata.create_all(engines[url])
yield
for engine in engines.values():
engine.execute("DROP TABLE notes")
def get_app(database_url):
app = Starlette()
app.add_middleware(
DatabaseMiddleware, database_url=database_url, rollback_on_shutdown=True
)
@app.route("/notes", methods=["GET"])
async def list_notes(request):
query = notes.select()
results = await request.database.fetchall(query)
content = [
{"text": result["text"], "completed": result["completed"]}
for result in results
]
return JSONResponse(content)
@app.route("/notes", methods=["POST"])
@transaction
async def add_note(request):
data = await request.json()
query = notes.insert().values(text=data["text"], completed=data["completed"])
await request.database.execute(query)
if "raise_exc" in request.query_params:
raise RuntimeError()
return JSONResponse({"text": data["text"], "completed": data["completed"]})
@app.route("/notes/bulk_create", methods=["POST"])
async def bulk_create_notes(request):
data = await request.json()
query = notes.insert()
await request.database.executemany(query, data)
return JSONResponse({"notes": data})
@app.route("/notes/{note_id:int}", methods=["GET"])
async def read_note(request):
note_id = request.path_params["note_id"]
query = notes.select().where(notes.c.id == note_id)
result = await request.database.fetchone(query)
content = {"text": result["text"], "completed": result["completed"]}
return JSONResponse(content)
@app.route("/notes/{note_id:int}/text", methods=["GET"])
async def read_note_text(request):
note_id = request.path_params["note_id"]
query = sqlalchemy.select([notes.c.text]).where(notes.c.id == note_id)
text = await request.database.fetchval(query)
return JSONResponse(text)
return app
@pytest.mark.parametrize("database_url", DATABASE_URLS)
def test_database(database_url):
app = get_app(database_url)
with TestClient(app) as client:
response = client.post(
"/notes", json={"text": "buy the milk", "completed": True}
)
assert response.status_code == 200
with pytest.raises(RuntimeError):
response = client.post(
"/notes",
json={"text": "you wont see me", "completed": False},
params={"raise_exc": "true"},
)
response = client.post(
"/notes", json={"text": "walk the dog", "completed": False}
)
assert response.status_code == 200
response = client.get("/notes")
assert response.status_code == 200
assert response.json() == [
{"text": "buy the milk", "completed": True},
{"text": "walk the dog", "completed": False},
]
response = client.get("/notes/1")
assert response.status_code == 200
assert response.json() == {"text": "buy the milk", "completed": True}
response = client.get("/notes/1/text")
assert response.status_code == 200
assert response.json() == "buy the milk"
@pytest.mark.parametrize("database_url", DATABASE_URLS)
def test_database_executemany(database_url):
app = get_app(database_url)
2018-12-10 14:37:30 +00:00
with TestClient(app) as client:
response = client.get("/notes")
print(response.json())
2018-12-10 14:37:30 +00:00
data = [
{"text": "buy the milk", "completed": True},
{"text": "walk the dog", "completed": False},
]
response = client.post("/notes/bulk_create", json=data)
assert response.status_code == 200
response = client.get("/notes")
print(response.json())
2018-12-10 14:37:30 +00:00
assert response.status_code == 200
assert response.json() == [
{"text": "buy the milk", "completed": True},
{"text": "walk the dog", "completed": False},
]
@pytest.mark.parametrize("database_url", DATABASE_URLS)
def test_database_isolated_during_test_cases(database_url):
"""
Using `TestClient` as a context manager
"""
app = get_app(database_url)
with TestClient(app) as client:
response = client.post(
"/notes", json={"text": "just one note", "completed": True}
)
assert response.status_code == 200
response = client.get("/notes")
assert response.status_code == 200
assert response.json() == [{"text": "just one note", "completed": True}]
with TestClient(app) as client:
response = client.post(
"/notes", json={"text": "just one note", "completed": True}
)
assert response.status_code == 200
response = client.get("/notes")
assert response.status_code == 200
assert response.json() == [{"text": "just one note", "completed": True}]