Return strings from `url_path_for` (#169)

Schema generation & return strings from `url_path_for`
This commit is contained in:
Tom Christie 2018-11-01 10:40:08 +00:00 committed by GitHub
parent 5488af66d7
commit de010e7796
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 225 additions and 25 deletions

View File

@ -3,6 +3,7 @@ aiofiles
graphene
itsdangerous
python-multipart
pyyaml
requests
ujson

View File

@ -1,9 +1,10 @@
import typing
from starlette.datastructures import URL
from starlette.datastructures import URL, URLPath
from starlette.exceptions import ExceptionMiddleware
from starlette.lifespan import LifespanHandler
from starlette.routing import BaseRoute, Router
from starlette.schemas import BaseSchemaGenerator
from starlette.types import ASGIApp, ASGIInstance, Scope
@ -13,6 +14,7 @@ class Starlette:
self.lifespan_handler = LifespanHandler()
self.app = self.router
self.exception_middleware = ExceptionMiddleware(self.router, debug=debug)
self.schema_generator = None # type: typing.Optional[BaseSchemaGenerator]
@property
def routes(self) -> typing.List[BaseRoute]:
@ -26,6 +28,11 @@ class Starlette:
def debug(self, value: bool) -> None:
self.exception_middleware.debug = value
@property
def schema(self) -> dict:
assert self.schema_generator is not None
return self.schema_generator.get_schema(self.routes)
def on_event(self, event_type: str) -> typing.Callable:
return self.lifespan_handler.on_event(event_type)
@ -73,7 +80,7 @@ class Starlette:
return decorator
def url_path_for(self, name: str, **path_params: str) -> URL:
def url_path_for(self, name: str, **path_params: str) -> URLPath:
return self.router.url_path_for(name, **path_params)
def __call__(self, scope: Scope) -> ASGIInstance:

View File

@ -92,12 +92,6 @@ class URL:
kwargs["netloc"] = hostname
else:
kwargs["netloc"] = "%s:%d" % (hostname, port)
if "secure" in kwargs:
secure = kwargs.pop("secure")
if self.scheme in ("http", "https"):
kwargs["scheme"] = "https" if secure else "http"
elif self.scheme in ("ws", "wss"):
kwargs["scheme"] = "wss" if secure else "ws"
components = self.components._replace(**kwargs)
return URL(components.geturl())
@ -105,14 +99,36 @@ class URL:
return str(self) == str(other)
def __str__(self) -> str:
if self.scheme and not self.netloc:
return str(self.replace(scheme=""))
return self._url
def __repr__(self) -> str:
return "%s(%s)" % (self.__class__.__name__, repr(self._url))
class URLPath(str):
"""
A URL path string that also holds an associated protocol.
Used by the routing to return `url_path_for` matches.
"""
def __new__(cls, path: str, protocol: str) -> str:
assert protocol in ("http", "websocket")
return str.__new__(cls, path) # type: ignore
def __init__(self, path: str, protocol: str) -> None:
self.protocol = protocol
def make_absolute_url(self, base_url: typing.Union[str, URL]) -> str:
if isinstance(base_url, str):
base_url = URL(base_url)
scheme = {
"http": {True: "https", False: "http"},
"websocket": {True: "wss", False: "ws"},
}[self.protocol][base_url.is_secure]
netloc = base_url.netloc
return str(URL(scheme=scheme, netloc=base_url.netloc, path=str(self)))
class QueryParams(typing.Mapping[str, str]):
"""
An immutable multidict.

View File

@ -85,10 +85,10 @@ class Request(Mapping):
def receive(self) -> Receive:
return self._receive
def url_for(self, name: str, **path_params: typing.Any) -> URL:
def url_for(self, name: str, **path_params: typing.Any) -> str:
router = self._scope["router"]
url = router.url_path_for(name, **path_params)
return url.replace(secure=self.url.is_secure, netloc=self.url.netloc)
url_path = router.url_path_for(name, **path_params)
return url_path.make_absolute_url(base_url=self.url)
async def stream(self) -> typing.AsyncGenerator[bytes, None]:
if hasattr(self, "_body"):

View File

@ -5,7 +5,7 @@ import typing
from concurrent.futures import ThreadPoolExecutor
from enum import Enum
from starlette.datastructures import URL
from starlette.datastructures import URL, URLPath
from starlette.exceptions import HTTPException
from starlette.graphql import GraphQLApp
from starlette.requests import Request
@ -85,7 +85,7 @@ class BaseRoute:
def matches(self, scope: Scope) -> typing.Tuple[Match, Scope]:
raise NotImplementedError() # pragma: no cover
def url_path_for(self, name: str, **path_params: str) -> URL:
def url_path_for(self, name: str, **path_params: str) -> URLPath:
raise NotImplementedError() # pragma: no cover
def __call__(self, scope: Scope) -> ASGIInstance:
@ -101,10 +101,12 @@ class Route(BaseRoute):
self.name = get_name(endpoint)
if inspect.isfunction(endpoint) or inspect.ismethod(endpoint):
# Endpoint is function or method. Treat it as `func(request) -> response`.
self.app = request_response(endpoint)
if methods is None:
methods = ["GET"]
else:
# Endpoint is a class. Treat it as ASGI.
self.app = endpoint
self.methods = methods
@ -127,12 +129,12 @@ class Route(BaseRoute):
return Match.FULL, child_scope
return Match.NONE, {}
def url_path_for(self, name: str, **path_params: str) -> URL:
def url_path_for(self, name: str, **path_params: str) -> URLPath:
if name != self.name or self.param_names != set(path_params.keys()):
raise NoMatchFound()
path, remaining_params = replace_params(self.path, **path_params)
assert not remaining_params
return URL(scheme="http", path=path)
return URLPath(path=path, protocol="http")
def __call__(self, scope: Scope) -> ASGIInstance:
if self.methods and scope["method"] not in self.methods:
@ -157,8 +159,10 @@ class WebSocketRoute(BaseRoute):
self.name = get_name(endpoint)
if inspect.isfunction(endpoint) or inspect.ismethod(endpoint):
# Endpoint is function or method. Treat it as `func(websocket)`.
self.app = websocket_session(endpoint)
else:
# Endpoint is a class. Treat it as ASGI.
self.app = endpoint
regex = "^" + path + "$"
@ -177,12 +181,12 @@ class WebSocketRoute(BaseRoute):
return Match.FULL, child_scope
return Match.NONE, {}
def url_path_for(self, name: str, **path_params: str) -> URL:
def url_path_for(self, name: str, **path_params: str) -> URLPath:
if name != self.name or self.param_names != set(path_params.keys()):
raise NoMatchFound()
path, remaining_params = replace_params(self.path, **path_params)
assert not remaining_params
return URL(scheme="ws", path=path)
return URLPath(path=path, protocol="websocket")
def __call__(self, scope: Scope) -> ASGIInstance:
return self.app(scope)
@ -219,12 +223,12 @@ class Mount(BaseRoute):
return Match.FULL, child_scope
return Match.NONE, {}
def url_path_for(self, name: str, **path_params: str) -> URL:
def url_path_for(self, name: str, **path_params: str) -> URLPath:
path, remaining_params = replace_params(self.path, **path_params)
for route in self.routes or []:
try:
url = route.url_path_for(name, **remaining_params)
return URL(scheme=url.scheme, path=path + url.path)
return URLPath(path=path + str(url), protocol=url.protocol)
except NoMatchFound as exc:
pass
raise NoMatchFound()
@ -292,7 +296,7 @@ class Router:
raise HTTPException(status_code=404)
return PlainTextResponse("Not Found", status_code=404)
def url_path_for(self, name: str, **path_params: str) -> URL:
def url_path_for(self, name: str, **path_params: str) -> URLPath:
for route in self.routes:
try:
return route.url_path_for(name, **path_params)

50
starlette/schemas.py Normal file
View File

@ -0,0 +1,50 @@
import inspect
import typing
from starlette.routing import BaseRoute, Route
try:
import yaml
except ImportError: # pragma: nocover
yaml = None # type: ignore
class BaseSchemaGenerator:
def get_schema(self, routes: typing.List[BaseRoute]) -> dict:
raise NotImplementedError() # pragma: no cover
class SchemaGenerator(BaseSchemaGenerator):
def __init__(self, base_schema: dict) -> None:
assert yaml is not None, "`pyyaml` must be installed to use SchemaGenerator."
self.base_schema = base_schema
def get_schema(self, routes: typing.List[BaseRoute]) -> dict:
paths = {} # type: dict
for route in routes:
if not isinstance(route, Route):
continue
if inspect.isfunction(route.endpoint) or inspect.ismethod(route.endpoint):
docstring = route.endpoint.__doc__
for method in route.methods or ["GET"]:
if method == "HEAD":
continue
if route.path not in paths:
paths[route.path] = {}
data = yaml.safe_load(docstring) if docstring else {}
paths[route.path][method.lower()] = data
else:
for method in ["get", "post", "put", "patch", "delete", "options"]:
if not hasattr(route.endpoint, method):
continue
docstring = getattr(route.endpoint, method).__doc__
if route.path not in paths:
paths[route.path] = {}
data = yaml.safe_load(docstring) if docstring else {}
paths[route.path][method] = data
schema = dict(self.base_schema)
schema["paths"] = paths
return schema

View File

@ -59,10 +59,10 @@ class WebSocket(Mapping):
def path_params(self) -> dict:
return self._scope.get("path_params", {})
def url_for(self, name: str, **path_params: typing.Any) -> URL:
def url_for(self, name: str, **path_params: typing.Any) -> str:
router = self._scope["router"]
url = router.url_path_for(name, **path_params)
return url.replace(secure=self.url.is_secure, netloc=self.url.netloc)
url_path = router.url_path_for(name, **path_params)
return url_path.make_absolute_url(base_url=self.url)
async def receive(self) -> Message:
"""

View File

@ -99,6 +99,25 @@ def test_url_path_for():
assert app.url_path_for("broken")
def test_url_for():
assert (
app.url_path_for("homepage").make_absolute_url(base_url="https://example.org")
== "https://example.org/"
)
assert (
app.url_path_for("user", username="tomchristie").make_absolute_url(
base_url="https://example.org"
)
== "https://example.org/users/tomchristie"
)
assert (
app.url_path_for("websocket_endpoint").make_absolute_url(
base_url="https://example.org"
)
== "wss://example.org/ws"
)
def test_router_add_route():
response = client.get("/func")
assert response.status_code == 200

103
tests/test_schemas.py Normal file
View File

@ -0,0 +1,103 @@
from starlette.applications import Starlette
from starlette.endpoints import HTTPEndpoint
from starlette.schemas import SchemaGenerator
app = Starlette()
app.schema_generator = SchemaGenerator(
{"openapi": "3.0.0", "info": {"title": "Example API", "version": "1.0"}}
)
@app.websocket_route("/ws")
def ws(session):
"""ws"""
pass # pragma: no cover
@app.route("/users", methods=["GET", "HEAD"])
def list_users(request):
"""
responses:
200:
description: A list of users.
examples:
[{"username": "tom"}, {"username": "lucy"}]
"""
pass # pragma: no cover
@app.route("/users", methods=["POST"])
def create_user(request):
"""
responses:
200:
description: A user.
examples:
{"username": "tom"}
"""
pass # pragma: no cover
@app.route("/orgs")
class OrganisationsEndpoint(HTTPEndpoint):
def get(self, request):
"""
responses:
200:
description: A list of organisations.
examples:
[{"name": "Foo Corp."}, {"name": "Acme Ltd."}]
"""
pass # pragma: no cover
def post(self, request):
"""
responses:
200:
description: An organisation.
examples:
{"name": "Foo Corp."}
"""
pass # pragma: no cover
def test_schema_generation():
assert app.schema == {
"openapi": "3.0.0",
"info": {"title": "Example API", "version": "1.0"},
"paths": {
"/orgs": {
"get": {
"responses": {
200: {
"description": "A list of " "organisations.",
"examples": [{"name": "Foo Corp."}, {"name": "Acme Ltd."}],
}
}
},
"post": {
"responses": {
200: {
"description": "An organisation.",
"examples": {"name": "Foo Corp."},
}
}
},
},
"/users": {
"get": {
"responses": {
200: {
"description": "A list of users.",
"examples": [{"username": "tom"}, {"username": "lucy"}],
}
}
},
"post": {
"responses": {
200: {"description": "A user.", "examples": {"username": "tom"}}
}
},
},
},
}