mirror of https://github.com/encode/starlette.git
Support app.url_path_for and request.url_for (#153)
This commit is contained in:
parent
b52fd11f32
commit
c047fe4e75
|
@ -19,9 +19,17 @@ The path parameters are available on the request as `request.path_params`.
|
|||
This is different to most Python webframeworks, but I think it actually ends up
|
||||
being much more nicely consistent all the way through.
|
||||
|
||||
### app.url_for(name, **path_params)
|
||||
### request.url_for(name, **path_params)
|
||||
|
||||
Applications now support URL reversing with `app.url_for(name, **path_params)`.
|
||||
Request and WebSocketSession now support URL reversing with `request.url_for(name, **path_params)`.
|
||||
This method returns a fully qualified `URL` instance.
|
||||
The URL instance is a string-like object.
|
||||
|
||||
### app.url_path_for(name, **path_params)
|
||||
|
||||
Applications now support URL path reversing with `app.url_path_for(name, **path_params)`.
|
||||
This method returns a `URL` instance with the path and scheme set.
|
||||
The URL instance is a string-like object, and will return only the path if coerced to a string.
|
||||
|
||||
### app.routes
|
||||
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
import typing
|
||||
|
||||
from starlette.datastructures import URL
|
||||
from starlette.exceptions import ExceptionMiddleware
|
||||
from starlette.lifespan import LifespanHandler
|
||||
from starlette.routing import BaseRoute, Router
|
||||
|
@ -72,11 +73,12 @@ class Starlette:
|
|||
|
||||
return decorator
|
||||
|
||||
def url_for(self, name: str, **path_params: str) -> str:
|
||||
return self.router.url_for(name, **path_params)
|
||||
def url_path_for(self, name: str, **path_params: str) -> URL:
|
||||
return self.router.url_path_for(name, **path_params)
|
||||
|
||||
def __call__(self, scope: Scope) -> ASGIInstance:
|
||||
scope["app"] = self
|
||||
scope["router"] = self.router
|
||||
if scope["type"] == "lifespan":
|
||||
return self.lifespan_handler(scope)
|
||||
return self.exception_middleware(scope)
|
||||
|
|
|
@ -4,9 +4,12 @@ from urllib.parse import parse_qsl, unquote, urlencode, urlparse, ParseResult
|
|||
|
||||
|
||||
class URL:
|
||||
def __init__(self, url: str = "", scope: Scope = None) -> None:
|
||||
def __init__(
|
||||
self, url: str = "", scope: Scope = None, **components: typing.Any
|
||||
) -> None:
|
||||
if scope is not None:
|
||||
assert not url, 'Cannot set both "url" and "scope".'
|
||||
assert not components, 'Cannot set both "scope" and "**components".'
|
||||
scheme = scope.get("scheme", "http")
|
||||
server = scope.get("server", None)
|
||||
path = scope.get("root_path", "") + scope["path"]
|
||||
|
@ -24,6 +27,10 @@ class URL:
|
|||
|
||||
if query_string:
|
||||
url += "?" + query_string.decode()
|
||||
elif components:
|
||||
assert not url, 'Cannot set both "scope" and "**components".'
|
||||
url = URL("").replace(**components).components.geturl()
|
||||
|
||||
self._url = url
|
||||
|
||||
@property
|
||||
|
@ -72,6 +79,10 @@ class URL:
|
|||
def port(self) -> typing.Optional[int]:
|
||||
return self.components.port
|
||||
|
||||
@property
|
||||
def is_secure(self) -> bool:
|
||||
return self.scheme in ("https", "wss")
|
||||
|
||||
def replace(self, **kwargs: typing.Any) -> "URL":
|
||||
if "hostname" in kwargs or "port" in kwargs:
|
||||
hostname = kwargs.pop("hostname", self.hostname)
|
||||
|
@ -80,6 +91,12 @@ class URL:
|
|||
kwargs["netloc"] = hostname
|
||||
else:
|
||||
kwargs["netloc"] = "%s:%d" % (hostname, port)
|
||||
if "secure" in kwargs:
|
||||
secure = kwargs.pop("secure")
|
||||
if self.scheme in ("http", "https"):
|
||||
kwargs["scheme"] = "https" if secure else "http"
|
||||
elif self.scheme in ("ws", "wss"):
|
||||
kwargs["scheme"] = "wss" if secure else "ws"
|
||||
components = self.components._replace(**kwargs)
|
||||
return URL(components.geturl())
|
||||
|
||||
|
@ -87,6 +104,8 @@ class URL:
|
|||
return str(self) == str(other)
|
||||
|
||||
def __str__(self) -> str:
|
||||
if self.scheme and not self.netloc:
|
||||
return str(self.replace(scheme=""))
|
||||
return self._url
|
||||
|
||||
def __repr__(self) -> str:
|
||||
|
|
|
@ -72,6 +72,11 @@ class Request(Mapping):
|
|||
self._cookies = cookies
|
||||
return self._cookies
|
||||
|
||||
def url_for(self, name: str, **path_params: typing.Any) -> URL:
|
||||
router = self._scope["router"]
|
||||
url = router.url_path_for(name, **path_params)
|
||||
return url.replace(secure=self.url.is_secure, netloc=self.url.netloc)
|
||||
|
||||
async def stream(self) -> typing.AsyncGenerator[bytes, None]:
|
||||
if hasattr(self, "_body"):
|
||||
yield self._body
|
||||
|
|
|
@ -5,6 +5,7 @@ import asyncio
|
|||
from concurrent.futures import ThreadPoolExecutor
|
||||
|
||||
from starlette.requests import Request
|
||||
from starlette.datastructures import URL
|
||||
from starlette.exceptions import HTTPException
|
||||
from starlette.responses import PlainTextResponse
|
||||
from starlette.types import Scope, ASGIApp, ASGIInstance, Send, Receive
|
||||
|
@ -68,7 +69,7 @@ class BaseRoute:
|
|||
def matches(self, scope: Scope) -> typing.Tuple[bool, Scope]:
|
||||
raise NotImplementedError() # pragma: no cover
|
||||
|
||||
def url_for(self, name: str, **path_params: str) -> str:
|
||||
def url_path_for(self, name: str, **path_params: str) -> URL:
|
||||
raise NotImplementedError() # pragma: no cover
|
||||
|
||||
def __call__(self, scope: Scope) -> ASGIInstance:
|
||||
|
@ -107,10 +108,10 @@ class Route(BaseRoute):
|
|||
return True, child_scope
|
||||
return False, {}
|
||||
|
||||
def url_for(self, name: str, **path_params: str) -> str:
|
||||
def url_path_for(self, name: str, **path_params: str) -> URL:
|
||||
if name != self.name or self.param_names != set(path_params.keys()):
|
||||
raise NoMatchFound()
|
||||
return replace_params(self.path, **path_params)
|
||||
return URL(scheme="http", path=replace_params(self.path, **path_params))
|
||||
|
||||
def __call__(self, scope: Scope) -> ASGIInstance:
|
||||
if self.methods and scope["method"] not in self.methods:
|
||||
|
@ -155,10 +156,10 @@ class WebSocketRoute(BaseRoute):
|
|||
return True, child_scope
|
||||
return False, {}
|
||||
|
||||
def url_for(self, name: str, **path_params: str) -> str:
|
||||
def url_path_for(self, name: str, **path_params: str) -> URL:
|
||||
if name != self.name or self.param_names != set(path_params.keys()):
|
||||
raise NoMatchFound()
|
||||
return replace_params(self.path, **path_params)
|
||||
return URL(scheme="ws", path=replace_params(self.path, **path_params))
|
||||
|
||||
def __call__(self, scope: Scope) -> ASGIInstance:
|
||||
return self.app(scope)
|
||||
|
@ -195,10 +196,11 @@ class Mount(BaseRoute):
|
|||
return True, child_scope
|
||||
return False, {}
|
||||
|
||||
def url_for(self, name: str, **path_params: str) -> str:
|
||||
def url_path_for(self, name: str, **path_params: str) -> URL:
|
||||
for route in self.routes or []:
|
||||
try:
|
||||
return self.path + route.url_for(name, **path_params)
|
||||
url = route.url_path_for(name, **path_params)
|
||||
return URL(scheme=url.scheme, path=self.path + url.path)
|
||||
except NoMatchFound as exc:
|
||||
pass
|
||||
raise NoMatchFound()
|
||||
|
@ -266,10 +268,10 @@ class Router:
|
|||
raise HTTPException(status_code=404)
|
||||
return PlainTextResponse("Not Found", status_code=404)
|
||||
|
||||
def url_for(self, name: str, **path_params: str) -> str:
|
||||
def url_path_for(self, name: str, **path_params: str) -> URL:
|
||||
for route in self.routes:
|
||||
try:
|
||||
return route.url_for(name, **path_params)
|
||||
return route.url_path_for(name, **path_params)
|
||||
except NoMatchFound as exc:
|
||||
pass
|
||||
raise NoMatchFound()
|
||||
|
@ -277,6 +279,9 @@ class Router:
|
|||
def __call__(self, scope: Scope) -> ASGIInstance:
|
||||
assert scope["type"] in ("http", "websocket")
|
||||
|
||||
if "router" not in scope:
|
||||
scope["router"] = self
|
||||
|
||||
for route in self.routes:
|
||||
matched, child_scope = route.matches(scope)
|
||||
if matched:
|
||||
|
|
|
@ -60,6 +60,11 @@ class WebSocket(Mapping):
|
|||
def path_params(self) -> dict:
|
||||
return self._scope.get("path_params", {})
|
||||
|
||||
def url_for(self, name: str, **path_params: typing.Any) -> URL:
|
||||
router = self._scope["router"]
|
||||
url = router.url_path_for(name, **path_params)
|
||||
return url.replace(secure=self.url.is_secure, netloc=self.url.netloc)
|
||||
|
||||
async def receive(self) -> Message:
|
||||
"""
|
||||
Receive ASGI websocket messages, ensuring valid state transitions.
|
||||
|
|
|
@ -86,8 +86,8 @@ async def websocket_endpoint(session):
|
|||
client = TestClient(app)
|
||||
|
||||
|
||||
def test_url_for():
|
||||
assert app.url_for("func_homepage") == "/func"
|
||||
def test_url_path_for():
|
||||
assert app.url_path_for("func_homepage") == "/func"
|
||||
|
||||
|
||||
def test_func_route():
|
||||
|
|
|
@ -85,12 +85,12 @@ def test_router():
|
|||
assert response.text == "xxxxx"
|
||||
|
||||
|
||||
def test_url_for():
|
||||
assert app.url_for("homepage") == "/"
|
||||
assert app.url_for("user", username="tomchristie") == "/users/tomchristie"
|
||||
assert app.url_for("websocket_endpoint") == "/ws"
|
||||
def test_url_path_for():
|
||||
assert app.url_path_for("homepage") == "/"
|
||||
assert app.url_path_for("user", username="tomchristie") == "/users/tomchristie"
|
||||
assert app.url_path_for("websocket_endpoint") == "/ws"
|
||||
with pytest.raises(NoMatchFound):
|
||||
assert app.url_for("broken")
|
||||
assert app.url_path_for("broken")
|
||||
|
||||
|
||||
def test_router_add_route():
|
||||
|
@ -110,7 +110,8 @@ def test_router_add_websocket_route():
|
|||
|
||||
|
||||
def http_endpoint(request):
|
||||
return Response("Hello, world", media_type="text/plain")
|
||||
url = request.url_for("http_endpoint")
|
||||
return Response("URL: %s" % url, media_type="text/plain")
|
||||
|
||||
|
||||
class WebsocketEndpoint:
|
||||
|
@ -120,7 +121,7 @@ class WebsocketEndpoint:
|
|||
async def __call__(self, receive, send):
|
||||
session = WebSocket(scope=self.scope, receive=receive, send=send)
|
||||
await session.accept()
|
||||
await session.send_json({"hello": "world"})
|
||||
await session.send_json({"URL": str(session.url_for("WebsocketEndpoint"))})
|
||||
await session.close()
|
||||
|
||||
|
||||
|
@ -137,10 +138,10 @@ def test_protocol_switch():
|
|||
|
||||
response = client.get("/")
|
||||
assert response.status_code == 200
|
||||
assert response.text == "Hello, world"
|
||||
assert response.text == "URL: http://testserver/"
|
||||
|
||||
with client.websocket_connect("/") as session:
|
||||
assert session.receive_json() == {"hello": "world"}
|
||||
assert session.receive_json() == {"URL": "ws://testserver/"}
|
||||
|
||||
with pytest.raises(WebSocketDisconnect):
|
||||
client.websocket_connect("/404")
|
||||
|
|
Loading…
Reference in New Issue