mirror of https://github.com/encode/starlette.git
Add support for functools.partial in WebsocketRoute (#1356)
* Add support for functools.partial in WebsocketRoute * remove commented code * Refactor tests for partian endpoint and ws
This commit is contained in:
parent
f53faba229
commit
76cd611b50
|
@ -276,7 +276,10 @@ class WebSocketRoute(BaseRoute):
|
||||||
self.endpoint = endpoint
|
self.endpoint = endpoint
|
||||||
self.name = get_name(endpoint) if name is None else name
|
self.name = get_name(endpoint) if name is None else name
|
||||||
|
|
||||||
if inspect.isfunction(endpoint) or inspect.ismethod(endpoint):
|
endpoint_handler = endpoint
|
||||||
|
while isinstance(endpoint_handler, functools.partial):
|
||||||
|
endpoint_handler = endpoint_handler.func
|
||||||
|
if inspect.isfunction(endpoint_handler) or inspect.ismethod(endpoint_handler):
|
||||||
# Endpoint is function or method. Treat it as `func(websocket)`.
|
# Endpoint is function or method. Treat it as `func(websocket)`.
|
||||||
self.app = websocket_session(endpoint)
|
self.app = websocket_session(endpoint)
|
||||||
else:
|
else:
|
||||||
|
|
|
@ -32,6 +32,28 @@ def user_no_match(request): # pragma: no cover
|
||||||
return Response(content, media_type="text/plain")
|
return Response(content, media_type="text/plain")
|
||||||
|
|
||||||
|
|
||||||
|
async def partial_endpoint(arg, request):
|
||||||
|
return JSONResponse({"arg": arg})
|
||||||
|
|
||||||
|
|
||||||
|
async def partial_ws_endpoint(websocket: WebSocket):
|
||||||
|
await websocket.accept()
|
||||||
|
await websocket.send_json({"url": str(websocket.url)})
|
||||||
|
await websocket.close()
|
||||||
|
|
||||||
|
|
||||||
|
class PartialRoutes:
|
||||||
|
@classmethod
|
||||||
|
async def async_endpoint(cls, arg, request):
|
||||||
|
return JSONResponse({"arg": arg})
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
async def async_ws_endpoint(cls, websocket: WebSocket):
|
||||||
|
await websocket.accept()
|
||||||
|
await websocket.send_json({"url": str(websocket.url)})
|
||||||
|
await websocket.close()
|
||||||
|
|
||||||
|
|
||||||
app = Router(
|
app = Router(
|
||||||
[
|
[
|
||||||
Route("/", endpoint=homepage, methods=["GET"]),
|
Route("/", endpoint=homepage, methods=["GET"]),
|
||||||
|
@ -44,6 +66,21 @@ app = Router(
|
||||||
Route("/nomatch", endpoint=user_no_match),
|
Route("/nomatch", endpoint=user_no_match),
|
||||||
],
|
],
|
||||||
),
|
),
|
||||||
|
Mount(
|
||||||
|
"/partial",
|
||||||
|
routes=[
|
||||||
|
Route("/", endpoint=functools.partial(partial_endpoint, "foo")),
|
||||||
|
Route(
|
||||||
|
"/cls",
|
||||||
|
endpoint=functools.partial(PartialRoutes.async_endpoint, "foo"),
|
||||||
|
),
|
||||||
|
WebSocketRoute("/ws", endpoint=functools.partial(partial_ws_endpoint)),
|
||||||
|
WebSocketRoute(
|
||||||
|
"/ws/cls",
|
||||||
|
endpoint=functools.partial(PartialRoutes.async_ws_endpoint),
|
||||||
|
),
|
||||||
|
],
|
||||||
|
),
|
||||||
Mount("/static", app=Response("xxxxx", media_type="image/png")),
|
Mount("/static", app=Response("xxxxx", media_type="image/png")),
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
@ -91,14 +128,14 @@ def path_with_parentheses(request):
|
||||||
|
|
||||||
|
|
||||||
@app.websocket_route("/ws")
|
@app.websocket_route("/ws")
|
||||||
async def websocket_endpoint(session):
|
async def websocket_endpoint(session: WebSocket):
|
||||||
await session.accept()
|
await session.accept()
|
||||||
await session.send_text("Hello, world!")
|
await session.send_text("Hello, world!")
|
||||||
await session.close()
|
await session.close()
|
||||||
|
|
||||||
|
|
||||||
@app.websocket_route("/ws/{room}")
|
@app.websocket_route("/ws/{room}")
|
||||||
async def websocket_params(session):
|
async def websocket_params(session: WebSocket):
|
||||||
await session.accept()
|
await session.accept()
|
||||||
await session.send_text(f"Hello, {session.path_params['room']}!")
|
await session.send_text(f"Hello, {session.path_params['room']}!")
|
||||||
await session.close()
|
await session.close()
|
||||||
|
@ -628,40 +665,28 @@ def test_raise_on_shutdown(test_client_factory):
|
||||||
pass # pragma: nocover
|
pass # pragma: nocover
|
||||||
|
|
||||||
|
|
||||||
class AsyncEndpointClassMethod:
|
|
||||||
@classmethod
|
|
||||||
async def async_endpoint(cls, arg, request):
|
|
||||||
return JSONResponse({"arg": arg})
|
|
||||||
|
|
||||||
|
|
||||||
async def _partial_async_endpoint(arg, request):
|
|
||||||
return JSONResponse({"arg": arg})
|
|
||||||
|
|
||||||
|
|
||||||
partial_async_endpoint = functools.partial(_partial_async_endpoint, "foo")
|
|
||||||
partial_cls_async_endpoint = functools.partial(
|
|
||||||
AsyncEndpointClassMethod.async_endpoint, "foo"
|
|
||||||
)
|
|
||||||
|
|
||||||
partial_async_app = Router(
|
|
||||||
routes=[
|
|
||||||
Route("/", partial_async_endpoint),
|
|
||||||
Route("/cls", partial_cls_async_endpoint),
|
|
||||||
]
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def test_partial_async_endpoint(test_client_factory):
|
def test_partial_async_endpoint(test_client_factory):
|
||||||
test_client = test_client_factory(partial_async_app)
|
test_client = test_client_factory(app)
|
||||||
response = test_client.get("/")
|
response = test_client.get("/partial")
|
||||||
assert response.status_code == 200
|
assert response.status_code == 200
|
||||||
assert response.json() == {"arg": "foo"}
|
assert response.json() == {"arg": "foo"}
|
||||||
|
|
||||||
cls_method_response = test_client.get("/cls")
|
cls_method_response = test_client.get("/partial/cls")
|
||||||
assert cls_method_response.status_code == 200
|
assert cls_method_response.status_code == 200
|
||||||
assert cls_method_response.json() == {"arg": "foo"}
|
assert cls_method_response.json() == {"arg": "foo"}
|
||||||
|
|
||||||
|
|
||||||
|
def test_partial_async_ws_endpoint(test_client_factory):
|
||||||
|
test_client = test_client_factory(app)
|
||||||
|
with test_client.websocket_connect("/partial/ws") as websocket:
|
||||||
|
data = websocket.receive_json()
|
||||||
|
assert data == {"url": "ws://testserver/partial/ws"}
|
||||||
|
|
||||||
|
with test_client.websocket_connect("/partial/ws/cls") as websocket:
|
||||||
|
data = websocket.receive_json()
|
||||||
|
assert data == {"url": "ws://testserver/partial/ws/cls"}
|
||||||
|
|
||||||
|
|
||||||
def test_duplicated_param_names():
|
def test_duplicated_param_names():
|
||||||
with pytest.raises(
|
with pytest.raises(
|
||||||
ValueError,
|
ValueError,
|
||||||
|
|
Loading…
Reference in New Issue