diff --git a/starlette/__init__.py b/starlette/__init__.py index f5b77301..a25765c3 100644 --- a/starlette/__init__.py +++ b/starlette/__init__.py @@ -1 +1 @@ -__version__ = "0.9.7" +__version__ = "0.9.8" diff --git a/starlette/database/core.py b/starlette/database/core.py index f2d01418..bf53bc69 100644 --- a/starlette/database/core.py +++ b/starlette/database/core.py @@ -61,7 +61,10 @@ class DatabaseSession: row = await self.fetchone(query) return row[index] - async def execute(self, query: ClauseElement) -> typing.Any: + async def execute(self, query: ClauseElement) -> None: + raise NotImplementedError() # pragma: no cover + + async def executemany(self, query: ClauseElement, values: list) -> None: raise NotImplementedError() # pragma: no cover def transaction(self) -> "DatabaseTransaction": diff --git a/starlette/database/postgres.py b/starlette/database/postgres.py index 63b141e8..e2e0530b 100644 --- a/starlette/database/postgres.py +++ b/starlette/database/postgres.py @@ -71,12 +71,25 @@ class PostgresSession(DatabaseSession): finally: await self.release_connection() - async def execute(self, query: ClauseElement) -> typing.Any: + async def execute(self, query: ClauseElement) -> None: query, args = compile(query, dialect=self.dialect) conn = await self.acquire_connection() try: - return await conn.execute(query, *args) + await conn.execute(query, *args) + finally: + await self.release_connection() + + async def executemany(self, query: ClauseElement, values: list) -> None: + conn = await self.acquire_connection() + try: + # asyncpg uses prepared statements under the hood, so we just + # loop through multiple executes here, which should all end up + # using the same prepared statement. + for item in values: + single_query = query.values(item) + single_query, args = compile(single_query, dialect=self.dialect) + await conn.execute(single_query, *args) finally: await self.release_connection() diff --git a/tests/test_database.py b/tests/test_database.py index 6bdd9cab..803772a0 100644 --- a/tests/test_database.py +++ b/tests/test_database.py @@ -56,6 +56,14 @@ async def add_note(request): return JSONResponse({"text": data["text"], "completed": data["completed"]}) +@app.route("/notes/bulk_create", methods=["POST"]) +async def bulk_create_notes(request): + data = await request.json() + query = notes.insert() + await request.database.executemany(query, data) + return JSONResponse({"notes": data}) + + @app.route("/notes/{note_id:int}", methods=["GET"]) async def read_note(request): note_id = request.path_params["note_id"] @@ -108,6 +116,23 @@ def test_database(): assert response.json() == "buy the milk" +def test_database_executemany(): + with TestClient(app) as client: + data = [ + {"text": "buy the milk", "completed": True}, + {"text": "walk the dog", "completed": False}, + ] + response = client.post("/notes/bulk_create", json=data) + assert response.status_code == 200 + + response = client.get("/notes") + assert response.status_code == 200 + assert response.json() == [ + {"text": "buy the milk", "completed": True}, + {"text": "walk the dog", "completed": False}, + ] + + def test_database_isolated_during_test_cases(): """ Using `TestClient` as a context manager