mirror of https://github.com/encode/starlette.git
Additional headers for WS `accept` message. (#1361)
* Additional headers for WS accept message. * Update tests/test_websockets.py Co-authored-by: Marcelo Trylesinski <marcelotryle@gmail.com> * fixup! Additional headers for WS accept message. * Update tests/test_websockets.py Co-authored-by: Marcelo Trylesinski <marcelotryle@gmail.com>
This commit is contained in:
parent
bff81f83d0
commit
9d686a7125
|
@ -51,7 +51,7 @@ For example: `websocket.path_params['username']`
|
|||
|
||||
### Accepting the connection
|
||||
|
||||
* `await websocket.accept(subprotocol=None)`
|
||||
* `await websocket.accept(subprotocol=None, headers=None)`
|
||||
|
||||
### Sending data
|
||||
|
||||
|
|
|
@ -298,6 +298,7 @@ class WebSocketTestSession:
|
|||
self.app = app
|
||||
self.scope = scope
|
||||
self.accepted_subprotocol = None
|
||||
self.extra_headers = None
|
||||
self.portal_factory = portal_factory
|
||||
self._receive_queue: "queue.Queue[typing.Any]" = queue.Queue()
|
||||
self._send_queue: "queue.Queue[typing.Any]" = queue.Queue()
|
||||
|
@ -315,6 +316,7 @@ class WebSocketTestSession:
|
|||
self.exit_stack.close()
|
||||
raise
|
||||
self.accepted_subprotocol = message.get("subprotocol", None)
|
||||
self.extra_headers = message.get("headers", None)
|
||||
return self
|
||||
|
||||
def __exit__(self, *args: typing.Any) -> None:
|
||||
|
|
|
@ -69,11 +69,17 @@ class WebSocket(HTTPConnection):
|
|||
else:
|
||||
raise RuntimeError('Cannot call "send" once a close message has been sent.')
|
||||
|
||||
async def accept(self, subprotocol: str = None) -> None:
|
||||
async def accept(
|
||||
self,
|
||||
subprotocol: str = None,
|
||||
headers: typing.Iterable[typing.Tuple[bytes, bytes]] = None,
|
||||
) -> None:
|
||||
if self.client_state == WebSocketState.CONNECTING:
|
||||
# If we haven't yet seen the 'connect' message, then wait for it first.
|
||||
await self.receive()
|
||||
await self.send({"type": "websocket.accept", "subprotocol": subprotocol})
|
||||
await self.send(
|
||||
{"type": "websocket.accept", "subprotocol": subprotocol, "headers": headers}
|
||||
)
|
||||
|
||||
def _raise_on_disconnect(self, message: Message) -> None:
|
||||
if message["type"] == "websocket.disconnect":
|
||||
|
|
|
@ -301,6 +301,20 @@ def test_subprotocol(test_client_factory):
|
|||
assert websocket.accepted_subprotocol == "wamp"
|
||||
|
||||
|
||||
def test_additional_headers(test_client_factory):
|
||||
def app(scope):
|
||||
async def asgi(receive, send):
|
||||
websocket = WebSocket(scope, receive=receive, send=send)
|
||||
await websocket.accept(headers=[(b"additional", b"header")])
|
||||
await websocket.close()
|
||||
|
||||
return asgi
|
||||
|
||||
client = test_client_factory(app)
|
||||
with client.websocket_connect("/") as websocket:
|
||||
assert websocket.extra_headers == [(b"additional", b"header")]
|
||||
|
||||
|
||||
def test_websocket_exception(test_client_factory):
|
||||
def app(scope):
|
||||
async def asgi(receive, send):
|
||||
|
|
Loading…
Reference in New Issue