From de010e7796a3ab155c0efbe1a3b995c8cad14870 Mon Sep 17 00:00:00 2001 From: Tom Christie Date: Thu, 1 Nov 2018 10:40:08 +0000 Subject: [PATCH] Return strings from `url_path_for` (#169) Schema generation & return strings from `url_path_for` --- requirements.txt | 1 + starlette/applications.py | 11 +++- starlette/datastructures.py | 32 ++++++++--- starlette/requests.py | 6 +-- starlette/routing.py | 22 ++++---- starlette/schemas.py | 50 +++++++++++++++++ starlette/websockets.py | 6 +-- tests/test_routing.py | 19 +++++++ tests/test_schemas.py | 103 ++++++++++++++++++++++++++++++++++++ 9 files changed, 225 insertions(+), 25 deletions(-) create mode 100644 starlette/schemas.py create mode 100644 tests/test_schemas.py diff --git a/requirements.txt b/requirements.txt index 78f0f415..e135e521 100644 --- a/requirements.txt +++ b/requirements.txt @@ -3,6 +3,7 @@ aiofiles graphene itsdangerous python-multipart +pyyaml requests ujson diff --git a/starlette/applications.py b/starlette/applications.py index fde36016..f00505ac 100644 --- a/starlette/applications.py +++ b/starlette/applications.py @@ -1,9 +1,10 @@ import typing -from starlette.datastructures import URL +from starlette.datastructures import URL, URLPath from starlette.exceptions import ExceptionMiddleware from starlette.lifespan import LifespanHandler from starlette.routing import BaseRoute, Router +from starlette.schemas import BaseSchemaGenerator from starlette.types import ASGIApp, ASGIInstance, Scope @@ -13,6 +14,7 @@ class Starlette: self.lifespan_handler = LifespanHandler() self.app = self.router self.exception_middleware = ExceptionMiddleware(self.router, debug=debug) + self.schema_generator = None # type: typing.Optional[BaseSchemaGenerator] @property def routes(self) -> typing.List[BaseRoute]: @@ -26,6 +28,11 @@ class Starlette: def debug(self, value: bool) -> None: self.exception_middleware.debug = value + @property + def schema(self) -> dict: + assert self.schema_generator is not None + return self.schema_generator.get_schema(self.routes) + def on_event(self, event_type: str) -> typing.Callable: return self.lifespan_handler.on_event(event_type) @@ -73,7 +80,7 @@ class Starlette: return decorator - def url_path_for(self, name: str, **path_params: str) -> URL: + def url_path_for(self, name: str, **path_params: str) -> URLPath: return self.router.url_path_for(name, **path_params) def __call__(self, scope: Scope) -> ASGIInstance: diff --git a/starlette/datastructures.py b/starlette/datastructures.py index 3e32d5e1..6b1868cd 100644 --- a/starlette/datastructures.py +++ b/starlette/datastructures.py @@ -92,12 +92,6 @@ class URL: kwargs["netloc"] = hostname else: kwargs["netloc"] = "%s:%d" % (hostname, port) - if "secure" in kwargs: - secure = kwargs.pop("secure") - if self.scheme in ("http", "https"): - kwargs["scheme"] = "https" if secure else "http" - elif self.scheme in ("ws", "wss"): - kwargs["scheme"] = "wss" if secure else "ws" components = self.components._replace(**kwargs) return URL(components.geturl()) @@ -105,14 +99,36 @@ class URL: return str(self) == str(other) def __str__(self) -> str: - if self.scheme and not self.netloc: - return str(self.replace(scheme="")) return self._url def __repr__(self) -> str: return "%s(%s)" % (self.__class__.__name__, repr(self._url)) +class URLPath(str): + """ + A URL path string that also holds an associated protocol. + Used by the routing to return `url_path_for` matches. + """ + + def __new__(cls, path: str, protocol: str) -> str: + assert protocol in ("http", "websocket") + return str.__new__(cls, path) # type: ignore + + def __init__(self, path: str, protocol: str) -> None: + self.protocol = protocol + + def make_absolute_url(self, base_url: typing.Union[str, URL]) -> str: + if isinstance(base_url, str): + base_url = URL(base_url) + scheme = { + "http": {True: "https", False: "http"}, + "websocket": {True: "wss", False: "ws"}, + }[self.protocol][base_url.is_secure] + netloc = base_url.netloc + return str(URL(scheme=scheme, netloc=base_url.netloc, path=str(self))) + + class QueryParams(typing.Mapping[str, str]): """ An immutable multidict. diff --git a/starlette/requests.py b/starlette/requests.py index 56c7b946..96450f9a 100644 --- a/starlette/requests.py +++ b/starlette/requests.py @@ -85,10 +85,10 @@ class Request(Mapping): def receive(self) -> Receive: return self._receive - def url_for(self, name: str, **path_params: typing.Any) -> URL: + def url_for(self, name: str, **path_params: typing.Any) -> str: router = self._scope["router"] - url = router.url_path_for(name, **path_params) - return url.replace(secure=self.url.is_secure, netloc=self.url.netloc) + url_path = router.url_path_for(name, **path_params) + return url_path.make_absolute_url(base_url=self.url) async def stream(self) -> typing.AsyncGenerator[bytes, None]: if hasattr(self, "_body"): diff --git a/starlette/routing.py b/starlette/routing.py index 13729351..fb17b776 100644 --- a/starlette/routing.py +++ b/starlette/routing.py @@ -5,7 +5,7 @@ import typing from concurrent.futures import ThreadPoolExecutor from enum import Enum -from starlette.datastructures import URL +from starlette.datastructures import URL, URLPath from starlette.exceptions import HTTPException from starlette.graphql import GraphQLApp from starlette.requests import Request @@ -85,7 +85,7 @@ class BaseRoute: def matches(self, scope: Scope) -> typing.Tuple[Match, Scope]: raise NotImplementedError() # pragma: no cover - def url_path_for(self, name: str, **path_params: str) -> URL: + def url_path_for(self, name: str, **path_params: str) -> URLPath: raise NotImplementedError() # pragma: no cover def __call__(self, scope: Scope) -> ASGIInstance: @@ -101,10 +101,12 @@ class Route(BaseRoute): self.name = get_name(endpoint) if inspect.isfunction(endpoint) or inspect.ismethod(endpoint): + # Endpoint is function or method. Treat it as `func(request) -> response`. self.app = request_response(endpoint) if methods is None: methods = ["GET"] else: + # Endpoint is a class. Treat it as ASGI. self.app = endpoint self.methods = methods @@ -127,12 +129,12 @@ class Route(BaseRoute): return Match.FULL, child_scope return Match.NONE, {} - def url_path_for(self, name: str, **path_params: str) -> URL: + def url_path_for(self, name: str, **path_params: str) -> URLPath: if name != self.name or self.param_names != set(path_params.keys()): raise NoMatchFound() path, remaining_params = replace_params(self.path, **path_params) assert not remaining_params - return URL(scheme="http", path=path) + return URLPath(path=path, protocol="http") def __call__(self, scope: Scope) -> ASGIInstance: if self.methods and scope["method"] not in self.methods: @@ -157,8 +159,10 @@ class WebSocketRoute(BaseRoute): self.name = get_name(endpoint) if inspect.isfunction(endpoint) or inspect.ismethod(endpoint): + # Endpoint is function or method. Treat it as `func(websocket)`. self.app = websocket_session(endpoint) else: + # Endpoint is a class. Treat it as ASGI. self.app = endpoint regex = "^" + path + "$" @@ -177,12 +181,12 @@ class WebSocketRoute(BaseRoute): return Match.FULL, child_scope return Match.NONE, {} - def url_path_for(self, name: str, **path_params: str) -> URL: + def url_path_for(self, name: str, **path_params: str) -> URLPath: if name != self.name or self.param_names != set(path_params.keys()): raise NoMatchFound() path, remaining_params = replace_params(self.path, **path_params) assert not remaining_params - return URL(scheme="ws", path=path) + return URLPath(path=path, protocol="websocket") def __call__(self, scope: Scope) -> ASGIInstance: return self.app(scope) @@ -219,12 +223,12 @@ class Mount(BaseRoute): return Match.FULL, child_scope return Match.NONE, {} - def url_path_for(self, name: str, **path_params: str) -> URL: + def url_path_for(self, name: str, **path_params: str) -> URLPath: path, remaining_params = replace_params(self.path, **path_params) for route in self.routes or []: try: url = route.url_path_for(name, **remaining_params) - return URL(scheme=url.scheme, path=path + url.path) + return URLPath(path=path + str(url), protocol=url.protocol) except NoMatchFound as exc: pass raise NoMatchFound() @@ -292,7 +296,7 @@ class Router: raise HTTPException(status_code=404) return PlainTextResponse("Not Found", status_code=404) - def url_path_for(self, name: str, **path_params: str) -> URL: + def url_path_for(self, name: str, **path_params: str) -> URLPath: for route in self.routes: try: return route.url_path_for(name, **path_params) diff --git a/starlette/schemas.py b/starlette/schemas.py new file mode 100644 index 00000000..1e3a1e8d --- /dev/null +++ b/starlette/schemas.py @@ -0,0 +1,50 @@ +import inspect +import typing + +from starlette.routing import BaseRoute, Route + +try: + import yaml +except ImportError: # pragma: nocover + yaml = None # type: ignore + + +class BaseSchemaGenerator: + def get_schema(self, routes: typing.List[BaseRoute]) -> dict: + raise NotImplementedError() # pragma: no cover + + +class SchemaGenerator(BaseSchemaGenerator): + def __init__(self, base_schema: dict) -> None: + assert yaml is not None, "`pyyaml` must be installed to use SchemaGenerator." + self.base_schema = base_schema + + def get_schema(self, routes: typing.List[BaseRoute]) -> dict: + paths = {} # type: dict + + for route in routes: + if not isinstance(route, Route): + continue + + if inspect.isfunction(route.endpoint) or inspect.ismethod(route.endpoint): + docstring = route.endpoint.__doc__ + for method in route.methods or ["GET"]: + if method == "HEAD": + continue + if route.path not in paths: + paths[route.path] = {} + data = yaml.safe_load(docstring) if docstring else {} + paths[route.path][method.lower()] = data + else: + for method in ["get", "post", "put", "patch", "delete", "options"]: + if not hasattr(route.endpoint, method): + continue + docstring = getattr(route.endpoint, method).__doc__ + if route.path not in paths: + paths[route.path] = {} + data = yaml.safe_load(docstring) if docstring else {} + paths[route.path][method] = data + + schema = dict(self.base_schema) + schema["paths"] = paths + return schema diff --git a/starlette/websockets.py b/starlette/websockets.py index a33b38ad..f48b060e 100644 --- a/starlette/websockets.py +++ b/starlette/websockets.py @@ -59,10 +59,10 @@ class WebSocket(Mapping): def path_params(self) -> dict: return self._scope.get("path_params", {}) - def url_for(self, name: str, **path_params: typing.Any) -> URL: + def url_for(self, name: str, **path_params: typing.Any) -> str: router = self._scope["router"] - url = router.url_path_for(name, **path_params) - return url.replace(secure=self.url.is_secure, netloc=self.url.netloc) + url_path = router.url_path_for(name, **path_params) + return url_path.make_absolute_url(base_url=self.url) async def receive(self) -> Message: """ diff --git a/tests/test_routing.py b/tests/test_routing.py index 43ec8308..6988b9f1 100644 --- a/tests/test_routing.py +++ b/tests/test_routing.py @@ -99,6 +99,25 @@ def test_url_path_for(): assert app.url_path_for("broken") +def test_url_for(): + assert ( + app.url_path_for("homepage").make_absolute_url(base_url="https://example.org") + == "https://example.org/" + ) + assert ( + app.url_path_for("user", username="tomchristie").make_absolute_url( + base_url="https://example.org" + ) + == "https://example.org/users/tomchristie" + ) + assert ( + app.url_path_for("websocket_endpoint").make_absolute_url( + base_url="https://example.org" + ) + == "wss://example.org/ws" + ) + + def test_router_add_route(): response = client.get("/func") assert response.status_code == 200 diff --git a/tests/test_schemas.py b/tests/test_schemas.py new file mode 100644 index 00000000..caedbf0d --- /dev/null +++ b/tests/test_schemas.py @@ -0,0 +1,103 @@ +from starlette.applications import Starlette +from starlette.endpoints import HTTPEndpoint +from starlette.schemas import SchemaGenerator + +app = Starlette() +app.schema_generator = SchemaGenerator( + {"openapi": "3.0.0", "info": {"title": "Example API", "version": "1.0"}} +) + + +@app.websocket_route("/ws") +def ws(session): + """ws""" + pass # pragma: no cover + + +@app.route("/users", methods=["GET", "HEAD"]) +def list_users(request): + """ + responses: + 200: + description: A list of users. + examples: + [{"username": "tom"}, {"username": "lucy"}] + """ + pass # pragma: no cover + + +@app.route("/users", methods=["POST"]) +def create_user(request): + """ + responses: + 200: + description: A user. + examples: + {"username": "tom"} + """ + pass # pragma: no cover + + +@app.route("/orgs") +class OrganisationsEndpoint(HTTPEndpoint): + def get(self, request): + """ + responses: + 200: + description: A list of organisations. + examples: + [{"name": "Foo Corp."}, {"name": "Acme Ltd."}] + """ + pass # pragma: no cover + + def post(self, request): + """ + responses: + 200: + description: An organisation. + examples: + {"name": "Foo Corp."} + """ + pass # pragma: no cover + + +def test_schema_generation(): + assert app.schema == { + "openapi": "3.0.0", + "info": {"title": "Example API", "version": "1.0"}, + "paths": { + "/orgs": { + "get": { + "responses": { + 200: { + "description": "A list of " "organisations.", + "examples": [{"name": "Foo Corp."}, {"name": "Acme Ltd."}], + } + } + }, + "post": { + "responses": { + 200: { + "description": "An organisation.", + "examples": {"name": "Foo Corp."}, + } + } + }, + }, + "/users": { + "get": { + "responses": { + 200: { + "description": "A list of users.", + "examples": [{"username": "tom"}, {"username": "lucy"}], + } + } + }, + "post": { + "responses": { + 200: {"description": "A user.", "examples": {"username": "tom"}} + } + }, + }, + }, + }