starlette/tests/test_database.py

135 lines
4.1 KiB
Python

import os
import pytest
import sqlalchemy
from starlette.applications import Starlette
from starlette.database import transaction
from starlette.middleware.database import DatabaseMiddleware
from starlette.responses import JSONResponse
from starlette.testclient import TestClient
DATABASE_URL = os.environ["STARLETTE_TEST_DATABASE"]
metadata = sqlalchemy.MetaData()
notes = sqlalchemy.Table(
"notes",
metadata,
sqlalchemy.Column("id", sqlalchemy.Integer, primary_key=True),
sqlalchemy.Column("text", sqlalchemy.String),
sqlalchemy.Column("completed", sqlalchemy.Boolean),
)
app = Starlette()
app.add_middleware(
DatabaseMiddleware, database_url=DATABASE_URL, rollback_on_shutdown=True
)
@pytest.fixture(autouse=True, scope="module")
def create_test_database():
engine = sqlalchemy.create_engine(DATABASE_URL)
metadata.create_all(engine)
yield
engine.execute("DROP TABLE notes")
@app.route("/notes", methods=["GET"])
async def list_notes(request):
query = notes.select()
results = await request.database.fetchall(query)
content = [
{"text": result["text"], "completed": result["completed"]} for result in results
]
return JSONResponse(content)
@app.route("/notes", methods=["POST"])
@transaction
async def add_note(request):
data = await request.json()
query = notes.insert().values(text=data["text"], completed=data["completed"])
await request.database.execute(query)
if "raise_exc" in request.query_params:
raise RuntimeError()
return JSONResponse({"text": data["text"], "completed": data["completed"]})
@app.route("/notes/{note_id:int}", methods=["GET"])
async def read_note(request):
note_id = request.path_params["note_id"]
query = notes.select().where(notes.c.id == note_id)
result = await request.database.fetchone(query)
content = {"text": result["text"], "completed": result["completed"]}
return JSONResponse(content)
@app.route("/notes/{note_id:int}/text", methods=["GET"])
async def read_note_text(request):
note_id = request.path_params["note_id"]
query = sqlalchemy.select([notes.c.text]).where(notes.c.id == note_id)
text = await request.database.fetchval(query)
return JSONResponse(text)
def test_database():
with TestClient(app) as client:
response = client.post(
"/notes", json={"text": "buy the milk", "completed": True}
)
assert response.status_code == 200
with pytest.raises(RuntimeError):
response = client.post(
"/notes",
json={"text": "you wont see me", "completed": False},
params={"raise_exc": "true"},
)
response = client.post(
"/notes", json={"text": "walk the dog", "completed": False}
)
assert response.status_code == 200
response = client.get("/notes")
assert response.status_code == 200
assert response.json() == [
{"text": "buy the milk", "completed": True},
{"text": "walk the dog", "completed": False},
]
response = client.get("/notes/1")
assert response.status_code == 200
assert response.json() == {"text": "buy the milk", "completed": True}
response = client.get("/notes/1/text")
assert response.status_code == 200
assert response.json() == "buy the milk"
def test_database_isolated_during_test_cases():
"""
Using `TestClient` as a context manager
"""
with TestClient(app) as client:
response = client.post(
"/notes", json={"text": "just one note", "completed": True}
)
assert response.status_code == 200
response = client.get("/notes")
assert response.status_code == 200
assert response.json() == [{"text": "just one note", "completed": True}]
with TestClient(app) as client:
response = client.post(
"/notes", json={"text": "just one note", "completed": True}
)
assert response.status_code == 200
response = client.get("/notes")
assert response.status_code == 200
assert response.json() == [{"text": "just one note", "completed": True}]