diff --git a/README.md b/README.md index 922a4165..dc9ac625 100644 --- a/README.md +++ b/README.md @@ -20,7 +20,7 @@ Starlette is a small library for working with [ASGI](https://asgi.readthedocs.io/en/latest/). -It gives you `Request` and `Response` classes, request routing, websocket support, +It gives you `Request` and `Response` classes, websocket support, routing, static files support, and a test client. **Requirements:** @@ -36,7 +36,7 @@ pip3 install starlette **Example:** ```python -from starlette import Response +from starlette.response import Response class App: @@ -72,6 +72,7 @@ You can run the application with any ASGI server, including [uvicorn](http://www * [Static Files](#static-files) * [Test Client](#test-client) * [Debugging](#debugging) +* [Applications](#applications) --- @@ -111,7 +112,7 @@ class App: Takes some text or bytes and returns an HTML response. ```python -from starlette import HTMLResponse +from starlette.response import HTMLResponse class App: @@ -128,7 +129,7 @@ class App: Takes some text or bytes and returns an plain text response. ```python -from starlette import PlainTextResponse +from starlette.response import PlainTextResponse class App: @@ -145,7 +146,7 @@ class App: Takes some data and returns an `application/json` encoded response. ```python -from starlette import JSONResponse +from starlette.response import JSONResponse class App: @@ -162,7 +163,7 @@ class App: Returns an HTTP redirect. Uses a 302 status code by default. ```python -from starlette import PlainTextResponse, RedirectResponse +from starlette.response import PlainTextResponse, RedirectResponse class App: @@ -182,7 +183,8 @@ class App: Takes an async generator and streams the response body. ```python -from starlette import Request, StreamingResponse +from starlette.request import Request +from starlette.response import StreamingResponse import asyncio @@ -218,7 +220,7 @@ Takes a different set of arguments to instantiate than the other response types: File responses will include appropriate `Content-Length`, `Last-Modified` and `ETag` headers. ```python -from starlette import FileResponse +from starlette.response import FileResponse class App: @@ -514,7 +516,8 @@ The test client allows you to make requests against your ASGI application, using the `requests` library. ```python -from starlette import HTMLResponse, TestClient +from starlette.response import HTMLResponse +from starlette.testclient import TestClient class App: @@ -562,7 +565,7 @@ class App: def test_app(): client = TestClient(App) - with client.wsconnect('/') as session: + with client.websocket_connect('/') as session: data = session.receive_text() assert data == 'Hello, world!' ``` @@ -576,10 +579,16 @@ always raised by the test client. #### Establishing a test session -* `.wsconnect(url, subprotocols=None, **options)` - Takes the same set of arguments as `requests.get()`. +* `.websocket_connect(url, subprotocols=None, **options)` - Takes the same set of arguments as `requests.get()`. May raise `starlette.websockets.Disconnect` if the application does not accept the websocket connection. +#### Sending data + +* `.send_text(data)` - Send the given text to the application. +* `.send_bytes(data)` - Send the given bytes to the application. +* `.send_json(data)` - Send the given data to the application. + #### Receiving data * `.receive_text()` - Wait for incoming text sent by the application and return it. @@ -615,4 +624,48 @@ app = DebugMiddleware(App) --- +## Applications + +Starlette also includes an `App` class that nicely ties together all of +its other functionality. + +```python +from starlette.app import App +from starlette.response import PlainTextResponse +from starlette.staticfiles import StaticFiles + + +app = App() +app.mount("/static", StaticFiles(directory="static")) + + +@app.route('/') +def homepage(request): + return PlainTextResponse('Hello, world!') + + +@app.route('/user/{username}') +def user(request, username): + return PlainTextResponse('Hello, %s!' % username) + + +@app.websocket_route('/ws') +async def websocket_endpoint(session): + await session.accept() + await session.send_text('Hello, websocket!') + await session.close() +``` + +### Adding routes to the application + +You can use any of the following to add handled routes to the application: + +* `.add_route(path, func, methods=["GET"])` - Add an HTTP route. The function may be either a coroutine or a regular function, with a signature like `func(request **kwargs) -> response`. +* `.add_websocket_route(path, func)` - Add a websocket session route. The function must be a coroutine, with a signature like `func(session, **kwargs)`. +* `.mount(prefix, app)` - Include an ASGI app, mounted under the given path prefix +* `.route(path)` - Add an HTTP route, decorator style. +* `.websocket_route(path)` - Add a WebSocket route, decorator style. + +--- +

Starlette is BSD licensed code.
Designed & built in Brighton, England.

— ⭐️ —

diff --git a/starlette/__init__.py b/starlette/__init__.py index a0819db8..e163e824 100644 --- a/starlette/__init__.py +++ b/starlette/__init__.py @@ -1,3 +1,4 @@ +from starlette.app import App from starlette.response import ( FileResponse, HTMLResponse, @@ -12,6 +13,7 @@ from starlette.testclient import TestClient __all__ = ( + "App", "FileResponse", "HTMLResponse", "JSONResponse", @@ -22,4 +24,4 @@ __all__ = ( "Request", "TestClient", ) -__version__ = "0.1.17" +__version__ = "0.2.0" diff --git a/starlette/app.py b/starlette/app.py new file mode 100644 index 00000000..3120d686 --- /dev/null +++ b/starlette/app.py @@ -0,0 +1,77 @@ +from starlette.request import Request +from starlette.routing import Path, PathPrefix, Router +from starlette.types import ASGIApp, ASGIInstance, Receive, Scope, Send +from starlette.websockets import WebSocketSession +import asyncio + + +def request_response(func): + """ + Taks a function or coroutine `func(request, **kwargs) -> response`, + and returns an ASGI application. + """ + is_coroutine = asyncio.iscoroutinefunction(func) + + def app(scope: Scope) -> ASGIInstance: + async def awaitable(receive: Receive, send: Send) -> None: + request = Request(scope, receive=receive) + kwargs = scope.get("kwargs", {}) + if is_coroutine: + response = await func(request, **kwargs) + else: + response = func(request, **kwargs) + await response(receive, send) + + return awaitable + + return app + + +def websocket_session(func): + """ + Takes a coroutine `func(session, **kwargs)`, and returns an ASGI application. + """ + + def app(scope: Scope) -> ASGIInstance: + async def awaitable(receive: Receive, send: Send) -> None: + session = WebSocketSession(scope, receive=receive, send=send) + kwargs = scope.get("kwargs", {}) + await func(session, **kwargs) + + return awaitable + + return app + + +class App: + def __init__(self) -> None: + self.router = Router(routes=[]) + + def mount(self, path: str, app: ASGIApp): + prefix = PathPrefix(path, app=app) + self.router.routes.append(prefix) + + def add_route(self, path: str, route, methods=None) -> None: + if methods is None: + methods = ["GET"] + instance = Path(path, request_response(route), protocol="http", methods=methods) + self.router.routes.append(instance) + + def add_websocket_route(self, path: str, route) -> None: + instance = Path(path, websocket_session(route), protocol="websocket") + self.router.routes.append(instance) + + def route(self, path: str): + def decorator(func): + self.add_route(path, func) + + return decorator + + def websocket_route(self, path: str): + def decorator(func): + self.add_websocket_route(path, func) + + return decorator + + def __call__(self, scope: Scope) -> ASGIInstance: + return self.router(scope) diff --git a/starlette/routing.py b/starlette/routing.py index 8832135a..4e0cd051 100644 --- a/starlette/routing.py +++ b/starlette/routing.py @@ -1,4 +1,4 @@ -from starlette import Response +from starlette.response import Response from starlette.types import Scope, ASGIApp, ASGIInstance import re import typing @@ -14,23 +14,29 @@ class Route: class Path(Route): def __init__( - self, path: str, app: ASGIApp, methods: typing.Sequence[str] = () + self, + path: str, + app: ASGIApp, + methods: typing.Sequence[str] = (), + protocol: str = None, ) -> None: self.path = path self.app = app + self.protocol = protocol self.methods = methods regex = "^" + path + "$" regex = re.sub("{([a-zA-Z_][a-zA-Z0-9_]*)}", r"(?P<\1>[^/]+)", regex) self.path_regex = re.compile(regex) def matches(self, scope: Scope) -> typing.Tuple[bool, Scope]: - match = self.path_regex.match(scope["path"]) - if match: - kwargs = dict(scope.get("kwargs", {})) - kwargs.update(match.groupdict()) - child_scope = dict(scope) - child_scope["kwargs"] = kwargs - return True, child_scope + if self.protocol is None or scope["type"] == self.protocol: + match = self.path_regex.match(scope["path"]) + if match: + kwargs = dict(scope.get("kwargs", {})) + kwargs.update(match.groupdict()) + child_scope = dict(scope) + child_scope["kwargs"] = kwargs + return True, child_scope return False, {} def __call__(self, scope: Scope) -> ASGIInstance: @@ -81,6 +87,12 @@ class Router: return self.not_found(scope) def not_found(self, scope: Scope) -> ASGIInstance: + if scope["type"] == "websocket": + + async def close(receive, send): + await send({"type": "websocket.close"}) + + return close return Response("Not found", 404, media_type="text/plain") diff --git a/starlette/testclient.py b/starlette/testclient.py index b3fb4f6c..2118b50e 100644 --- a/starlette/testclient.py +++ b/starlette/testclient.py @@ -140,11 +140,11 @@ class WebSocketTestSession: self._receive_queue = queue.Queue() self._send_queue = queue.Queue() self._thread = threading.Thread(target=self._run) - self._receive_queue.put({"type": "websocket.connect"}) + self.send({"type": "websocket.connect"}) self._thread.start() - message = self._send_queue.get() - self._raise_on_close_or_exception(message) - self.accepted_subprotocol = message["subprotocol"] + message = self.receive() + self._raise_on_close(message) + self.accepted_subprotocol = message.get("subprotocol", None) def __enter__(self): return self @@ -174,38 +174,45 @@ class WebSocketTestSession: async def _asgi_send(self, message): self._send_queue.put(message) - def _raise_on_close_or_exception(self, message): - if isinstance(message, BaseException): - raise message + def _raise_on_close(self, message): if message["type"] == "websocket.close": - raise WebSocketDisconnect(message["code"]) + raise WebSocketDisconnect(message.get("code", 1000)) + + def send(self, message): + self._receive_queue.put(message) def send_text(self, data): - self._receive_queue.put({"type": "websocket.receive", "text": data}) + self.send({"type": "websocket.receive", "text": data}) def send_bytes(self, data): - self._receive_queue.put({"type": "websocket.receive", "bytes": data}) + self.send({"type": "websocket.receive", "bytes": data}) def send_json(self, data): encoded = json.dumps(data).encode("utf-8") - self._receive_queue.put({"type": "websocket.receive", "bytes": encoded}) + self.send({"type": "websocket.receive", "bytes": encoded}) def close(self, code=1000): - self._receive_queue.put({"type": "websocket.disconnect", "code": code}) + self.send({"type": "websocket.disconnect", "code": code}) + + def receive(self): + message = self._send_queue.get() + if isinstance(message, BaseException): + raise message + return message def receive_text(self): - message = self._send_queue.get() - self._raise_on_close_or_exception(message) + message = self.receive() + self._raise_on_close(message) return message["text"] def receive_bytes(self): - message = self._send_queue.get() - self._raise_on_close_or_exception(message) + message = self.receive() + self._raise_on_close(message) return message["bytes"] def receive_json(self): - message = self._send_queue.get() - self._raise_on_close_or_exception(message) + message = self.receive() + self._raise_on_close(message) encoded = message["bytes"] return json.loads(encoded.decode("utf-8")) @@ -225,7 +232,9 @@ class _TestClient(requests.Session): url = urljoin(self.base_url, url) return super().request(method, url, **kwargs) - def wsconnect(self, url: str, subprotocols=None, **kwargs) -> WebSocketTestSession: + def websocket_connect( + self, url: str, subprotocols=None, **kwargs + ) -> WebSocketTestSession: url = urljoin("ws://testserver", url) headers = kwargs.get("headers", {}) headers.setdefault("connection", "upgrade") diff --git a/tests/test_app.py b/tests/test_app.py new file mode 100644 index 00000000..73b6a479 --- /dev/null +++ b/tests/test_app.py @@ -0,0 +1,75 @@ +from starlette import App +from starlette.response import PlainTextResponse +from starlette.staticfiles import StaticFiles +from starlette.testclient import TestClient +import os + + +app = App() + + +@app.route("/func") +def func_homepage(request): + return PlainTextResponse("Hello, world!") + + +@app.route("/async") +async def async_homepage(request): + return PlainTextResponse("Hello, world!") + + +@app.route("/user/{username}") +def user_page(request, username): + return PlainTextResponse("Hello, %s!" % username) + + +@app.websocket_route("/ws") +async def websocket_endpoint(session): + await session.accept() + await session.send_text("Hello, world!") + await session.close() + + +client = TestClient(app) + + +def test_func_route(): + response = client.get("/func") + assert response.status_code == 200 + assert response.text == "Hello, world!" + + +def test_async_route(): + response = client.get("/async") + assert response.status_code == 200 + assert response.text == "Hello, world!" + + +def test_route_kwargs(): + response = client.get("/user/tomchristie") + assert response.status_code == 200 + assert response.text == "Hello, tomchristie!" + + +def test_websocket_route(): + with client.websocket_connect("/ws") as session: + text = session.receive_text() + assert text == "Hello, world!" + + +def test_400(): + response = client.get("/404") + assert response.status_code == 404 + + +def test_app_mount(tmpdir): + path = os.path.join(tmpdir, "example.txt") + with open(path, "w") as file: + file.write("") + + app = App() + app.mount("/static", StaticFiles(directory=tmpdir)) + client = TestClient(app) + response = client.get("/static/example.txt") + assert response.status_code == 200 + assert response.text == "" diff --git a/tests/test_routing.py b/tests/test_routing.py index 16eeb41a..3cad2c8c 100644 --- a/tests/test_routing.py +++ b/tests/test_routing.py @@ -1,6 +1,7 @@ from starlette import Response, TestClient from starlette.routing import Path, PathPrefix, Router, ProtocolRouter -from starlette.websockets import WebSocketSession +from starlette.websockets import WebSocketSession, WebSocketDisconnect +import pytest def homepage(scope): @@ -78,7 +79,10 @@ def websocket_endpoint(scope): mixed_protocol_app = ProtocolRouter( - {"http": http_endpoint, "websocket": websocket_endpoint} + { + "http": Router([Path("/", app=http_endpoint)]), + "websocket": Router([Path("/", app=websocket_endpoint)]), + } ) @@ -89,5 +93,8 @@ def test_protocol_switch(): assert response.status_code == 200 assert response.text == "Hello, world" - with client.wsconnect("/") as session: + with client.websocket_connect("/") as session: assert session.receive_json() == {"hello": "world"} + + with pytest.raises(WebSocketDisconnect): + client.websocket_connect("/404") diff --git a/tests/test_websockets.py b/tests/test_websockets.py index 2416a315..ed528fec 100644 --- a/tests/test_websockets.py +++ b/tests/test_websockets.py @@ -14,7 +14,7 @@ def test_session_url(): return asgi client = TestClient(app) - with client.wsconnect("/123?a=abc") as session: + with client.websocket_connect("/123?a=abc") as session: data = session.receive_json() assert data == {"url": "ws://testserver/123?a=abc"} @@ -31,7 +31,7 @@ def test_session_query_params(): return asgi client = TestClient(app) - with client.wsconnect("/?a=abc&b=456") as session: + with client.websocket_connect("/?a=abc&b=456") as session: data = session.receive_json() assert data == {"params": {"a": "abc", "b": "456"}} @@ -48,7 +48,7 @@ def test_session_headers(): return asgi client = TestClient(app) - with client.wsconnect("/") as session: + with client.websocket_connect("/") as session: expected_headers = { "accept": "*/*", "accept-encoding": "gzip, deflate", @@ -73,7 +73,7 @@ def test_session_port(): return asgi client = TestClient(app) - with client.wsconnect("ws://example.com:123/123?a=abc") as session: + with client.websocket_connect("ws://example.com:123/123?a=abc") as session: data = session.receive_json() assert data == {"port": 123} @@ -90,7 +90,7 @@ def test_session_send_and_receive_text(): return asgi client = TestClient(app) - with client.wsconnect("/") as session: + with client.websocket_connect("/") as session: session.send_text("Hello, world!") data = session.receive_text() assert data == "Message was: Hello, world!" @@ -108,7 +108,7 @@ def test_session_send_and_receive_bytes(): return asgi client = TestClient(app) - with client.wsconnect("/") as session: + with client.websocket_connect("/") as session: session.send_bytes(b"Hello, world!") data = session.receive_bytes() assert data == b"Message was: Hello, world!" @@ -126,7 +126,7 @@ def test_session_send_and_receive_json(): return asgi client = TestClient(app) - with client.wsconnect("/") as session: + with client.websocket_connect("/") as session: session.send_json({"hello": "world"}) data = session.receive_json() assert data == {"message": {"hello": "world"}} @@ -148,7 +148,7 @@ def test_client_close(): return asgi client = TestClient(app) - with client.wsconnect("/") as session: + with client.websocket_connect("/") as session: session.close(code=1001) assert close_code == 1001 @@ -163,7 +163,7 @@ def test_application_close(): return asgi client = TestClient(app) - with client.wsconnect("/") as session: + with client.websocket_connect("/") as session: with pytest.raises(WebSocketDisconnect) as exc: session.receive_text() assert exc.value.code == 1001 @@ -179,7 +179,7 @@ def test_rejected_connection(): client = TestClient(app) with pytest.raises(WebSocketDisconnect) as exc: - client.wsconnect("/") + client.websocket_connect("/") assert exc.value.code == 1001 @@ -194,7 +194,7 @@ def test_subprotocol(): return asgi client = TestClient(app) - with client.wsconnect("/", subprotocols=["soap", "wamp"]) as session: + with client.websocket_connect("/", subprotocols=["soap", "wamp"]) as session: assert session.accepted_subprotocol == "wamp" @@ -207,7 +207,7 @@ def test_session_exception(): client = TestClient(app) with pytest.raises(AssertionError): - client.wsconnect("/123?a=abc") + client.websocket_connect("/123?a=abc") def test_duplicate_close(): @@ -222,7 +222,7 @@ def test_duplicate_close(): client = TestClient(app) with pytest.raises(RuntimeError): - with client.wsconnect("/") as session: + with client.websocket_connect("/") as session: pass @@ -239,7 +239,7 @@ def test_duplicate_disconnect(): client = TestClient(app) with pytest.raises(RuntimeError): - with client.wsconnect("/") as session: + with client.websocket_connect("/") as session: session.close()