From 9dc9d2e92919306fae683746e38530ff55ed3092 Mon Sep 17 00:00:00 2001 From: Hao Guan Date: Tue, 19 Nov 2024 03:28:43 +0800 Subject: [PATCH] fix(testclient): exclude query sting from `raw_path` (#2716) Co-authored-by: Marcelo Trylesinski --- starlette/testclient.py | 4 ++-- tests/test_testclient.py | 24 ++++++++++++++++++++++++ 2 files changed, 26 insertions(+), 2 deletions(-) diff --git a/starlette/testclient.py b/starlette/testclient.py index 5143c4c5..645ca109 100644 --- a/starlette/testclient.py +++ b/starlette/testclient.py @@ -281,7 +281,7 @@ class _TestClientTransport(httpx.BaseTransport): scope = { "type": "websocket", "path": unquote(path), - "raw_path": raw_path, + "raw_path": raw_path.split(b"?", 1)[0], "root_path": self.root_path, "scheme": scheme, "query_string": query.encode(), @@ -300,7 +300,7 @@ class _TestClientTransport(httpx.BaseTransport): "http_version": "1.1", "method": request.method, "path": unquote(path), - "raw_path": raw_path, + "raw_path": raw_path.split(b"?", 1)[0], "root_path": self.root_path, "scheme": scheme, "query_string": query.encode(), diff --git a/tests/test_testclient.py b/tests/test_testclient.py index 92f16d33..68593a9a 100644 --- a/tests/test_testclient.py +++ b/tests/test_testclient.py @@ -378,3 +378,27 @@ def test_merge_url(test_client_factory: TestClientFactory) -> None: client = test_client_factory(app, base_url="http://testserver/api/v1/") response = client.get("/bar") assert response.text == "/api/v1/bar" + + +def test_raw_path_with_querystring(test_client_factory: TestClientFactory) -> None: + async def app(scope: Scope, receive: Receive, send: Send) -> None: + response = Response(scope.get("raw_path")) + await response(scope, receive, send) + + client = test_client_factory(app) + response = client.get("/hello-world", params={"foo": "bar"}) + assert response.content == b"/hello-world" + + +def test_websocket_raw_path_without_params(test_client_factory: TestClientFactory) -> None: + async def app(scope: Scope, receive: Receive, send: Send) -> None: + websocket = WebSocket(scope, receive=receive, send=send) + await websocket.accept() + raw_path = scope.get("raw_path") + assert raw_path is not None + await websocket.send_bytes(raw_path) + + client = test_client_factory(app) + with client.websocket_connect("/hello-world", params={"foo": "bar"}) as websocket: + data = websocket.receive_bytes() + assert data == b"/hello-world"