From 590c77e03df0be328a3805d4c06d5151ae23a9a3 Mon Sep 17 00:00:00 2001 From: Jose Eduardo Date: Tue, 19 Feb 2019 12:49:30 +0000 Subject: [PATCH] Improvements to schema generation (#336) * Include mounted paths in schemas (part of #172) * Remove unnecessary indirection * Refactor: cleaner interface, return a dict always --- starlette/schemas.py | 48 +++++++++++++++++++++++++++----- tests/test_schemas.py | 65 +++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 106 insertions(+), 7 deletions(-) diff --git a/starlette/schemas.py b/starlette/schemas.py index 9c855bd3..6d4b119f 100644 --- a/starlette/schemas.py +++ b/starlette/schemas.py @@ -3,7 +3,7 @@ import typing from starlette.requests import Request from starlette.responses import Response -from starlette.routing import BaseRoute, Route +from starlette.routing import BaseRoute, Mount, Route try: import yaml @@ -48,10 +48,22 @@ class BaseSchemaGenerator: endpoints_info: list = [] for route in routes: - if not isinstance(route, Route) or not route.include_in_schema: + if isinstance(route, Mount): + routes = route.routes or [] + sub_endpoints = [ + EndpointInfo( + path="".join((route.path, sub_endpoint.path)), + http_method=sub_endpoint.http_method, + func=sub_endpoint.func, + ) + for sub_endpoint in self.get_endpoints(routes) + ] + endpoints_info.extend(sub_endpoints) + + elif not isinstance(route, Route) or not route.include_in_schema: continue - if inspect.isfunction(route.endpoint) or inspect.ismethod(route.endpoint): + elif inspect.isfunction(route.endpoint) or inspect.ismethod(route.endpoint): for method in route.methods or ["GET"]: if method == "HEAD": continue @@ -66,6 +78,7 @@ class BaseSchemaGenerator: endpoints_info.append( EndpointInfo(route.path, method.lower(), func) ) + return endpoints_info def parse_docstring(self, func_or_method: typing.Callable) -> dict: @@ -73,7 +86,22 @@ class BaseSchemaGenerator: Given a function, parse the docstring as YAML and return a dictionary of info. """ docstring = func_or_method.__doc__ - return yaml.safe_load(docstring) if docstring else {} + if not docstring: + return {} + + # We support having regular docstrings before the schema + # definition. Here we return just the schema part from + # the docstring. + docstring = docstring.split("---")[-1] + + parsed = yaml.safe_load(docstring) + + if not isinstance(parsed, dict): + # A regular docstring (not yaml formatted) can return + # a simple string here, which wouldn't follow the schema. + return {} + + return parsed def OpenAPIResponse(self, request: Request) -> Response: routes = request.app.routes @@ -91,9 +119,15 @@ class SchemaGenerator(BaseSchemaGenerator): endpoints_info = self.get_endpoints(routes) for endpoint in endpoints_info: + + parsed = self.parse_docstring(endpoint.func) + + if not parsed: + continue + if endpoint.path not in schema["paths"]: schema["paths"][endpoint.path] = {} - schema["paths"][endpoint.path][endpoint.http_method] = self.parse_docstring( - endpoint.func - ) + + schema["paths"][endpoint.path][endpoint.http_method] = parsed + return schema diff --git a/tests/test_schemas.py b/tests/test_schemas.py index 3f52e08f..ae06003d 100644 --- a/tests/test_schemas.py +++ b/tests/test_schemas.py @@ -10,6 +10,10 @@ schemas = SchemaGenerator( app = Starlette() +subapp = Starlette() +app.mount("/subapp", subapp) + + @app.websocket_route("/ws") def ws(session): """ws""" @@ -63,6 +67,43 @@ class OrganisationsEndpoint(HTTPEndpoint): pass # pragma: no cover +@app.route("/regular-docstring-and-schema") +def regular_docstring_and_schema(request): + """ + This a regular docstring example (not included in schema) + + --- + + responses: + 200: + description: This is included in the schema. + """ + pass # pragma: no cover + + +@app.route("/regular-docstring") +def regular_docstring(request): + """ + This a regular docstring example (not included in schema) + """ + pass # pragma: no cover + + +@app.route("/no-docstring") +def no_docstring(request): + pass # pragma: no cover + + +@subapp.route("/subapp-endpoint") +def subapp_endpoint(request): + """ + responses: + 200: + description: This endpoint is part of a subapp. + """ + pass # pragma: no cover + + @app.route("/schema", methods=["GET"], include_in_schema=False) def schema(request): return schemas.OpenAPIResponse(request=request) @@ -92,6 +133,20 @@ def test_schema_generation(): } }, }, + "/regular-docstring-and-schema": { + "get": { + "responses": { + 200: {"description": "This is included in the schema."} + } + } + }, + "/subapp/subapp-endpoint": { + "get": { + "responses": { + 200: {"description": "This endpoint is part of a subapp."} + } + } + }, "/users": { "get": { "responses": { @@ -136,6 +191,16 @@ paths: description: An organisation. examples: name: Foo Corp. + /regular-docstring-and-schema: + get: + responses: + 200: + description: This is included in the schema. + /subapp/subapp-endpoint: + get: + responses: + 200: + description: This endpoint is part of a subapp. /users: get: responses: