From 9d686a7125131fe026071d04e0d0e0a726e1afc5 Mon Sep 17 00:00:00 2001 From: matiuszka <40184215+matiuszka@users.noreply.github.com> Date: Thu, 6 Jan 2022 11:55:29 +0100 Subject: [PATCH] Additional headers for WS `accept` message. (#1361) * Additional headers for WS accept message. * Update tests/test_websockets.py Co-authored-by: Marcelo Trylesinski * fixup! Additional headers for WS accept message. * Update tests/test_websockets.py Co-authored-by: Marcelo Trylesinski --- docs/websockets.md | 2 +- starlette/testclient.py | 2 ++ starlette/websockets.py | 10 ++++++++-- tests/test_websockets.py | 14 ++++++++++++++ 4 files changed, 25 insertions(+), 3 deletions(-) diff --git a/docs/websockets.md b/docs/websockets.md index 80749618..43406ace 100644 --- a/docs/websockets.md +++ b/docs/websockets.md @@ -51,7 +51,7 @@ For example: `websocket.path_params['username']` ### Accepting the connection -* `await websocket.accept(subprotocol=None)` +* `await websocket.accept(subprotocol=None, headers=None)` ### Sending data diff --git a/starlette/testclient.py b/starlette/testclient.py index 40220fb4..0b4bc78d 100644 --- a/starlette/testclient.py +++ b/starlette/testclient.py @@ -298,6 +298,7 @@ class WebSocketTestSession: self.app = app self.scope = scope self.accepted_subprotocol = None + self.extra_headers = None self.portal_factory = portal_factory self._receive_queue: "queue.Queue[typing.Any]" = queue.Queue() self._send_queue: "queue.Queue[typing.Any]" = queue.Queue() @@ -315,6 +316,7 @@ class WebSocketTestSession: self.exit_stack.close() raise self.accepted_subprotocol = message.get("subprotocol", None) + self.extra_headers = message.get("headers", None) return self def __exit__(self, *args: typing.Any) -> None: diff --git a/starlette/websockets.py b/starlette/websockets.py index b9b8844d..7632b28c 100644 --- a/starlette/websockets.py +++ b/starlette/websockets.py @@ -69,11 +69,17 @@ class WebSocket(HTTPConnection): else: raise RuntimeError('Cannot call "send" once a close message has been sent.') - async def accept(self, subprotocol: str = None) -> None: + async def accept( + self, + subprotocol: str = None, + headers: typing.Iterable[typing.Tuple[bytes, bytes]] = None, + ) -> None: if self.client_state == WebSocketState.CONNECTING: # If we haven't yet seen the 'connect' message, then wait for it first. await self.receive() - await self.send({"type": "websocket.accept", "subprotocol": subprotocol}) + await self.send( + {"type": "websocket.accept", "subprotocol": subprotocol, "headers": headers} + ) def _raise_on_disconnect(self, message: Message) -> None: if message["type"] == "websocket.disconnect": diff --git a/tests/test_websockets.py b/tests/test_websockets.py index e02d433d..bf025330 100644 --- a/tests/test_websockets.py +++ b/tests/test_websockets.py @@ -301,6 +301,20 @@ def test_subprotocol(test_client_factory): assert websocket.accepted_subprotocol == "wamp" +def test_additional_headers(test_client_factory): + def app(scope): + async def asgi(receive, send): + websocket = WebSocket(scope, receive=receive, send=send) + await websocket.accept(headers=[(b"additional", b"header")]) + await websocket.close() + + return asgi + + client = test_client_factory(app) + with client.websocket_connect("/") as websocket: + assert websocket.extra_headers == [(b"additional", b"header")] + + def test_websocket_exception(test_client_factory): def app(scope): async def asgi(receive, send):