diff --git a/starlette/schemas.py b/starlette/schemas.py index c05e34f1..9c855bd3 100644 --- a/starlette/schemas.py +++ b/starlette/schemas.py @@ -1,6 +1,7 @@ import inspect import typing +from starlette.requests import Request from starlette.responses import Response from starlette.routing import BaseRoute, Route @@ -74,6 +75,11 @@ class BaseSchemaGenerator: docstring = func_or_method.__doc__ return yaml.safe_load(docstring) if docstring else {} + def OpenAPIResponse(self, request: Request) -> Response: + routes = request.app.routes + schema = self.get_schema(routes=routes) + return OpenAPIResponse(schema) + class SchemaGenerator(BaseSchemaGenerator): def __init__(self, base_schema: dict) -> None: diff --git a/tests/test_schemas.py b/tests/test_schemas.py index 12501761..3f52e08f 100644 --- a/tests/test_schemas.py +++ b/tests/test_schemas.py @@ -3,11 +3,12 @@ from starlette.endpoints import HTTPEndpoint from starlette.schemas import OpenAPIResponse, SchemaGenerator from starlette.testclient import TestClient -app = Starlette() -app.schema_generator = SchemaGenerator( +schemas = SchemaGenerator( {"openapi": "3.0.0", "info": {"title": "Example API", "version": "1.0"}} ) +app = Starlette() + @app.websocket_route("/ws") def ws(session): @@ -64,11 +65,12 @@ class OrganisationsEndpoint(HTTPEndpoint): @app.route("/schema", methods=["GET"], include_in_schema=False) def schema(request): - return OpenAPIResponse(app.schema) + return schemas.OpenAPIResponse(request=request) def test_schema_generation(): - assert app.schema == { + schema = schemas.get_schema(routes=app.routes) + assert schema == { "openapi": "3.0.0", "info": {"title": "Example API", "version": "1.0"}, "paths": { @@ -109,6 +111,11 @@ def test_schema_generation(): } +def test_schema_generation_legacy(): + app.schema_generator = schemas + assert app.schema == schemas.get_schema(routes=app.routes) + + EXPECTED_SCHEMA = """ info: title: Example API