import anyio import pytest from starlette import status from starlette.types import Receive, Scope, Send from starlette.websockets import WebSocket, WebSocketDisconnect, WebSocketState def test_websocket_url(test_client_factory): async def app(scope: Scope, receive: Receive, send: Send) -> None: websocket = WebSocket(scope, receive=receive, send=send) await websocket.accept() await websocket.send_json({"url": str(websocket.url)}) await websocket.close() client = test_client_factory(app) with client.websocket_connect("/123?a=abc") as websocket: data = websocket.receive_json() assert data == {"url": "ws://testserver/123?a=abc"} def test_websocket_binary_json(test_client_factory): async def app(scope: Scope, receive: Receive, send: Send) -> None: websocket = WebSocket(scope, receive=receive, send=send) await websocket.accept() message = await websocket.receive_json(mode="binary") await websocket.send_json(message, mode="binary") await websocket.close() client = test_client_factory(app) with client.websocket_connect("/123?a=abc") as websocket: websocket.send_json({"test": "data"}, mode="binary") data = websocket.receive_json(mode="binary") assert data == {"test": "data"} def test_websocket_query_params(test_client_factory): async def app(scope: Scope, receive: Receive, send: Send) -> None: websocket = WebSocket(scope, receive=receive, send=send) query_params = dict(websocket.query_params) await websocket.accept() await websocket.send_json({"params": query_params}) await websocket.close() client = test_client_factory(app) with client.websocket_connect("/?a=abc&b=456") as websocket: data = websocket.receive_json() assert data == {"params": {"a": "abc", "b": "456"}} def test_websocket_headers(test_client_factory): async def app(scope: Scope, receive: Receive, send: Send) -> None: websocket = WebSocket(scope, receive=receive, send=send) headers = dict(websocket.headers) await websocket.accept() await websocket.send_json({"headers": headers}) await websocket.close() client = test_client_factory(app) with client.websocket_connect("/") as websocket: expected_headers = { "accept": "*/*", "accept-encoding": "gzip, deflate", "connection": "upgrade", "host": "testserver", "user-agent": "testclient", "sec-websocket-key": "testserver==", "sec-websocket-version": "13", } data = websocket.receive_json() assert data == {"headers": expected_headers} def test_websocket_port(test_client_factory): async def app(scope: Scope, receive: Receive, send: Send) -> None: websocket = WebSocket(scope, receive=receive, send=send) await websocket.accept() await websocket.send_json({"port": websocket.url.port}) await websocket.close() client = test_client_factory(app) with client.websocket_connect("ws://example.com:123/123?a=abc") as websocket: data = websocket.receive_json() assert data == {"port": 123} def test_websocket_send_and_receive_text(test_client_factory): async def app(scope: Scope, receive: Receive, send: Send) -> None: websocket = WebSocket(scope, receive=receive, send=send) await websocket.accept() data = await websocket.receive_text() await websocket.send_text("Message was: " + data) await websocket.close() client = test_client_factory(app) with client.websocket_connect("/") as websocket: websocket.send_text("Hello, world!") data = websocket.receive_text() assert data == "Message was: Hello, world!" def test_websocket_send_and_receive_bytes(test_client_factory): async def app(scope: Scope, receive: Receive, send: Send) -> None: websocket = WebSocket(scope, receive=receive, send=send) await websocket.accept() data = await websocket.receive_bytes() await websocket.send_bytes(b"Message was: " + data) await websocket.close() client = test_client_factory(app) with client.websocket_connect("/") as websocket: websocket.send_bytes(b"Hello, world!") data = websocket.receive_bytes() assert data == b"Message was: Hello, world!" def test_websocket_send_and_receive_json(test_client_factory): async def app(scope: Scope, receive: Receive, send: Send) -> None: websocket = WebSocket(scope, receive=receive, send=send) await websocket.accept() data = await websocket.receive_json() await websocket.send_json({"message": data}) await websocket.close() client = test_client_factory(app) with client.websocket_connect("/") as websocket: websocket.send_json({"hello": "world"}) data = websocket.receive_json() assert data == {"message": {"hello": "world"}} def test_websocket_iter_text(test_client_factory): async def app(scope: Scope, receive: Receive, send: Send) -> None: websocket = WebSocket(scope, receive=receive, send=send) await websocket.accept() async for data in websocket.iter_text(): await websocket.send_text("Message was: " + data) client = test_client_factory(app) with client.websocket_connect("/") as websocket: websocket.send_text("Hello, world!") data = websocket.receive_text() assert data == "Message was: Hello, world!" def test_websocket_iter_bytes(test_client_factory): async def app(scope: Scope, receive: Receive, send: Send) -> None: websocket = WebSocket(scope, receive=receive, send=send) await websocket.accept() async for data in websocket.iter_bytes(): await websocket.send_bytes(b"Message was: " + data) client = test_client_factory(app) with client.websocket_connect("/") as websocket: websocket.send_bytes(b"Hello, world!") data = websocket.receive_bytes() assert data == b"Message was: Hello, world!" def test_websocket_iter_json(test_client_factory): async def app(scope: Scope, receive: Receive, send: Send) -> None: websocket = WebSocket(scope, receive=receive, send=send) await websocket.accept() async for data in websocket.iter_json(): await websocket.send_json({"message": data}) client = test_client_factory(app) with client.websocket_connect("/") as websocket: websocket.send_json({"hello": "world"}) data = websocket.receive_json() assert data == {"message": {"hello": "world"}} def test_websocket_concurrency_pattern(test_client_factory): stream_send, stream_receive = anyio.create_memory_object_stream() async def reader(websocket): async with stream_send: async for data in websocket.iter_json(): await stream_send.send(data) async def writer(websocket): async with stream_receive: async for message in stream_receive: await websocket.send_json(message) async def app(scope: Scope, receive: Receive, send: Send) -> None: websocket = WebSocket(scope, receive=receive, send=send) await websocket.accept() async with anyio.create_task_group() as task_group: task_group.start_soon(reader, websocket) await writer(websocket) await websocket.close() client = test_client_factory(app) with client.websocket_connect("/") as websocket: websocket.send_json({"hello": "world"}) data = websocket.receive_json() assert data == {"hello": "world"} def test_client_close(test_client_factory): close_code = None async def app(scope: Scope, receive: Receive, send: Send) -> None: nonlocal close_code websocket = WebSocket(scope, receive=receive, send=send) await websocket.accept() try: await websocket.receive_text() except WebSocketDisconnect as exc: close_code = exc.code client = test_client_factory(app) with client.websocket_connect("/") as websocket: websocket.close(code=status.WS_1001_GOING_AWAY) assert close_code == status.WS_1001_GOING_AWAY def test_application_close(test_client_factory): async def app(scope: Scope, receive: Receive, send: Send) -> None: websocket = WebSocket(scope, receive=receive, send=send) await websocket.accept() await websocket.close(status.WS_1001_GOING_AWAY) client = test_client_factory(app) with client.websocket_connect("/") as websocket: with pytest.raises(WebSocketDisconnect) as exc: websocket.receive_text() assert exc.value.code == status.WS_1001_GOING_AWAY def test_rejected_connection(test_client_factory): async def app(scope: Scope, receive: Receive, send: Send) -> None: websocket = WebSocket(scope, receive=receive, send=send) await websocket.close(status.WS_1001_GOING_AWAY) client = test_client_factory(app) with pytest.raises(WebSocketDisconnect) as exc: with client.websocket_connect("/"): pass # pragma: nocover assert exc.value.code == status.WS_1001_GOING_AWAY def test_subprotocol(test_client_factory): async def app(scope: Scope, receive: Receive, send: Send) -> None: websocket = WebSocket(scope, receive=receive, send=send) assert websocket["subprotocols"] == ["soap", "wamp"] await websocket.accept(subprotocol="wamp") await websocket.close() client = test_client_factory(app) with client.websocket_connect("/", subprotocols=["soap", "wamp"]) as websocket: assert websocket.accepted_subprotocol == "wamp" def test_additional_headers(test_client_factory): async def app(scope: Scope, receive: Receive, send: Send) -> None: websocket = WebSocket(scope, receive=receive, send=send) await websocket.accept(headers=[(b"additional", b"header")]) await websocket.close() client = test_client_factory(app) with client.websocket_connect("/") as websocket: assert websocket.extra_headers == [(b"additional", b"header")] def test_no_additional_headers(test_client_factory): async def app(scope: Scope, receive: Receive, send: Send) -> None: websocket = WebSocket(scope, receive=receive, send=send) await websocket.accept() await websocket.close() client = test_client_factory(app) with client.websocket_connect("/") as websocket: assert websocket.extra_headers == [] def test_websocket_exception(test_client_factory): async def app(scope: Scope, receive: Receive, send: Send) -> None: assert False client = test_client_factory(app) with pytest.raises(AssertionError): with client.websocket_connect("/123?a=abc"): pass # pragma: nocover def test_duplicate_close(test_client_factory): async def app(scope: Scope, receive: Receive, send: Send) -> None: websocket = WebSocket(scope, receive=receive, send=send) await websocket.accept() await websocket.close() await websocket.close() client = test_client_factory(app) with pytest.raises(RuntimeError): with client.websocket_connect("/"): pass # pragma: nocover def test_duplicate_disconnect(test_client_factory): async def app(scope: Scope, receive: Receive, send: Send) -> None: websocket = WebSocket(scope, receive=receive, send=send) await websocket.accept() message = await websocket.receive() assert message["type"] == "websocket.disconnect" message = await websocket.receive() client = test_client_factory(app) with pytest.raises(RuntimeError): with client.websocket_connect("/") as websocket: websocket.close() def test_websocket_scope_interface(): """ A WebSocket can be instantiated with a scope, and presents a `Mapping` interface. """ async def mock_receive(): pass # pragma: no cover async def mock_send(message): pass # pragma: no cover websocket = WebSocket( {"type": "websocket", "path": "/abc/", "headers": []}, receive=mock_receive, send=mock_send, ) assert websocket["type"] == "websocket" assert dict(websocket) == {"type": "websocket", "path": "/abc/", "headers": []} assert len(websocket) == 3 # check __eq__ and __hash__ assert websocket != WebSocket( {"type": "websocket", "path": "/abc/", "headers": []}, receive=mock_receive, send=mock_send, ) assert websocket == websocket assert websocket in {websocket} assert {websocket} == {websocket} def test_websocket_close_reason(test_client_factory) -> None: async def app(scope: Scope, receive: Receive, send: Send) -> None: websocket = WebSocket(scope, receive=receive, send=send) await websocket.accept() await websocket.close(code=status.WS_1001_GOING_AWAY, reason="Going Away") client = test_client_factory(app) with client.websocket_connect("/") as websocket: with pytest.raises(WebSocketDisconnect) as exc: websocket.receive_text() assert exc.value.code == status.WS_1001_GOING_AWAY assert exc.value.reason == "Going Away" def test_send_json_invalid_mode(test_client_factory): async def app(scope: Scope, receive: Receive, send: Send) -> None: websocket = WebSocket(scope, receive=receive, send=send) await websocket.accept() await websocket.send_json({}, mode="invalid") client = test_client_factory(app) with pytest.raises(RuntimeError): with client.websocket_connect("/"): pass # pragma: nocover def test_receive_json_invalid_mode(test_client_factory): async def app(scope: Scope, receive: Receive, send: Send) -> None: websocket = WebSocket(scope, receive=receive, send=send) await websocket.accept() await websocket.receive_json(mode="invalid") client = test_client_factory(app) with pytest.raises(RuntimeError): with client.websocket_connect("/"): pass # pragma: nocover def test_receive_text_before_accept(test_client_factory): async def app(scope: Scope, receive: Receive, send: Send) -> None: websocket = WebSocket(scope, receive=receive, send=send) await websocket.receive_text() client = test_client_factory(app) with pytest.raises(RuntimeError): with client.websocket_connect("/"): pass # pragma: nocover def test_receive_bytes_before_accept(test_client_factory): async def app(scope: Scope, receive: Receive, send: Send) -> None: websocket = WebSocket(scope, receive=receive, send=send) await websocket.receive_bytes() client = test_client_factory(app) with pytest.raises(RuntimeError): with client.websocket_connect("/"): pass # pragma: nocover def test_receive_json_before_accept(test_client_factory): async def app(scope: Scope, receive: Receive, send: Send) -> None: websocket = WebSocket(scope, receive=receive, send=send) await websocket.receive_json() client = test_client_factory(app) with pytest.raises(RuntimeError): with client.websocket_connect("/"): pass # pragma: nocover def test_send_before_accept(test_client_factory): async def app(scope: Scope, receive: Receive, send: Send) -> None: websocket = WebSocket(scope, receive=receive, send=send) await websocket.send({"type": "websocket.send"}) client = test_client_factory(app) with pytest.raises(RuntimeError): with client.websocket_connect("/"): pass # pragma: nocover def test_send_wrong_message_type(test_client_factory): async def app(scope: Scope, receive: Receive, send: Send) -> None: websocket = WebSocket(scope, receive=receive, send=send) await websocket.send({"type": "websocket.accept"}) await websocket.send({"type": "websocket.accept"}) client = test_client_factory(app) with pytest.raises(RuntimeError): with client.websocket_connect("/"): pass # pragma: nocover def test_receive_before_accept(test_client_factory): async def app(scope: Scope, receive: Receive, send: Send) -> None: websocket = WebSocket(scope, receive=receive, send=send) await websocket.accept() websocket.client_state = WebSocketState.CONNECTING await websocket.receive() client = test_client_factory(app) with pytest.raises(RuntimeError): with client.websocket_connect("/") as websocket: websocket.send({"type": "websocket.send"}) def test_receive_wrong_message_type(test_client_factory): async def app(scope, receive, send): websocket = WebSocket(scope, receive=receive, send=send) await websocket.accept() await websocket.receive() client = test_client_factory(app) with pytest.raises(RuntimeError): with client.websocket_connect("/") as websocket: websocket.send({"type": "websocket.connect"})