From c047fe4e75a6af4264f34feac230e09d3d272ab6 Mon Sep 17 00:00:00 2001 From: Tom Christie Date: Mon, 29 Oct 2018 11:14:42 +0000 Subject: [PATCH] Support app.url_path_for and request.url_for (#153) --- docs/release-notes.md | 12 ++++++++++-- starlette/applications.py | 6 ++++-- starlette/datastructures.py | 21 ++++++++++++++++++++- starlette/requests.py | 5 +++++ starlette/routing.py | 23 ++++++++++++++--------- starlette/websockets.py | 5 +++++ tests/test_applications.py | 4 ++-- tests/test_routing.py | 19 ++++++++++--------- 8 files changed, 70 insertions(+), 25 deletions(-) diff --git a/docs/release-notes.md b/docs/release-notes.md index 1d0aa534..d7cdeb58 100644 --- a/docs/release-notes.md +++ b/docs/release-notes.md @@ -19,9 +19,17 @@ The path parameters are available on the request as `request.path_params`. This is different to most Python webframeworks, but I think it actually ends up being much more nicely consistent all the way through. -### app.url_for(name, **path_params) +### request.url_for(name, **path_params) -Applications now support URL reversing with `app.url_for(name, **path_params)`. +Request and WebSocketSession now support URL reversing with `request.url_for(name, **path_params)`. +This method returns a fully qualified `URL` instance. +The URL instance is a string-like object. + +### app.url_path_for(name, **path_params) + +Applications now support URL path reversing with `app.url_path_for(name, **path_params)`. +This method returns a `URL` instance with the path and scheme set. +The URL instance is a string-like object, and will return only the path if coerced to a string. ### app.routes diff --git a/starlette/applications.py b/starlette/applications.py index b1bf1a12..fde36016 100644 --- a/starlette/applications.py +++ b/starlette/applications.py @@ -1,5 +1,6 @@ import typing +from starlette.datastructures import URL from starlette.exceptions import ExceptionMiddleware from starlette.lifespan import LifespanHandler from starlette.routing import BaseRoute, Router @@ -72,11 +73,12 @@ class Starlette: return decorator - def url_for(self, name: str, **path_params: str) -> str: - return self.router.url_for(name, **path_params) + def url_path_for(self, name: str, **path_params: str) -> URL: + return self.router.url_path_for(name, **path_params) def __call__(self, scope: Scope) -> ASGIInstance: scope["app"] = self + scope["router"] = self.router if scope["type"] == "lifespan": return self.lifespan_handler(scope) return self.exception_middleware(scope) diff --git a/starlette/datastructures.py b/starlette/datastructures.py index 019dab01..55e90fbe 100644 --- a/starlette/datastructures.py +++ b/starlette/datastructures.py @@ -4,9 +4,12 @@ from urllib.parse import parse_qsl, unquote, urlencode, urlparse, ParseResult class URL: - def __init__(self, url: str = "", scope: Scope = None) -> None: + def __init__( + self, url: str = "", scope: Scope = None, **components: typing.Any + ) -> None: if scope is not None: assert not url, 'Cannot set both "url" and "scope".' + assert not components, 'Cannot set both "scope" and "**components".' scheme = scope.get("scheme", "http") server = scope.get("server", None) path = scope.get("root_path", "") + scope["path"] @@ -24,6 +27,10 @@ class URL: if query_string: url += "?" + query_string.decode() + elif components: + assert not url, 'Cannot set both "scope" and "**components".' + url = URL("").replace(**components).components.geturl() + self._url = url @property @@ -72,6 +79,10 @@ class URL: def port(self) -> typing.Optional[int]: return self.components.port + @property + def is_secure(self) -> bool: + return self.scheme in ("https", "wss") + def replace(self, **kwargs: typing.Any) -> "URL": if "hostname" in kwargs or "port" in kwargs: hostname = kwargs.pop("hostname", self.hostname) @@ -80,6 +91,12 @@ class URL: kwargs["netloc"] = hostname else: kwargs["netloc"] = "%s:%d" % (hostname, port) + if "secure" in kwargs: + secure = kwargs.pop("secure") + if self.scheme in ("http", "https"): + kwargs["scheme"] = "https" if secure else "http" + elif self.scheme in ("ws", "wss"): + kwargs["scheme"] = "wss" if secure else "ws" components = self.components._replace(**kwargs) return URL(components.geturl()) @@ -87,6 +104,8 @@ class URL: return str(self) == str(other) def __str__(self) -> str: + if self.scheme and not self.netloc: + return str(self.replace(scheme="")) return self._url def __repr__(self) -> str: diff --git a/starlette/requests.py b/starlette/requests.py index 71a7d2e6..79b87048 100644 --- a/starlette/requests.py +++ b/starlette/requests.py @@ -72,6 +72,11 @@ class Request(Mapping): self._cookies = cookies return self._cookies + def url_for(self, name: str, **path_params: typing.Any) -> URL: + router = self._scope["router"] + url = router.url_path_for(name, **path_params) + return url.replace(secure=self.url.is_secure, netloc=self.url.netloc) + async def stream(self) -> typing.AsyncGenerator[bytes, None]: if hasattr(self, "_body"): yield self._body diff --git a/starlette/routing.py b/starlette/routing.py index 2a811c5c..bb55fd1c 100644 --- a/starlette/routing.py +++ b/starlette/routing.py @@ -5,6 +5,7 @@ import asyncio from concurrent.futures import ThreadPoolExecutor from starlette.requests import Request +from starlette.datastructures import URL from starlette.exceptions import HTTPException from starlette.responses import PlainTextResponse from starlette.types import Scope, ASGIApp, ASGIInstance, Send, Receive @@ -68,7 +69,7 @@ class BaseRoute: def matches(self, scope: Scope) -> typing.Tuple[bool, Scope]: raise NotImplementedError() # pragma: no cover - def url_for(self, name: str, **path_params: str) -> str: + def url_path_for(self, name: str, **path_params: str) -> URL: raise NotImplementedError() # pragma: no cover def __call__(self, scope: Scope) -> ASGIInstance: @@ -107,10 +108,10 @@ class Route(BaseRoute): return True, child_scope return False, {} - def url_for(self, name: str, **path_params: str) -> str: + def url_path_for(self, name: str, **path_params: str) -> URL: if name != self.name or self.param_names != set(path_params.keys()): raise NoMatchFound() - return replace_params(self.path, **path_params) + return URL(scheme="http", path=replace_params(self.path, **path_params)) def __call__(self, scope: Scope) -> ASGIInstance: if self.methods and scope["method"] not in self.methods: @@ -155,10 +156,10 @@ class WebSocketRoute(BaseRoute): return True, child_scope return False, {} - def url_for(self, name: str, **path_params: str) -> str: + def url_path_for(self, name: str, **path_params: str) -> URL: if name != self.name or self.param_names != set(path_params.keys()): raise NoMatchFound() - return replace_params(self.path, **path_params) + return URL(scheme="ws", path=replace_params(self.path, **path_params)) def __call__(self, scope: Scope) -> ASGIInstance: return self.app(scope) @@ -195,10 +196,11 @@ class Mount(BaseRoute): return True, child_scope return False, {} - def url_for(self, name: str, **path_params: str) -> str: + def url_path_for(self, name: str, **path_params: str) -> URL: for route in self.routes or []: try: - return self.path + route.url_for(name, **path_params) + url = route.url_path_for(name, **path_params) + return URL(scheme=url.scheme, path=self.path + url.path) except NoMatchFound as exc: pass raise NoMatchFound() @@ -266,10 +268,10 @@ class Router: raise HTTPException(status_code=404) return PlainTextResponse("Not Found", status_code=404) - def url_for(self, name: str, **path_params: str) -> str: + def url_path_for(self, name: str, **path_params: str) -> URL: for route in self.routes: try: - return route.url_for(name, **path_params) + return route.url_path_for(name, **path_params) except NoMatchFound as exc: pass raise NoMatchFound() @@ -277,6 +279,9 @@ class Router: def __call__(self, scope: Scope) -> ASGIInstance: assert scope["type"] in ("http", "websocket") + if "router" not in scope: + scope["router"] = self + for route in self.routes: matched, child_scope = route.matches(scope) if matched: diff --git a/starlette/websockets.py b/starlette/websockets.py index c1d73bfd..45431716 100644 --- a/starlette/websockets.py +++ b/starlette/websockets.py @@ -60,6 +60,11 @@ class WebSocket(Mapping): def path_params(self) -> dict: return self._scope.get("path_params", {}) + def url_for(self, name: str, **path_params: typing.Any) -> URL: + router = self._scope["router"] + url = router.url_path_for(name, **path_params) + return url.replace(secure=self.url.is_secure, netloc=self.url.netloc) + async def receive(self) -> Message: """ Receive ASGI websocket messages, ensuring valid state transitions. diff --git a/tests/test_applications.py b/tests/test_applications.py index 90d54f9c..b718bf19 100644 --- a/tests/test_applications.py +++ b/tests/test_applications.py @@ -86,8 +86,8 @@ async def websocket_endpoint(session): client = TestClient(app) -def test_url_for(): - assert app.url_for("func_homepage") == "/func" +def test_url_path_for(): + assert app.url_path_for("func_homepage") == "/func" def test_func_route(): diff --git a/tests/test_routing.py b/tests/test_routing.py index 3e0b5735..725f3691 100644 --- a/tests/test_routing.py +++ b/tests/test_routing.py @@ -85,12 +85,12 @@ def test_router(): assert response.text == "xxxxx" -def test_url_for(): - assert app.url_for("homepage") == "/" - assert app.url_for("user", username="tomchristie") == "/users/tomchristie" - assert app.url_for("websocket_endpoint") == "/ws" +def test_url_path_for(): + assert app.url_path_for("homepage") == "/" + assert app.url_path_for("user", username="tomchristie") == "/users/tomchristie" + assert app.url_path_for("websocket_endpoint") == "/ws" with pytest.raises(NoMatchFound): - assert app.url_for("broken") + assert app.url_path_for("broken") def test_router_add_route(): @@ -110,7 +110,8 @@ def test_router_add_websocket_route(): def http_endpoint(request): - return Response("Hello, world", media_type="text/plain") + url = request.url_for("http_endpoint") + return Response("URL: %s" % url, media_type="text/plain") class WebsocketEndpoint: @@ -120,7 +121,7 @@ class WebsocketEndpoint: async def __call__(self, receive, send): session = WebSocket(scope=self.scope, receive=receive, send=send) await session.accept() - await session.send_json({"hello": "world"}) + await session.send_json({"URL": str(session.url_for("WebsocketEndpoint"))}) await session.close() @@ -137,10 +138,10 @@ def test_protocol_switch(): response = client.get("/") assert response.status_code == 200 - assert response.text == "Hello, world" + assert response.text == "URL: http://testserver/" with client.websocket_connect("/") as session: - assert session.receive_json() == {"hello": "world"} + assert session.receive_json() == {"URL": "ws://testserver/"} with pytest.raises(WebSocketDisconnect): client.websocket_connect("/404")