Additional headers for WS `accept` message. (#1361)

* Additional headers for WS accept message.

* Update tests/test_websockets.py

Co-authored-by: Marcelo Trylesinski <marcelotryle@gmail.com>

* fixup! Additional headers for WS accept message.

* Update tests/test_websockets.py

Co-authored-by: Marcelo Trylesinski <marcelotryle@gmail.com>
This commit is contained in:
matiuszka 2022-01-06 11:55:29 +01:00 committed by GitHub
parent bff81f83d0
commit 9d686a7125
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 25 additions and 3 deletions

View File

@ -51,7 +51,7 @@ For example: `websocket.path_params['username']`
### Accepting the connection
* `await websocket.accept(subprotocol=None)`
* `await websocket.accept(subprotocol=None, headers=None)`
### Sending data

View File

@ -298,6 +298,7 @@ class WebSocketTestSession:
self.app = app
self.scope = scope
self.accepted_subprotocol = None
self.extra_headers = None
self.portal_factory = portal_factory
self._receive_queue: "queue.Queue[typing.Any]" = queue.Queue()
self._send_queue: "queue.Queue[typing.Any]" = queue.Queue()
@ -315,6 +316,7 @@ class WebSocketTestSession:
self.exit_stack.close()
raise
self.accepted_subprotocol = message.get("subprotocol", None)
self.extra_headers = message.get("headers", None)
return self
def __exit__(self, *args: typing.Any) -> None:

View File

@ -69,11 +69,17 @@ class WebSocket(HTTPConnection):
else:
raise RuntimeError('Cannot call "send" once a close message has been sent.')
async def accept(self, subprotocol: str = None) -> None:
async def accept(
self,
subprotocol: str = None,
headers: typing.Iterable[typing.Tuple[bytes, bytes]] = None,
) -> None:
if self.client_state == WebSocketState.CONNECTING:
# If we haven't yet seen the 'connect' message, then wait for it first.
await self.receive()
await self.send({"type": "websocket.accept", "subprotocol": subprotocol})
await self.send(
{"type": "websocket.accept", "subprotocol": subprotocol, "headers": headers}
)
def _raise_on_disconnect(self, message: Message) -> None:
if message["type"] == "websocket.disconnect":

View File

@ -301,6 +301,20 @@ def test_subprotocol(test_client_factory):
assert websocket.accepted_subprotocol == "wamp"
def test_additional_headers(test_client_factory):
def app(scope):
async def asgi(receive, send):
websocket = WebSocket(scope, receive=receive, send=send)
await websocket.accept(headers=[(b"additional", b"header")])
await websocket.close()
return asgi
client = test_client_factory(app)
with client.websocket_connect("/") as websocket:
assert websocket.extra_headers == [(b"additional", b"header")]
def test_websocket_exception(test_client_factory):
def app(scope):
async def asgi(receive, send):