mirror of https://github.com/encode/starlette.git
Add executemany (#259)
This commit is contained in:
parent
256b6245f7
commit
1e1f3bab46
|
@ -1 +1 @@
|
|||
__version__ = "0.9.7"
|
||||
__version__ = "0.9.8"
|
||||
|
|
|
@ -61,7 +61,10 @@ class DatabaseSession:
|
|||
row = await self.fetchone(query)
|
||||
return row[index]
|
||||
|
||||
async def execute(self, query: ClauseElement) -> typing.Any:
|
||||
async def execute(self, query: ClauseElement) -> None:
|
||||
raise NotImplementedError() # pragma: no cover
|
||||
|
||||
async def executemany(self, query: ClauseElement, values: list) -> None:
|
||||
raise NotImplementedError() # pragma: no cover
|
||||
|
||||
def transaction(self) -> "DatabaseTransaction":
|
||||
|
|
|
@ -71,12 +71,25 @@ class PostgresSession(DatabaseSession):
|
|||
finally:
|
||||
await self.release_connection()
|
||||
|
||||
async def execute(self, query: ClauseElement) -> typing.Any:
|
||||
async def execute(self, query: ClauseElement) -> None:
|
||||
query, args = compile(query, dialect=self.dialect)
|
||||
|
||||
conn = await self.acquire_connection()
|
||||
try:
|
||||
return await conn.execute(query, *args)
|
||||
await conn.execute(query, *args)
|
||||
finally:
|
||||
await self.release_connection()
|
||||
|
||||
async def executemany(self, query: ClauseElement, values: list) -> None:
|
||||
conn = await self.acquire_connection()
|
||||
try:
|
||||
# asyncpg uses prepared statements under the hood, so we just
|
||||
# loop through multiple executes here, which should all end up
|
||||
# using the same prepared statement.
|
||||
for item in values:
|
||||
single_query = query.values(item)
|
||||
single_query, args = compile(single_query, dialect=self.dialect)
|
||||
await conn.execute(single_query, *args)
|
||||
finally:
|
||||
await self.release_connection()
|
||||
|
||||
|
|
|
@ -56,6 +56,14 @@ async def add_note(request):
|
|||
return JSONResponse({"text": data["text"], "completed": data["completed"]})
|
||||
|
||||
|
||||
@app.route("/notes/bulk_create", methods=["POST"])
|
||||
async def bulk_create_notes(request):
|
||||
data = await request.json()
|
||||
query = notes.insert()
|
||||
await request.database.executemany(query, data)
|
||||
return JSONResponse({"notes": data})
|
||||
|
||||
|
||||
@app.route("/notes/{note_id:int}", methods=["GET"])
|
||||
async def read_note(request):
|
||||
note_id = request.path_params["note_id"]
|
||||
|
@ -108,6 +116,23 @@ def test_database():
|
|||
assert response.json() == "buy the milk"
|
||||
|
||||
|
||||
def test_database_executemany():
|
||||
with TestClient(app) as client:
|
||||
data = [
|
||||
{"text": "buy the milk", "completed": True},
|
||||
{"text": "walk the dog", "completed": False},
|
||||
]
|
||||
response = client.post("/notes/bulk_create", json=data)
|
||||
assert response.status_code == 200
|
||||
|
||||
response = client.get("/notes")
|
||||
assert response.status_code == 200
|
||||
assert response.json() == [
|
||||
{"text": "buy the milk", "completed": True},
|
||||
{"text": "walk the dog", "completed": False},
|
||||
]
|
||||
|
||||
|
||||
def test_database_isolated_during_test_cases():
|
||||
"""
|
||||
Using `TestClient` as a context manager
|
||||
|
|
Loading…
Reference in New Issue