Support app.url_path_for and request.url_for (#153)

This commit is contained in:
Tom Christie 2018-10-29 11:14:42 +00:00 committed by GitHub
parent b52fd11f32
commit c047fe4e75
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 70 additions and 25 deletions

View File

@ -19,9 +19,17 @@ The path parameters are available on the request as `request.path_params`.
This is different to most Python webframeworks, but I think it actually ends up
being much more nicely consistent all the way through.
### app.url_for(name, **path_params)
### request.url_for(name, **path_params)
Applications now support URL reversing with `app.url_for(name, **path_params)`.
Request and WebSocketSession now support URL reversing with `request.url_for(name, **path_params)`.
This method returns a fully qualified `URL` instance.
The URL instance is a string-like object.
### app.url_path_for(name, **path_params)
Applications now support URL path reversing with `app.url_path_for(name, **path_params)`.
This method returns a `URL` instance with the path and scheme set.
The URL instance is a string-like object, and will return only the path if coerced to a string.
### app.routes

View File

@ -1,5 +1,6 @@
import typing
from starlette.datastructures import URL
from starlette.exceptions import ExceptionMiddleware
from starlette.lifespan import LifespanHandler
from starlette.routing import BaseRoute, Router
@ -72,11 +73,12 @@ class Starlette:
return decorator
def url_for(self, name: str, **path_params: str) -> str:
return self.router.url_for(name, **path_params)
def url_path_for(self, name: str, **path_params: str) -> URL:
return self.router.url_path_for(name, **path_params)
def __call__(self, scope: Scope) -> ASGIInstance:
scope["app"] = self
scope["router"] = self.router
if scope["type"] == "lifespan":
return self.lifespan_handler(scope)
return self.exception_middleware(scope)

View File

@ -4,9 +4,12 @@ from urllib.parse import parse_qsl, unquote, urlencode, urlparse, ParseResult
class URL:
def __init__(self, url: str = "", scope: Scope = None) -> None:
def __init__(
self, url: str = "", scope: Scope = None, **components: typing.Any
) -> None:
if scope is not None:
assert not url, 'Cannot set both "url" and "scope".'
assert not components, 'Cannot set both "scope" and "**components".'
scheme = scope.get("scheme", "http")
server = scope.get("server", None)
path = scope.get("root_path", "") + scope["path"]
@ -24,6 +27,10 @@ class URL:
if query_string:
url += "?" + query_string.decode()
elif components:
assert not url, 'Cannot set both "scope" and "**components".'
url = URL("").replace(**components).components.geturl()
self._url = url
@property
@ -72,6 +79,10 @@ class URL:
def port(self) -> typing.Optional[int]:
return self.components.port
@property
def is_secure(self) -> bool:
return self.scheme in ("https", "wss")
def replace(self, **kwargs: typing.Any) -> "URL":
if "hostname" in kwargs or "port" in kwargs:
hostname = kwargs.pop("hostname", self.hostname)
@ -80,6 +91,12 @@ class URL:
kwargs["netloc"] = hostname
else:
kwargs["netloc"] = "%s:%d" % (hostname, port)
if "secure" in kwargs:
secure = kwargs.pop("secure")
if self.scheme in ("http", "https"):
kwargs["scheme"] = "https" if secure else "http"
elif self.scheme in ("ws", "wss"):
kwargs["scheme"] = "wss" if secure else "ws"
components = self.components._replace(**kwargs)
return URL(components.geturl())
@ -87,6 +104,8 @@ class URL:
return str(self) == str(other)
def __str__(self) -> str:
if self.scheme and not self.netloc:
return str(self.replace(scheme=""))
return self._url
def __repr__(self) -> str:

View File

@ -72,6 +72,11 @@ class Request(Mapping):
self._cookies = cookies
return self._cookies
def url_for(self, name: str, **path_params: typing.Any) -> URL:
router = self._scope["router"]
url = router.url_path_for(name, **path_params)
return url.replace(secure=self.url.is_secure, netloc=self.url.netloc)
async def stream(self) -> typing.AsyncGenerator[bytes, None]:
if hasattr(self, "_body"):
yield self._body

View File

@ -5,6 +5,7 @@ import asyncio
from concurrent.futures import ThreadPoolExecutor
from starlette.requests import Request
from starlette.datastructures import URL
from starlette.exceptions import HTTPException
from starlette.responses import PlainTextResponse
from starlette.types import Scope, ASGIApp, ASGIInstance, Send, Receive
@ -68,7 +69,7 @@ class BaseRoute:
def matches(self, scope: Scope) -> typing.Tuple[bool, Scope]:
raise NotImplementedError() # pragma: no cover
def url_for(self, name: str, **path_params: str) -> str:
def url_path_for(self, name: str, **path_params: str) -> URL:
raise NotImplementedError() # pragma: no cover
def __call__(self, scope: Scope) -> ASGIInstance:
@ -107,10 +108,10 @@ class Route(BaseRoute):
return True, child_scope
return False, {}
def url_for(self, name: str, **path_params: str) -> str:
def url_path_for(self, name: str, **path_params: str) -> URL:
if name != self.name or self.param_names != set(path_params.keys()):
raise NoMatchFound()
return replace_params(self.path, **path_params)
return URL(scheme="http", path=replace_params(self.path, **path_params))
def __call__(self, scope: Scope) -> ASGIInstance:
if self.methods and scope["method"] not in self.methods:
@ -155,10 +156,10 @@ class WebSocketRoute(BaseRoute):
return True, child_scope
return False, {}
def url_for(self, name: str, **path_params: str) -> str:
def url_path_for(self, name: str, **path_params: str) -> URL:
if name != self.name or self.param_names != set(path_params.keys()):
raise NoMatchFound()
return replace_params(self.path, **path_params)
return URL(scheme="ws", path=replace_params(self.path, **path_params))
def __call__(self, scope: Scope) -> ASGIInstance:
return self.app(scope)
@ -195,10 +196,11 @@ class Mount(BaseRoute):
return True, child_scope
return False, {}
def url_for(self, name: str, **path_params: str) -> str:
def url_path_for(self, name: str, **path_params: str) -> URL:
for route in self.routes or []:
try:
return self.path + route.url_for(name, **path_params)
url = route.url_path_for(name, **path_params)
return URL(scheme=url.scheme, path=self.path + url.path)
except NoMatchFound as exc:
pass
raise NoMatchFound()
@ -266,10 +268,10 @@ class Router:
raise HTTPException(status_code=404)
return PlainTextResponse("Not Found", status_code=404)
def url_for(self, name: str, **path_params: str) -> str:
def url_path_for(self, name: str, **path_params: str) -> URL:
for route in self.routes:
try:
return route.url_for(name, **path_params)
return route.url_path_for(name, **path_params)
except NoMatchFound as exc:
pass
raise NoMatchFound()
@ -277,6 +279,9 @@ class Router:
def __call__(self, scope: Scope) -> ASGIInstance:
assert scope["type"] in ("http", "websocket")
if "router" not in scope:
scope["router"] = self
for route in self.routes:
matched, child_scope = route.matches(scope)
if matched:

View File

@ -60,6 +60,11 @@ class WebSocket(Mapping):
def path_params(self) -> dict:
return self._scope.get("path_params", {})
def url_for(self, name: str, **path_params: typing.Any) -> URL:
router = self._scope["router"]
url = router.url_path_for(name, **path_params)
return url.replace(secure=self.url.is_secure, netloc=self.url.netloc)
async def receive(self) -> Message:
"""
Receive ASGI websocket messages, ensuring valid state transitions.

View File

@ -86,8 +86,8 @@ async def websocket_endpoint(session):
client = TestClient(app)
def test_url_for():
assert app.url_for("func_homepage") == "/func"
def test_url_path_for():
assert app.url_path_for("func_homepage") == "/func"
def test_func_route():

View File

@ -85,12 +85,12 @@ def test_router():
assert response.text == "xxxxx"
def test_url_for():
assert app.url_for("homepage") == "/"
assert app.url_for("user", username="tomchristie") == "/users/tomchristie"
assert app.url_for("websocket_endpoint") == "/ws"
def test_url_path_for():
assert app.url_path_for("homepage") == "/"
assert app.url_path_for("user", username="tomchristie") == "/users/tomchristie"
assert app.url_path_for("websocket_endpoint") == "/ws"
with pytest.raises(NoMatchFound):
assert app.url_for("broken")
assert app.url_path_for("broken")
def test_router_add_route():
@ -110,7 +110,8 @@ def test_router_add_websocket_route():
def http_endpoint(request):
return Response("Hello, world", media_type="text/plain")
url = request.url_for("http_endpoint")
return Response("URL: %s" % url, media_type="text/plain")
class WebsocketEndpoint:
@ -120,7 +121,7 @@ class WebsocketEndpoint:
async def __call__(self, receive, send):
session = WebSocket(scope=self.scope, receive=receive, send=send)
await session.accept()
await session.send_json({"hello": "world"})
await session.send_json({"URL": str(session.url_for("WebsocketEndpoint"))})
await session.close()
@ -137,10 +138,10 @@ def test_protocol_switch():
response = client.get("/")
assert response.status_code == 200
assert response.text == "Hello, world"
assert response.text == "URL: http://testserver/"
with client.websocket_connect("/") as session:
assert session.receive_json() == {"hello": "world"}
assert session.receive_json() == {"URL": "ws://testserver/"}
with pytest.raises(WebSocketDisconnect):
client.websocket_connect("/404")