mirror of https://github.com/encode/starlette.git
635 lines
24 KiB
Python
635 lines
24 KiB
Python
import sys
|
|
from typing import Any, Callable, MutableMapping
|
|
|
|
import anyio
|
|
import pytest
|
|
from anyio.abc import ObjectReceiveStream, ObjectSendStream
|
|
|
|
from starlette import status
|
|
from starlette.responses import Response
|
|
from starlette.testclient import TestClient, WebSocketDenialResponse
|
|
from starlette.types import Message, Receive, Scope, Send
|
|
from starlette.websockets import WebSocket, WebSocketDisconnect, WebSocketState
|
|
|
|
TestClientFactory = Callable[..., TestClient]
|
|
|
|
|
|
def test_websocket_url(test_client_factory: TestClientFactory) -> None:
|
|
async def app(scope: Scope, receive: Receive, send: Send) -> None:
|
|
websocket = WebSocket(scope, receive=receive, send=send)
|
|
await websocket.accept()
|
|
await websocket.send_json({"url": str(websocket.url)})
|
|
await websocket.close()
|
|
|
|
client = test_client_factory(app)
|
|
with client.websocket_connect("/123?a=abc") as websocket:
|
|
data = websocket.receive_json()
|
|
assert data == {"url": "ws://testserver/123?a=abc"}
|
|
|
|
|
|
def test_websocket_binary_json(test_client_factory: TestClientFactory) -> None:
|
|
async def app(scope: Scope, receive: Receive, send: Send) -> None:
|
|
websocket = WebSocket(scope, receive=receive, send=send)
|
|
await websocket.accept()
|
|
message = await websocket.receive_json(mode="binary")
|
|
await websocket.send_json(message, mode="binary")
|
|
await websocket.close()
|
|
|
|
client = test_client_factory(app)
|
|
with client.websocket_connect("/123?a=abc") as websocket:
|
|
websocket.send_json({"test": "data"}, mode="binary")
|
|
data = websocket.receive_json(mode="binary")
|
|
assert data == {"test": "data"}
|
|
|
|
|
|
def test_websocket_ensure_unicode_on_send_json(
|
|
test_client_factory: TestClientFactory,
|
|
) -> None:
|
|
async def app(scope: Scope, receive: Receive, send: Send) -> None:
|
|
websocket = WebSocket(scope, receive=receive, send=send)
|
|
|
|
await websocket.accept()
|
|
message = await websocket.receive_json(mode="text")
|
|
await websocket.send_json(message, mode="text")
|
|
await websocket.close()
|
|
|
|
client = test_client_factory(app)
|
|
with client.websocket_connect("/123?a=abc") as websocket:
|
|
websocket.send_json({"test": "数据"}, mode="text")
|
|
data = websocket.receive_text()
|
|
assert data == '{"test":"数据"}'
|
|
|
|
|
|
def test_websocket_query_params(test_client_factory: TestClientFactory) -> None:
|
|
async def app(scope: Scope, receive: Receive, send: Send) -> None:
|
|
websocket = WebSocket(scope, receive=receive, send=send)
|
|
query_params = dict(websocket.query_params)
|
|
await websocket.accept()
|
|
await websocket.send_json({"params": query_params})
|
|
await websocket.close()
|
|
|
|
client = test_client_factory(app)
|
|
with client.websocket_connect("/?a=abc&b=456") as websocket:
|
|
data = websocket.receive_json()
|
|
assert data == {"params": {"a": "abc", "b": "456"}}
|
|
|
|
|
|
@pytest.mark.skipif(
|
|
any(module in sys.modules for module in ("brotli", "brotlicffi")),
|
|
reason='urllib3 includes "br" to the "accept-encoding" headers.',
|
|
)
|
|
def test_websocket_headers(test_client_factory: TestClientFactory) -> None:
|
|
async def app(scope: Scope, receive: Receive, send: Send) -> None:
|
|
websocket = WebSocket(scope, receive=receive, send=send)
|
|
headers = dict(websocket.headers)
|
|
await websocket.accept()
|
|
await websocket.send_json({"headers": headers})
|
|
await websocket.close()
|
|
|
|
client = test_client_factory(app)
|
|
with client.websocket_connect("/") as websocket:
|
|
expected_headers = {
|
|
"accept": "*/*",
|
|
"accept-encoding": "gzip, deflate",
|
|
"connection": "upgrade",
|
|
"host": "testserver",
|
|
"user-agent": "testclient",
|
|
"sec-websocket-key": "testserver==",
|
|
"sec-websocket-version": "13",
|
|
}
|
|
data = websocket.receive_json()
|
|
assert data == {"headers": expected_headers}
|
|
|
|
|
|
def test_websocket_port(test_client_factory: TestClientFactory) -> None:
|
|
async def app(scope: Scope, receive: Receive, send: Send) -> None:
|
|
websocket = WebSocket(scope, receive=receive, send=send)
|
|
await websocket.accept()
|
|
await websocket.send_json({"port": websocket.url.port})
|
|
await websocket.close()
|
|
|
|
client = test_client_factory(app)
|
|
with client.websocket_connect("ws://example.com:123/123?a=abc") as websocket:
|
|
data = websocket.receive_json()
|
|
assert data == {"port": 123}
|
|
|
|
|
|
def test_websocket_send_and_receive_text(
|
|
test_client_factory: TestClientFactory,
|
|
) -> None:
|
|
async def app(scope: Scope, receive: Receive, send: Send) -> None:
|
|
websocket = WebSocket(scope, receive=receive, send=send)
|
|
await websocket.accept()
|
|
data = await websocket.receive_text()
|
|
await websocket.send_text("Message was: " + data)
|
|
await websocket.close()
|
|
|
|
client = test_client_factory(app)
|
|
with client.websocket_connect("/") as websocket:
|
|
websocket.send_text("Hello, world!")
|
|
data = websocket.receive_text()
|
|
assert data == "Message was: Hello, world!"
|
|
|
|
|
|
def test_websocket_send_and_receive_bytes(
|
|
test_client_factory: TestClientFactory,
|
|
) -> None:
|
|
async def app(scope: Scope, receive: Receive, send: Send) -> None:
|
|
websocket = WebSocket(scope, receive=receive, send=send)
|
|
await websocket.accept()
|
|
data = await websocket.receive_bytes()
|
|
await websocket.send_bytes(b"Message was: " + data)
|
|
await websocket.close()
|
|
|
|
client = test_client_factory(app)
|
|
with client.websocket_connect("/") as websocket:
|
|
websocket.send_bytes(b"Hello, world!")
|
|
data = websocket.receive_bytes()
|
|
assert data == b"Message was: Hello, world!"
|
|
|
|
|
|
def test_websocket_send_and_receive_json(
|
|
test_client_factory: TestClientFactory,
|
|
) -> None:
|
|
async def app(scope: Scope, receive: Receive, send: Send) -> None:
|
|
websocket = WebSocket(scope, receive=receive, send=send)
|
|
await websocket.accept()
|
|
data = await websocket.receive_json()
|
|
await websocket.send_json({"message": data})
|
|
await websocket.close()
|
|
|
|
client = test_client_factory(app)
|
|
with client.websocket_connect("/") as websocket:
|
|
websocket.send_json({"hello": "world"})
|
|
data = websocket.receive_json()
|
|
assert data == {"message": {"hello": "world"}}
|
|
|
|
|
|
def test_websocket_iter_text(test_client_factory: TestClientFactory) -> None:
|
|
async def app(scope: Scope, receive: Receive, send: Send) -> None:
|
|
websocket = WebSocket(scope, receive=receive, send=send)
|
|
await websocket.accept()
|
|
async for data in websocket.iter_text():
|
|
await websocket.send_text("Message was: " + data)
|
|
|
|
client = test_client_factory(app)
|
|
with client.websocket_connect("/") as websocket:
|
|
websocket.send_text("Hello, world!")
|
|
data = websocket.receive_text()
|
|
assert data == "Message was: Hello, world!"
|
|
|
|
|
|
def test_websocket_iter_bytes(test_client_factory: TestClientFactory) -> None:
|
|
async def app(scope: Scope, receive: Receive, send: Send) -> None:
|
|
websocket = WebSocket(scope, receive=receive, send=send)
|
|
await websocket.accept()
|
|
async for data in websocket.iter_bytes():
|
|
await websocket.send_bytes(b"Message was: " + data)
|
|
|
|
client = test_client_factory(app)
|
|
with client.websocket_connect("/") as websocket:
|
|
websocket.send_bytes(b"Hello, world!")
|
|
data = websocket.receive_bytes()
|
|
assert data == b"Message was: Hello, world!"
|
|
|
|
|
|
def test_websocket_iter_json(test_client_factory: TestClientFactory) -> None:
|
|
async def app(scope: Scope, receive: Receive, send: Send) -> None:
|
|
websocket = WebSocket(scope, receive=receive, send=send)
|
|
await websocket.accept()
|
|
async for data in websocket.iter_json():
|
|
await websocket.send_json({"message": data})
|
|
|
|
client = test_client_factory(app)
|
|
with client.websocket_connect("/") as websocket:
|
|
websocket.send_json({"hello": "world"})
|
|
data = websocket.receive_json()
|
|
assert data == {"message": {"hello": "world"}}
|
|
|
|
|
|
def test_websocket_concurrency_pattern(test_client_factory: TestClientFactory) -> None:
|
|
stream_send: ObjectSendStream[MutableMapping[str, Any]]
|
|
stream_receive: ObjectReceiveStream[MutableMapping[str, Any]]
|
|
stream_send, stream_receive = anyio.create_memory_object_stream()
|
|
|
|
async def reader(websocket: WebSocket) -> None:
|
|
async with stream_send:
|
|
async for data in websocket.iter_json():
|
|
await stream_send.send(data)
|
|
|
|
async def writer(websocket: WebSocket) -> None:
|
|
async with stream_receive:
|
|
async for message in stream_receive:
|
|
await websocket.send_json(message)
|
|
|
|
async def app(scope: Scope, receive: Receive, send: Send) -> None:
|
|
websocket = WebSocket(scope, receive=receive, send=send)
|
|
await websocket.accept()
|
|
async with anyio.create_task_group() as task_group:
|
|
task_group.start_soon(reader, websocket)
|
|
await writer(websocket)
|
|
await websocket.close()
|
|
|
|
client = test_client_factory(app)
|
|
with client.websocket_connect("/") as websocket:
|
|
websocket.send_json({"hello": "world"})
|
|
data = websocket.receive_json()
|
|
assert data == {"hello": "world"}
|
|
|
|
|
|
def test_client_close(test_client_factory: TestClientFactory) -> None:
|
|
close_code = None
|
|
close_reason = None
|
|
|
|
async def app(scope: Scope, receive: Receive, send: Send) -> None:
|
|
nonlocal close_code, close_reason
|
|
websocket = WebSocket(scope, receive=receive, send=send)
|
|
await websocket.accept()
|
|
try:
|
|
await websocket.receive_text()
|
|
except WebSocketDisconnect as exc:
|
|
close_code = exc.code
|
|
close_reason = exc.reason
|
|
|
|
client = test_client_factory(app)
|
|
with client.websocket_connect("/") as websocket:
|
|
websocket.close(code=status.WS_1001_GOING_AWAY, reason="Going Away")
|
|
assert close_code == status.WS_1001_GOING_AWAY
|
|
assert close_reason == "Going Away"
|
|
|
|
|
|
@pytest.mark.anyio
|
|
async def test_client_disconnect_on_send() -> None:
|
|
async def app(scope: Scope, receive: Receive, send: Send) -> None:
|
|
websocket = WebSocket(scope, receive=receive, send=send)
|
|
await websocket.accept()
|
|
await websocket.send_text("Hello, world!")
|
|
|
|
async def receive() -> Message:
|
|
return {"type": "websocket.connect"}
|
|
|
|
async def send(message: Message) -> None:
|
|
if message["type"] == "websocket.accept":
|
|
return
|
|
# Simulate the exception the server would send to the application when the
|
|
# client disconnects.
|
|
raise OSError
|
|
|
|
with pytest.raises(WebSocketDisconnect) as ctx:
|
|
await app({"type": "websocket", "path": "/"}, receive, send)
|
|
assert ctx.value.code == status.WS_1006_ABNORMAL_CLOSURE
|
|
|
|
|
|
def test_application_close(test_client_factory: TestClientFactory) -> None:
|
|
async def app(scope: Scope, receive: Receive, send: Send) -> None:
|
|
websocket = WebSocket(scope, receive=receive, send=send)
|
|
await websocket.accept()
|
|
await websocket.close(status.WS_1001_GOING_AWAY)
|
|
|
|
client = test_client_factory(app)
|
|
with client.websocket_connect("/") as websocket:
|
|
with pytest.raises(WebSocketDisconnect) as exc:
|
|
websocket.receive_text()
|
|
assert exc.value.code == status.WS_1001_GOING_AWAY
|
|
|
|
|
|
def test_rejected_connection(test_client_factory: TestClientFactory) -> None:
|
|
async def app(scope: Scope, receive: Receive, send: Send) -> None:
|
|
websocket = WebSocket(scope, receive=receive, send=send)
|
|
msg = await websocket.receive()
|
|
assert msg == {"type": "websocket.connect"}
|
|
await websocket.close(status.WS_1001_GOING_AWAY)
|
|
|
|
client = test_client_factory(app)
|
|
with pytest.raises(WebSocketDisconnect) as exc:
|
|
with client.websocket_connect("/"):
|
|
pass # pragma: no cover
|
|
assert exc.value.code == status.WS_1001_GOING_AWAY
|
|
|
|
|
|
def test_send_denial_response(test_client_factory: TestClientFactory) -> None:
|
|
async def app(scope: Scope, receive: Receive, send: Send) -> None:
|
|
websocket = WebSocket(scope, receive=receive, send=send)
|
|
msg = await websocket.receive()
|
|
assert msg == {"type": "websocket.connect"}
|
|
response = Response(status_code=404, content="foo")
|
|
await websocket.send_denial_response(response)
|
|
|
|
client = test_client_factory(app)
|
|
with pytest.raises(WebSocketDenialResponse) as exc:
|
|
with client.websocket_connect("/"):
|
|
pass # pragma: no cover
|
|
assert exc.value.status_code == 404
|
|
assert exc.value.content == b"foo"
|
|
|
|
|
|
def test_send_response_multi(test_client_factory: TestClientFactory) -> None:
|
|
async def app(scope: Scope, receive: Receive, send: Send) -> None:
|
|
websocket = WebSocket(scope, receive=receive, send=send)
|
|
msg = await websocket.receive()
|
|
assert msg == {"type": "websocket.connect"}
|
|
await websocket.send(
|
|
{
|
|
"type": "websocket.http.response.start",
|
|
"status": 404,
|
|
"headers": [(b"content-type", b"text/plain"), (b"foo", b"bar")],
|
|
}
|
|
)
|
|
await websocket.send(
|
|
{
|
|
"type": "websocket.http.response.body",
|
|
"body": b"hard",
|
|
"more_body": True,
|
|
}
|
|
)
|
|
await websocket.send(
|
|
{
|
|
"type": "websocket.http.response.body",
|
|
"body": b"body",
|
|
}
|
|
)
|
|
|
|
client = test_client_factory(app)
|
|
with pytest.raises(WebSocketDenialResponse) as exc:
|
|
with client.websocket_connect("/"):
|
|
pass # pragma: no cover
|
|
assert exc.value.status_code == 404
|
|
assert exc.value.content == b"hardbody"
|
|
assert exc.value.headers["foo"] == "bar"
|
|
|
|
|
|
def test_send_response_unsupported(test_client_factory: TestClientFactory) -> None:
|
|
async def app(scope: Scope, receive: Receive, send: Send) -> None:
|
|
del scope["extensions"]["websocket.http.response"]
|
|
websocket = WebSocket(scope, receive=receive, send=send)
|
|
msg = await websocket.receive()
|
|
assert msg == {"type": "websocket.connect"}
|
|
response = Response(status_code=404, content="foo")
|
|
with pytest.raises(
|
|
RuntimeError,
|
|
match="The server doesn't support the Websocket Denial Response extension.",
|
|
):
|
|
await websocket.send_denial_response(response)
|
|
await websocket.close()
|
|
|
|
client = test_client_factory(app)
|
|
with pytest.raises(WebSocketDisconnect) as exc:
|
|
with client.websocket_connect("/"):
|
|
pass # pragma: no cover
|
|
assert exc.value.code == status.WS_1000_NORMAL_CLOSURE
|
|
|
|
|
|
def test_send_response_duplicate_start(test_client_factory: TestClientFactory) -> None:
|
|
async def app(scope: Scope, receive: Receive, send: Send) -> None:
|
|
websocket = WebSocket(scope, receive=receive, send=send)
|
|
msg = await websocket.receive()
|
|
assert msg == {"type": "websocket.connect"}
|
|
response = Response(status_code=404, content="foo")
|
|
await websocket.send(
|
|
{
|
|
"type": "websocket.http.response.start",
|
|
"status": response.status_code,
|
|
"headers": response.raw_headers,
|
|
}
|
|
)
|
|
await websocket.send(
|
|
{
|
|
"type": "websocket.http.response.start",
|
|
"status": response.status_code,
|
|
"headers": response.raw_headers,
|
|
}
|
|
)
|
|
|
|
client = test_client_factory(app)
|
|
with pytest.raises(
|
|
RuntimeError,
|
|
match=(
|
|
'Expected ASGI message "websocket.http.response.body", but got '
|
|
"'websocket.http.response.start'"
|
|
),
|
|
):
|
|
with client.websocket_connect("/"):
|
|
pass # pragma: no cover
|
|
|
|
|
|
def test_subprotocol(test_client_factory: TestClientFactory) -> None:
|
|
async def app(scope: Scope, receive: Receive, send: Send) -> None:
|
|
websocket = WebSocket(scope, receive=receive, send=send)
|
|
assert websocket["subprotocols"] == ["soap", "wamp"]
|
|
await websocket.accept(subprotocol="wamp")
|
|
await websocket.close()
|
|
|
|
client = test_client_factory(app)
|
|
with client.websocket_connect("/", subprotocols=["soap", "wamp"]) as websocket:
|
|
assert websocket.accepted_subprotocol == "wamp"
|
|
|
|
|
|
def test_additional_headers(test_client_factory: TestClientFactory) -> None:
|
|
async def app(scope: Scope, receive: Receive, send: Send) -> None:
|
|
websocket = WebSocket(scope, receive=receive, send=send)
|
|
await websocket.accept(headers=[(b"additional", b"header")])
|
|
await websocket.close()
|
|
|
|
client = test_client_factory(app)
|
|
with client.websocket_connect("/") as websocket:
|
|
assert websocket.extra_headers == [(b"additional", b"header")]
|
|
|
|
|
|
def test_no_additional_headers(test_client_factory: TestClientFactory) -> None:
|
|
async def app(scope: Scope, receive: Receive, send: Send) -> None:
|
|
websocket = WebSocket(scope, receive=receive, send=send)
|
|
await websocket.accept()
|
|
await websocket.close()
|
|
|
|
client = test_client_factory(app)
|
|
with client.websocket_connect("/") as websocket:
|
|
assert websocket.extra_headers == []
|
|
|
|
|
|
def test_websocket_exception(test_client_factory: TestClientFactory) -> None:
|
|
async def app(scope: Scope, receive: Receive, send: Send) -> None:
|
|
assert False
|
|
|
|
client = test_client_factory(app)
|
|
with pytest.raises(AssertionError):
|
|
with client.websocket_connect("/123?a=abc"):
|
|
pass # pragma: no cover
|
|
|
|
|
|
def test_duplicate_close(test_client_factory: TestClientFactory) -> None:
|
|
async def app(scope: Scope, receive: Receive, send: Send) -> None:
|
|
websocket = WebSocket(scope, receive=receive, send=send)
|
|
await websocket.accept()
|
|
await websocket.close()
|
|
await websocket.close()
|
|
|
|
client = test_client_factory(app)
|
|
with pytest.raises(RuntimeError):
|
|
with client.websocket_connect("/"):
|
|
pass # pragma: no cover
|
|
|
|
|
|
def test_duplicate_disconnect(test_client_factory: TestClientFactory) -> None:
|
|
async def app(scope: Scope, receive: Receive, send: Send) -> None:
|
|
websocket = WebSocket(scope, receive=receive, send=send)
|
|
await websocket.accept()
|
|
message = await websocket.receive()
|
|
assert message["type"] == "websocket.disconnect"
|
|
message = await websocket.receive()
|
|
|
|
client = test_client_factory(app)
|
|
with pytest.raises(RuntimeError):
|
|
with client.websocket_connect("/") as websocket:
|
|
websocket.close()
|
|
|
|
|
|
def test_websocket_scope_interface() -> None:
|
|
"""
|
|
A WebSocket can be instantiated with a scope, and presents a `Mapping`
|
|
interface.
|
|
"""
|
|
|
|
async def mock_receive() -> Message: # type: ignore
|
|
... # pragma: no cover
|
|
|
|
async def mock_send(message: Message) -> None:
|
|
... # pragma: no cover
|
|
|
|
websocket = WebSocket(
|
|
{"type": "websocket", "path": "/abc/", "headers": []},
|
|
receive=mock_receive,
|
|
send=mock_send,
|
|
)
|
|
assert websocket["type"] == "websocket"
|
|
assert dict(websocket) == {"type": "websocket", "path": "/abc/", "headers": []}
|
|
assert len(websocket) == 3
|
|
|
|
# check __eq__ and __hash__
|
|
assert websocket != WebSocket(
|
|
{"type": "websocket", "path": "/abc/", "headers": []},
|
|
receive=mock_receive,
|
|
send=mock_send,
|
|
)
|
|
assert websocket == websocket
|
|
assert websocket in {websocket}
|
|
assert {websocket} == {websocket}
|
|
|
|
|
|
def test_websocket_close_reason(test_client_factory: TestClientFactory) -> None:
|
|
async def app(scope: Scope, receive: Receive, send: Send) -> None:
|
|
websocket = WebSocket(scope, receive=receive, send=send)
|
|
await websocket.accept()
|
|
await websocket.close(code=status.WS_1001_GOING_AWAY, reason="Going Away")
|
|
|
|
client = test_client_factory(app)
|
|
with client.websocket_connect("/") as websocket:
|
|
with pytest.raises(WebSocketDisconnect) as exc:
|
|
websocket.receive_text()
|
|
assert exc.value.code == status.WS_1001_GOING_AWAY
|
|
assert exc.value.reason == "Going Away"
|
|
|
|
|
|
def test_send_json_invalid_mode(test_client_factory: TestClientFactory) -> None:
|
|
async def app(scope: Scope, receive: Receive, send: Send) -> None:
|
|
websocket = WebSocket(scope, receive=receive, send=send)
|
|
await websocket.accept()
|
|
await websocket.send_json({}, mode="invalid")
|
|
|
|
client = test_client_factory(app)
|
|
with pytest.raises(RuntimeError):
|
|
with client.websocket_connect("/"):
|
|
pass # pragma: no cover
|
|
|
|
|
|
def test_receive_json_invalid_mode(test_client_factory: TestClientFactory) -> None:
|
|
async def app(scope: Scope, receive: Receive, send: Send) -> None:
|
|
websocket = WebSocket(scope, receive=receive, send=send)
|
|
await websocket.accept()
|
|
await websocket.receive_json(mode="invalid")
|
|
|
|
client = test_client_factory(app)
|
|
with pytest.raises(RuntimeError):
|
|
with client.websocket_connect("/"):
|
|
pass # pragma: nocover
|
|
|
|
|
|
def test_receive_text_before_accept(test_client_factory: TestClientFactory) -> None:
|
|
async def app(scope: Scope, receive: Receive, send: Send) -> None:
|
|
websocket = WebSocket(scope, receive=receive, send=send)
|
|
await websocket.receive_text()
|
|
|
|
client = test_client_factory(app)
|
|
with pytest.raises(RuntimeError):
|
|
with client.websocket_connect("/"):
|
|
pass # pragma: nocover
|
|
|
|
|
|
def test_receive_bytes_before_accept(test_client_factory: TestClientFactory) -> None:
|
|
async def app(scope: Scope, receive: Receive, send: Send) -> None:
|
|
websocket = WebSocket(scope, receive=receive, send=send)
|
|
await websocket.receive_bytes()
|
|
|
|
client = test_client_factory(app)
|
|
with pytest.raises(RuntimeError):
|
|
with client.websocket_connect("/"):
|
|
pass # pragma: nocover
|
|
|
|
|
|
def test_receive_json_before_accept(test_client_factory: TestClientFactory) -> None:
|
|
async def app(scope: Scope, receive: Receive, send: Send) -> None:
|
|
websocket = WebSocket(scope, receive=receive, send=send)
|
|
await websocket.receive_json()
|
|
|
|
client = test_client_factory(app)
|
|
with pytest.raises(RuntimeError):
|
|
with client.websocket_connect("/"):
|
|
pass # pragma: no cover
|
|
|
|
|
|
def test_send_before_accept(test_client_factory: TestClientFactory) -> None:
|
|
async def app(scope: Scope, receive: Receive, send: Send) -> None:
|
|
websocket = WebSocket(scope, receive=receive, send=send)
|
|
await websocket.send({"type": "websocket.send"})
|
|
|
|
client = test_client_factory(app)
|
|
with pytest.raises(RuntimeError):
|
|
with client.websocket_connect("/"):
|
|
pass # pragma: nocover
|
|
|
|
|
|
def test_send_wrong_message_type(test_client_factory: TestClientFactory) -> None:
|
|
async def app(scope: Scope, receive: Receive, send: Send) -> None:
|
|
websocket = WebSocket(scope, receive=receive, send=send)
|
|
await websocket.send({"type": "websocket.accept"})
|
|
await websocket.send({"type": "websocket.accept"})
|
|
|
|
client = test_client_factory(app)
|
|
with pytest.raises(RuntimeError):
|
|
with client.websocket_connect("/"):
|
|
pass # pragma: no cover
|
|
|
|
|
|
def test_receive_before_accept(test_client_factory: TestClientFactory) -> None:
|
|
async def app(scope: Scope, receive: Receive, send: Send) -> None:
|
|
websocket = WebSocket(scope, receive=receive, send=send)
|
|
await websocket.accept()
|
|
websocket.client_state = WebSocketState.CONNECTING
|
|
await websocket.receive()
|
|
|
|
client = test_client_factory(app)
|
|
with pytest.raises(RuntimeError):
|
|
with client.websocket_connect("/") as websocket:
|
|
websocket.send({"type": "websocket.send"})
|
|
|
|
|
|
def test_receive_wrong_message_type(test_client_factory: TestClientFactory) -> None:
|
|
async def app(scope: Scope, receive: Receive, send: Send) -> None:
|
|
websocket = WebSocket(scope, receive=receive, send=send)
|
|
await websocket.accept()
|
|
await websocket.receive()
|
|
|
|
client = test_client_factory(app)
|
|
with pytest.raises(RuntimeError):
|
|
with client.websocket_connect("/") as websocket:
|
|
websocket.send({"type": "websocket.connect"})
|