diff --git a/starlette/graphql.py b/starlette/graphql.py index df9aaf4d..f01ac819 100644 --- a/starlette/graphql.py +++ b/starlette/graphql.py @@ -73,7 +73,7 @@ class GraphQLApp: status_code=status.HTTP_400_BAD_REQUEST, ) - result = await self.execute(query, variables) + result = await self.execute(request, query, variables) error_data = ( [format_graphql_error(err) for err in result.errors] if result.errors @@ -85,7 +85,11 @@ class GraphQLApp: ) return JSONResponse(response_data, status_code=status_code) - async def execute(self, query, variables=None, operation_name=None): # type: ignore + async def execute( # type: ignore + self, request, query, variables=None, operation_name=None + ): + context = dict(request=request) + if self.is_async: return await self.schema.execute( query, @@ -93,6 +97,7 @@ class GraphQLApp: operation_name=operation_name, executor=self.executor, return_promise=True, + context=context, ) else: return await run_in_threadpool( @@ -100,6 +105,7 @@ class GraphQLApp: query, variables=variables, operation_name=operation_name, + context=context, ) async def handle_graphiql(self, request: Request) -> Response: diff --git a/tests/test_graphql.py b/tests/test_graphql.py index 676c8be8..f4345f60 100644 --- a/tests/test_graphql.py +++ b/tests/test_graphql.py @@ -2,16 +2,35 @@ import graphene from graphql.execution.executors.asyncio import AsyncioExecutor from starlette.applications import Starlette +from starlette.datastructures import Headers from starlette.graphql import GraphQLApp from starlette.testclient import TestClient +class FakeAuthMiddleware: + def __init__(self, app) -> None: + self.app = app + + def __call__(self, scope): + headers = Headers(scope=scope) + scope["user"] = "Jane" if headers.get("Authorization") == "Bearer 123" else None + return self.app(scope) + + class Query(graphene.ObjectType): hello = graphene.String(name=graphene.String(default_value="stranger")) + whoami = graphene.String() def resolve_hello(self, info, name): return "Hello " + name + def resolve_whoami(self, info): + return ( + "a mystery" + if info.context["request"]["user"] is None + else info.context["request"]["user"] + ) + schema = graphene.Schema(query=Query) app = GraphQLApp(schema=schema) @@ -91,6 +110,18 @@ def test_add_graphql_route(): assert response.json() == {"data": {"hello": "Hello stranger"}, "errors": None} +def test_graphql_context(): + app = Starlette() + app.add_middleware(FakeAuthMiddleware) + app.add_route("/", GraphQLApp(schema=schema)) + client = TestClient(app) + response = client.post( + "/", json={"query": "{ whoami }"}, headers={"Authorization": "Bearer 123"} + ) + assert response.status_code == 200 + assert response.json() == {"data": {"whoami": "Jane"}, "errors": None} + + class ASyncQuery(graphene.ObjectType): hello = graphene.String(name=graphene.String(default_value="stranger"))