mirror of https://github.com/encode/starlette.git
Add support for functools.partial in WebsocketRoute (#1356)
* Add support for functools.partial in WebsocketRoute * remove commented code * Refactor tests for partian endpoint and ws
This commit is contained in:
parent
f53faba229
commit
76cd611b50
|
@ -276,7 +276,10 @@ class WebSocketRoute(BaseRoute):
|
|||
self.endpoint = endpoint
|
||||
self.name = get_name(endpoint) if name is None else name
|
||||
|
||||
if inspect.isfunction(endpoint) or inspect.ismethod(endpoint):
|
||||
endpoint_handler = endpoint
|
||||
while isinstance(endpoint_handler, functools.partial):
|
||||
endpoint_handler = endpoint_handler.func
|
||||
if inspect.isfunction(endpoint_handler) or inspect.ismethod(endpoint_handler):
|
||||
# Endpoint is function or method. Treat it as `func(websocket)`.
|
||||
self.app = websocket_session(endpoint)
|
||||
else:
|
||||
|
|
|
@ -32,6 +32,28 @@ def user_no_match(request): # pragma: no cover
|
|||
return Response(content, media_type="text/plain")
|
||||
|
||||
|
||||
async def partial_endpoint(arg, request):
|
||||
return JSONResponse({"arg": arg})
|
||||
|
||||
|
||||
async def partial_ws_endpoint(websocket: WebSocket):
|
||||
await websocket.accept()
|
||||
await websocket.send_json({"url": str(websocket.url)})
|
||||
await websocket.close()
|
||||
|
||||
|
||||
class PartialRoutes:
|
||||
@classmethod
|
||||
async def async_endpoint(cls, arg, request):
|
||||
return JSONResponse({"arg": arg})
|
||||
|
||||
@classmethod
|
||||
async def async_ws_endpoint(cls, websocket: WebSocket):
|
||||
await websocket.accept()
|
||||
await websocket.send_json({"url": str(websocket.url)})
|
||||
await websocket.close()
|
||||
|
||||
|
||||
app = Router(
|
||||
[
|
||||
Route("/", endpoint=homepage, methods=["GET"]),
|
||||
|
@ -44,6 +66,21 @@ app = Router(
|
|||
Route("/nomatch", endpoint=user_no_match),
|
||||
],
|
||||
),
|
||||
Mount(
|
||||
"/partial",
|
||||
routes=[
|
||||
Route("/", endpoint=functools.partial(partial_endpoint, "foo")),
|
||||
Route(
|
||||
"/cls",
|
||||
endpoint=functools.partial(PartialRoutes.async_endpoint, "foo"),
|
||||
),
|
||||
WebSocketRoute("/ws", endpoint=functools.partial(partial_ws_endpoint)),
|
||||
WebSocketRoute(
|
||||
"/ws/cls",
|
||||
endpoint=functools.partial(PartialRoutes.async_ws_endpoint),
|
||||
),
|
||||
],
|
||||
),
|
||||
Mount("/static", app=Response("xxxxx", media_type="image/png")),
|
||||
]
|
||||
)
|
||||
|
@ -91,14 +128,14 @@ def path_with_parentheses(request):
|
|||
|
||||
|
||||
@app.websocket_route("/ws")
|
||||
async def websocket_endpoint(session):
|
||||
async def websocket_endpoint(session: WebSocket):
|
||||
await session.accept()
|
||||
await session.send_text("Hello, world!")
|
||||
await session.close()
|
||||
|
||||
|
||||
@app.websocket_route("/ws/{room}")
|
||||
async def websocket_params(session):
|
||||
async def websocket_params(session: WebSocket):
|
||||
await session.accept()
|
||||
await session.send_text(f"Hello, {session.path_params['room']}!")
|
||||
await session.close()
|
||||
|
@ -628,40 +665,28 @@ def test_raise_on_shutdown(test_client_factory):
|
|||
pass # pragma: nocover
|
||||
|
||||
|
||||
class AsyncEndpointClassMethod:
|
||||
@classmethod
|
||||
async def async_endpoint(cls, arg, request):
|
||||
return JSONResponse({"arg": arg})
|
||||
|
||||
|
||||
async def _partial_async_endpoint(arg, request):
|
||||
return JSONResponse({"arg": arg})
|
||||
|
||||
|
||||
partial_async_endpoint = functools.partial(_partial_async_endpoint, "foo")
|
||||
partial_cls_async_endpoint = functools.partial(
|
||||
AsyncEndpointClassMethod.async_endpoint, "foo"
|
||||
)
|
||||
|
||||
partial_async_app = Router(
|
||||
routes=[
|
||||
Route("/", partial_async_endpoint),
|
||||
Route("/cls", partial_cls_async_endpoint),
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
def test_partial_async_endpoint(test_client_factory):
|
||||
test_client = test_client_factory(partial_async_app)
|
||||
response = test_client.get("/")
|
||||
test_client = test_client_factory(app)
|
||||
response = test_client.get("/partial")
|
||||
assert response.status_code == 200
|
||||
assert response.json() == {"arg": "foo"}
|
||||
|
||||
cls_method_response = test_client.get("/cls")
|
||||
cls_method_response = test_client.get("/partial/cls")
|
||||
assert cls_method_response.status_code == 200
|
||||
assert cls_method_response.json() == {"arg": "foo"}
|
||||
|
||||
|
||||
def test_partial_async_ws_endpoint(test_client_factory):
|
||||
test_client = test_client_factory(app)
|
||||
with test_client.websocket_connect("/partial/ws") as websocket:
|
||||
data = websocket.receive_json()
|
||||
assert data == {"url": "ws://testserver/partial/ws"}
|
||||
|
||||
with test_client.websocket_connect("/partial/ws/cls") as websocket:
|
||||
data = websocket.receive_json()
|
||||
assert data == {"url": "ws://testserver/partial/ws/cls"}
|
||||
|
||||
|
||||
def test_duplicated_param_names():
|
||||
with pytest.raises(
|
||||
ValueError,
|
||||
|
|
Loading…
Reference in New Issue