mirror of https://github.com/encode/starlette.git
Return strings from `url_path_for` (#169)
Schema generation & return strings from `url_path_for`
This commit is contained in:
parent
5488af66d7
commit
de010e7796
|
@ -3,6 +3,7 @@ aiofiles
|
|||
graphene
|
||||
itsdangerous
|
||||
python-multipart
|
||||
pyyaml
|
||||
requests
|
||||
ujson
|
||||
|
||||
|
|
|
@ -1,9 +1,10 @@
|
|||
import typing
|
||||
|
||||
from starlette.datastructures import URL
|
||||
from starlette.datastructures import URL, URLPath
|
||||
from starlette.exceptions import ExceptionMiddleware
|
||||
from starlette.lifespan import LifespanHandler
|
||||
from starlette.routing import BaseRoute, Router
|
||||
from starlette.schemas import BaseSchemaGenerator
|
||||
from starlette.types import ASGIApp, ASGIInstance, Scope
|
||||
|
||||
|
||||
|
@ -13,6 +14,7 @@ class Starlette:
|
|||
self.lifespan_handler = LifespanHandler()
|
||||
self.app = self.router
|
||||
self.exception_middleware = ExceptionMiddleware(self.router, debug=debug)
|
||||
self.schema_generator = None # type: typing.Optional[BaseSchemaGenerator]
|
||||
|
||||
@property
|
||||
def routes(self) -> typing.List[BaseRoute]:
|
||||
|
@ -26,6 +28,11 @@ class Starlette:
|
|||
def debug(self, value: bool) -> None:
|
||||
self.exception_middleware.debug = value
|
||||
|
||||
@property
|
||||
def schema(self) -> dict:
|
||||
assert self.schema_generator is not None
|
||||
return self.schema_generator.get_schema(self.routes)
|
||||
|
||||
def on_event(self, event_type: str) -> typing.Callable:
|
||||
return self.lifespan_handler.on_event(event_type)
|
||||
|
||||
|
@ -73,7 +80,7 @@ class Starlette:
|
|||
|
||||
return decorator
|
||||
|
||||
def url_path_for(self, name: str, **path_params: str) -> URL:
|
||||
def url_path_for(self, name: str, **path_params: str) -> URLPath:
|
||||
return self.router.url_path_for(name, **path_params)
|
||||
|
||||
def __call__(self, scope: Scope) -> ASGIInstance:
|
||||
|
|
|
@ -92,12 +92,6 @@ class URL:
|
|||
kwargs["netloc"] = hostname
|
||||
else:
|
||||
kwargs["netloc"] = "%s:%d" % (hostname, port)
|
||||
if "secure" in kwargs:
|
||||
secure = kwargs.pop("secure")
|
||||
if self.scheme in ("http", "https"):
|
||||
kwargs["scheme"] = "https" if secure else "http"
|
||||
elif self.scheme in ("ws", "wss"):
|
||||
kwargs["scheme"] = "wss" if secure else "ws"
|
||||
components = self.components._replace(**kwargs)
|
||||
return URL(components.geturl())
|
||||
|
||||
|
@ -105,14 +99,36 @@ class URL:
|
|||
return str(self) == str(other)
|
||||
|
||||
def __str__(self) -> str:
|
||||
if self.scheme and not self.netloc:
|
||||
return str(self.replace(scheme=""))
|
||||
return self._url
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return "%s(%s)" % (self.__class__.__name__, repr(self._url))
|
||||
|
||||
|
||||
class URLPath(str):
|
||||
"""
|
||||
A URL path string that also holds an associated protocol.
|
||||
Used by the routing to return `url_path_for` matches.
|
||||
"""
|
||||
|
||||
def __new__(cls, path: str, protocol: str) -> str:
|
||||
assert protocol in ("http", "websocket")
|
||||
return str.__new__(cls, path) # type: ignore
|
||||
|
||||
def __init__(self, path: str, protocol: str) -> None:
|
||||
self.protocol = protocol
|
||||
|
||||
def make_absolute_url(self, base_url: typing.Union[str, URL]) -> str:
|
||||
if isinstance(base_url, str):
|
||||
base_url = URL(base_url)
|
||||
scheme = {
|
||||
"http": {True: "https", False: "http"},
|
||||
"websocket": {True: "wss", False: "ws"},
|
||||
}[self.protocol][base_url.is_secure]
|
||||
netloc = base_url.netloc
|
||||
return str(URL(scheme=scheme, netloc=base_url.netloc, path=str(self)))
|
||||
|
||||
|
||||
class QueryParams(typing.Mapping[str, str]):
|
||||
"""
|
||||
An immutable multidict.
|
||||
|
|
|
@ -85,10 +85,10 @@ class Request(Mapping):
|
|||
def receive(self) -> Receive:
|
||||
return self._receive
|
||||
|
||||
def url_for(self, name: str, **path_params: typing.Any) -> URL:
|
||||
def url_for(self, name: str, **path_params: typing.Any) -> str:
|
||||
router = self._scope["router"]
|
||||
url = router.url_path_for(name, **path_params)
|
||||
return url.replace(secure=self.url.is_secure, netloc=self.url.netloc)
|
||||
url_path = router.url_path_for(name, **path_params)
|
||||
return url_path.make_absolute_url(base_url=self.url)
|
||||
|
||||
async def stream(self) -> typing.AsyncGenerator[bytes, None]:
|
||||
if hasattr(self, "_body"):
|
||||
|
|
|
@ -5,7 +5,7 @@ import typing
|
|||
from concurrent.futures import ThreadPoolExecutor
|
||||
from enum import Enum
|
||||
|
||||
from starlette.datastructures import URL
|
||||
from starlette.datastructures import URL, URLPath
|
||||
from starlette.exceptions import HTTPException
|
||||
from starlette.graphql import GraphQLApp
|
||||
from starlette.requests import Request
|
||||
|
@ -85,7 +85,7 @@ class BaseRoute:
|
|||
def matches(self, scope: Scope) -> typing.Tuple[Match, Scope]:
|
||||
raise NotImplementedError() # pragma: no cover
|
||||
|
||||
def url_path_for(self, name: str, **path_params: str) -> URL:
|
||||
def url_path_for(self, name: str, **path_params: str) -> URLPath:
|
||||
raise NotImplementedError() # pragma: no cover
|
||||
|
||||
def __call__(self, scope: Scope) -> ASGIInstance:
|
||||
|
@ -101,10 +101,12 @@ class Route(BaseRoute):
|
|||
self.name = get_name(endpoint)
|
||||
|
||||
if inspect.isfunction(endpoint) or inspect.ismethod(endpoint):
|
||||
# Endpoint is function or method. Treat it as `func(request) -> response`.
|
||||
self.app = request_response(endpoint)
|
||||
if methods is None:
|
||||
methods = ["GET"]
|
||||
else:
|
||||
# Endpoint is a class. Treat it as ASGI.
|
||||
self.app = endpoint
|
||||
|
||||
self.methods = methods
|
||||
|
@ -127,12 +129,12 @@ class Route(BaseRoute):
|
|||
return Match.FULL, child_scope
|
||||
return Match.NONE, {}
|
||||
|
||||
def url_path_for(self, name: str, **path_params: str) -> URL:
|
||||
def url_path_for(self, name: str, **path_params: str) -> URLPath:
|
||||
if name != self.name or self.param_names != set(path_params.keys()):
|
||||
raise NoMatchFound()
|
||||
path, remaining_params = replace_params(self.path, **path_params)
|
||||
assert not remaining_params
|
||||
return URL(scheme="http", path=path)
|
||||
return URLPath(path=path, protocol="http")
|
||||
|
||||
def __call__(self, scope: Scope) -> ASGIInstance:
|
||||
if self.methods and scope["method"] not in self.methods:
|
||||
|
@ -157,8 +159,10 @@ class WebSocketRoute(BaseRoute):
|
|||
self.name = get_name(endpoint)
|
||||
|
||||
if inspect.isfunction(endpoint) or inspect.ismethod(endpoint):
|
||||
# Endpoint is function or method. Treat it as `func(websocket)`.
|
||||
self.app = websocket_session(endpoint)
|
||||
else:
|
||||
# Endpoint is a class. Treat it as ASGI.
|
||||
self.app = endpoint
|
||||
|
||||
regex = "^" + path + "$"
|
||||
|
@ -177,12 +181,12 @@ class WebSocketRoute(BaseRoute):
|
|||
return Match.FULL, child_scope
|
||||
return Match.NONE, {}
|
||||
|
||||
def url_path_for(self, name: str, **path_params: str) -> URL:
|
||||
def url_path_for(self, name: str, **path_params: str) -> URLPath:
|
||||
if name != self.name or self.param_names != set(path_params.keys()):
|
||||
raise NoMatchFound()
|
||||
path, remaining_params = replace_params(self.path, **path_params)
|
||||
assert not remaining_params
|
||||
return URL(scheme="ws", path=path)
|
||||
return URLPath(path=path, protocol="websocket")
|
||||
|
||||
def __call__(self, scope: Scope) -> ASGIInstance:
|
||||
return self.app(scope)
|
||||
|
@ -219,12 +223,12 @@ class Mount(BaseRoute):
|
|||
return Match.FULL, child_scope
|
||||
return Match.NONE, {}
|
||||
|
||||
def url_path_for(self, name: str, **path_params: str) -> URL:
|
||||
def url_path_for(self, name: str, **path_params: str) -> URLPath:
|
||||
path, remaining_params = replace_params(self.path, **path_params)
|
||||
for route in self.routes or []:
|
||||
try:
|
||||
url = route.url_path_for(name, **remaining_params)
|
||||
return URL(scheme=url.scheme, path=path + url.path)
|
||||
return URLPath(path=path + str(url), protocol=url.protocol)
|
||||
except NoMatchFound as exc:
|
||||
pass
|
||||
raise NoMatchFound()
|
||||
|
@ -292,7 +296,7 @@ class Router:
|
|||
raise HTTPException(status_code=404)
|
||||
return PlainTextResponse("Not Found", status_code=404)
|
||||
|
||||
def url_path_for(self, name: str, **path_params: str) -> URL:
|
||||
def url_path_for(self, name: str, **path_params: str) -> URLPath:
|
||||
for route in self.routes:
|
||||
try:
|
||||
return route.url_path_for(name, **path_params)
|
||||
|
|
|
@ -0,0 +1,50 @@
|
|||
import inspect
|
||||
import typing
|
||||
|
||||
from starlette.routing import BaseRoute, Route
|
||||
|
||||
try:
|
||||
import yaml
|
||||
except ImportError: # pragma: nocover
|
||||
yaml = None # type: ignore
|
||||
|
||||
|
||||
class BaseSchemaGenerator:
|
||||
def get_schema(self, routes: typing.List[BaseRoute]) -> dict:
|
||||
raise NotImplementedError() # pragma: no cover
|
||||
|
||||
|
||||
class SchemaGenerator(BaseSchemaGenerator):
|
||||
def __init__(self, base_schema: dict) -> None:
|
||||
assert yaml is not None, "`pyyaml` must be installed to use SchemaGenerator."
|
||||
self.base_schema = base_schema
|
||||
|
||||
def get_schema(self, routes: typing.List[BaseRoute]) -> dict:
|
||||
paths = {} # type: dict
|
||||
|
||||
for route in routes:
|
||||
if not isinstance(route, Route):
|
||||
continue
|
||||
|
||||
if inspect.isfunction(route.endpoint) or inspect.ismethod(route.endpoint):
|
||||
docstring = route.endpoint.__doc__
|
||||
for method in route.methods or ["GET"]:
|
||||
if method == "HEAD":
|
||||
continue
|
||||
if route.path not in paths:
|
||||
paths[route.path] = {}
|
||||
data = yaml.safe_load(docstring) if docstring else {}
|
||||
paths[route.path][method.lower()] = data
|
||||
else:
|
||||
for method in ["get", "post", "put", "patch", "delete", "options"]:
|
||||
if not hasattr(route.endpoint, method):
|
||||
continue
|
||||
docstring = getattr(route.endpoint, method).__doc__
|
||||
if route.path not in paths:
|
||||
paths[route.path] = {}
|
||||
data = yaml.safe_load(docstring) if docstring else {}
|
||||
paths[route.path][method] = data
|
||||
|
||||
schema = dict(self.base_schema)
|
||||
schema["paths"] = paths
|
||||
return schema
|
|
@ -59,10 +59,10 @@ class WebSocket(Mapping):
|
|||
def path_params(self) -> dict:
|
||||
return self._scope.get("path_params", {})
|
||||
|
||||
def url_for(self, name: str, **path_params: typing.Any) -> URL:
|
||||
def url_for(self, name: str, **path_params: typing.Any) -> str:
|
||||
router = self._scope["router"]
|
||||
url = router.url_path_for(name, **path_params)
|
||||
return url.replace(secure=self.url.is_secure, netloc=self.url.netloc)
|
||||
url_path = router.url_path_for(name, **path_params)
|
||||
return url_path.make_absolute_url(base_url=self.url)
|
||||
|
||||
async def receive(self) -> Message:
|
||||
"""
|
||||
|
|
|
@ -99,6 +99,25 @@ def test_url_path_for():
|
|||
assert app.url_path_for("broken")
|
||||
|
||||
|
||||
def test_url_for():
|
||||
assert (
|
||||
app.url_path_for("homepage").make_absolute_url(base_url="https://example.org")
|
||||
== "https://example.org/"
|
||||
)
|
||||
assert (
|
||||
app.url_path_for("user", username="tomchristie").make_absolute_url(
|
||||
base_url="https://example.org"
|
||||
)
|
||||
== "https://example.org/users/tomchristie"
|
||||
)
|
||||
assert (
|
||||
app.url_path_for("websocket_endpoint").make_absolute_url(
|
||||
base_url="https://example.org"
|
||||
)
|
||||
== "wss://example.org/ws"
|
||||
)
|
||||
|
||||
|
||||
def test_router_add_route():
|
||||
response = client.get("/func")
|
||||
assert response.status_code == 200
|
||||
|
|
|
@ -0,0 +1,103 @@
|
|||
from starlette.applications import Starlette
|
||||
from starlette.endpoints import HTTPEndpoint
|
||||
from starlette.schemas import SchemaGenerator
|
||||
|
||||
app = Starlette()
|
||||
app.schema_generator = SchemaGenerator(
|
||||
{"openapi": "3.0.0", "info": {"title": "Example API", "version": "1.0"}}
|
||||
)
|
||||
|
||||
|
||||
@app.websocket_route("/ws")
|
||||
def ws(session):
|
||||
"""ws"""
|
||||
pass # pragma: no cover
|
||||
|
||||
|
||||
@app.route("/users", methods=["GET", "HEAD"])
|
||||
def list_users(request):
|
||||
"""
|
||||
responses:
|
||||
200:
|
||||
description: A list of users.
|
||||
examples:
|
||||
[{"username": "tom"}, {"username": "lucy"}]
|
||||
"""
|
||||
pass # pragma: no cover
|
||||
|
||||
|
||||
@app.route("/users", methods=["POST"])
|
||||
def create_user(request):
|
||||
"""
|
||||
responses:
|
||||
200:
|
||||
description: A user.
|
||||
examples:
|
||||
{"username": "tom"}
|
||||
"""
|
||||
pass # pragma: no cover
|
||||
|
||||
|
||||
@app.route("/orgs")
|
||||
class OrganisationsEndpoint(HTTPEndpoint):
|
||||
def get(self, request):
|
||||
"""
|
||||
responses:
|
||||
200:
|
||||
description: A list of organisations.
|
||||
examples:
|
||||
[{"name": "Foo Corp."}, {"name": "Acme Ltd."}]
|
||||
"""
|
||||
pass # pragma: no cover
|
||||
|
||||
def post(self, request):
|
||||
"""
|
||||
responses:
|
||||
200:
|
||||
description: An organisation.
|
||||
examples:
|
||||
{"name": "Foo Corp."}
|
||||
"""
|
||||
pass # pragma: no cover
|
||||
|
||||
|
||||
def test_schema_generation():
|
||||
assert app.schema == {
|
||||
"openapi": "3.0.0",
|
||||
"info": {"title": "Example API", "version": "1.0"},
|
||||
"paths": {
|
||||
"/orgs": {
|
||||
"get": {
|
||||
"responses": {
|
||||
200: {
|
||||
"description": "A list of " "organisations.",
|
||||
"examples": [{"name": "Foo Corp."}, {"name": "Acme Ltd."}],
|
||||
}
|
||||
}
|
||||
},
|
||||
"post": {
|
||||
"responses": {
|
||||
200: {
|
||||
"description": "An organisation.",
|
||||
"examples": {"name": "Foo Corp."},
|
||||
}
|
||||
}
|
||||
},
|
||||
},
|
||||
"/users": {
|
||||
"get": {
|
||||
"responses": {
|
||||
200: {
|
||||
"description": "A list of users.",
|
||||
"examples": [{"username": "tom"}, {"username": "lucy"}],
|
||||
}
|
||||
}
|
||||
},
|
||||
"post": {
|
||||
"responses": {
|
||||
200: {"description": "A user.", "examples": {"username": "tom"}}
|
||||
}
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
Loading…
Reference in New Issue