Add support for functools.partial in WebsocketRoute (#1356)

* Add support for functools.partial in WebsocketRoute

* remove commented code

* Refactor tests for partian endpoint and ws
This commit is contained in:
Amin Alaee 2021-12-11 14:35:23 +01:00 committed by GitHub
parent f53faba229
commit 76cd611b50
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 57 additions and 29 deletions

View File

@ -276,7 +276,10 @@ class WebSocketRoute(BaseRoute):
self.endpoint = endpoint
self.name = get_name(endpoint) if name is None else name
if inspect.isfunction(endpoint) or inspect.ismethod(endpoint):
endpoint_handler = endpoint
while isinstance(endpoint_handler, functools.partial):
endpoint_handler = endpoint_handler.func
if inspect.isfunction(endpoint_handler) or inspect.ismethod(endpoint_handler):
# Endpoint is function or method. Treat it as `func(websocket)`.
self.app = websocket_session(endpoint)
else:

View File

@ -32,6 +32,28 @@ def user_no_match(request): # pragma: no cover
return Response(content, media_type="text/plain")
async def partial_endpoint(arg, request):
return JSONResponse({"arg": arg})
async def partial_ws_endpoint(websocket: WebSocket):
await websocket.accept()
await websocket.send_json({"url": str(websocket.url)})
await websocket.close()
class PartialRoutes:
@classmethod
async def async_endpoint(cls, arg, request):
return JSONResponse({"arg": arg})
@classmethod
async def async_ws_endpoint(cls, websocket: WebSocket):
await websocket.accept()
await websocket.send_json({"url": str(websocket.url)})
await websocket.close()
app = Router(
[
Route("/", endpoint=homepage, methods=["GET"]),
@ -44,6 +66,21 @@ app = Router(
Route("/nomatch", endpoint=user_no_match),
],
),
Mount(
"/partial",
routes=[
Route("/", endpoint=functools.partial(partial_endpoint, "foo")),
Route(
"/cls",
endpoint=functools.partial(PartialRoutes.async_endpoint, "foo"),
),
WebSocketRoute("/ws", endpoint=functools.partial(partial_ws_endpoint)),
WebSocketRoute(
"/ws/cls",
endpoint=functools.partial(PartialRoutes.async_ws_endpoint),
),
],
),
Mount("/static", app=Response("xxxxx", media_type="image/png")),
]
)
@ -91,14 +128,14 @@ def path_with_parentheses(request):
@app.websocket_route("/ws")
async def websocket_endpoint(session):
async def websocket_endpoint(session: WebSocket):
await session.accept()
await session.send_text("Hello, world!")
await session.close()
@app.websocket_route("/ws/{room}")
async def websocket_params(session):
async def websocket_params(session: WebSocket):
await session.accept()
await session.send_text(f"Hello, {session.path_params['room']}!")
await session.close()
@ -628,40 +665,28 @@ def test_raise_on_shutdown(test_client_factory):
pass # pragma: nocover
class AsyncEndpointClassMethod:
@classmethod
async def async_endpoint(cls, arg, request):
return JSONResponse({"arg": arg})
async def _partial_async_endpoint(arg, request):
return JSONResponse({"arg": arg})
partial_async_endpoint = functools.partial(_partial_async_endpoint, "foo")
partial_cls_async_endpoint = functools.partial(
AsyncEndpointClassMethod.async_endpoint, "foo"
)
partial_async_app = Router(
routes=[
Route("/", partial_async_endpoint),
Route("/cls", partial_cls_async_endpoint),
]
)
def test_partial_async_endpoint(test_client_factory):
test_client = test_client_factory(partial_async_app)
response = test_client.get("/")
test_client = test_client_factory(app)
response = test_client.get("/partial")
assert response.status_code == 200
assert response.json() == {"arg": "foo"}
cls_method_response = test_client.get("/cls")
cls_method_response = test_client.get("/partial/cls")
assert cls_method_response.status_code == 200
assert cls_method_response.json() == {"arg": "foo"}
def test_partial_async_ws_endpoint(test_client_factory):
test_client = test_client_factory(app)
with test_client.websocket_connect("/partial/ws") as websocket:
data = websocket.receive_json()
assert data == {"url": "ws://testserver/partial/ws"}
with test_client.websocket_connect("/partial/ws/cls") as websocket:
data = websocket.receive_json()
assert data == {"url": "ws://testserver/partial/ws/cls"}
def test_duplicated_param_names():
with pytest.raises(
ValueError,