diff --git a/README.md b/README.md index 06d07dc7..6628e999 100644 --- a/README.md +++ b/README.md @@ -161,6 +161,30 @@ class App: await response(receive, send) ``` +### FileResponse + +Asynchronously streams a file as the response. + +Takes a different set of arguments to instantiate than the other response types: + +* `path` - The filepath to the file to stream. +* `headers` - Any custom headers to include, as a dictionary. +* `media_type` - A string giving the media type. If unset, the filename or path will be used to infer a media type. +* `filename` - If set, this will be included in the response `Content-Disposition`. + +```python +from starlette import FileResponse + + +class App: + def __init__(self, scope): + self.scope = scope + + async def __call__(self, receive, send): + response = FileResponse('/statics/favicon.ico') + await response(receive, send) +``` + --- ## Requests diff --git a/requirements.txt b/requirements.txt index 96b8130d..fbec2e51 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,3 +1,4 @@ +aiofiles requests # Testing diff --git a/setup.py b/setup.py index 3777b4ec..646c7613 100644 --- a/setup.py +++ b/setup.py @@ -44,6 +44,7 @@ setup( author_email='tom@tomchristie.com', packages=get_packages('starlette'), install_requires=[ + 'aiofiles', 'requests', ], classifiers=[ diff --git a/starlette/__init__.py b/starlette/__init__.py index 22ae9352..26ac565a 100644 --- a/starlette/__init__.py +++ b/starlette/__init__.py @@ -1,5 +1,6 @@ from starlette.decorators import asgi_application from starlette.response import ( + FileResponse, HTMLResponse, JSONResponse, Response, @@ -13,6 +14,7 @@ from starlette.testclient import TestClient __all__ = ( "asgi_application", + "FileResponse", "HTMLResponse", "JSONResponse", "Path", @@ -24,4 +26,4 @@ __all__ = ( "Request", "TestClient", ) -__version__ = "0.1.6" +__version__ = "0.1.7" diff --git a/starlette/datastructures.py b/starlette/datastructures.py index cfb24bf9..247157ca 100644 --- a/starlette/datastructures.py +++ b/starlette/datastructures.py @@ -75,7 +75,7 @@ class QueryParams(typing.Mapping[str, str]): self._dict = {k: v for k, v in reversed(items)} self._list = items - def get_list(self, key: str) -> typing.List[str]: + def getlist(self, key: str) -> typing.List[str]: return [item_value for item_key, item_value in self._list if item_key == key] def keys(self): @@ -119,16 +119,15 @@ class Headers(typing.Mapping[str, str]): An immutable, case-insensitive multidict. """ - def __init__(self, value: typing.Union[StrDict, StrPairs] = None) -> None: - if value is None: + def __init__(self, raw_headers=None) -> None: + if raw_headers is None: self._list = [] else: - assert isinstance(value, list) - for header_key, header_value in value: + for header_key, header_value in raw_headers: assert isinstance(header_key, bytes) assert isinstance(header_value, bytes) assert header_key == header_key.lower() - self._list = value + self._list = raw_headers def keys(self): return [key.decode("latin-1") for key, value in self._list] @@ -148,7 +147,7 @@ class Headers(typing.Mapping[str, str]): except KeyError: return default - def get_list(self, key: str) -> typing.List[str]: + def getlist(self, key: str) -> typing.List[str]: get_header_key = key.lower().encode("latin-1") return [ item_value.decode("latin-1") @@ -164,7 +163,11 @@ class Headers(typing.Mapping[str, str]): raise KeyError(key) def __contains__(self, key: str): - return key.lower() in self.keys() + get_header_key = key.lower().encode("latin-1") + for header_key, header_value in self._list: + if header_key == get_header_key: + return True + return False def __iter__(self): return iter(self.items()) @@ -183,6 +186,9 @@ class Headers(typing.Mapping[str, str]): class MutableHeaders(Headers): def __setitem__(self, key: str, value: str): + """ + Set the header `key` to `value`, removing any duplicate entries. + """ set_key = key.lower().encode("latin-1") set_value = value.encode("latin-1") @@ -195,3 +201,31 @@ class MutableHeaders(Headers): del self._list[idx] self._list.append((set_key, set_value)) + + def __delitem__(self, key: str): + """ + Remove the header `key`. + """ + del_key = key.lower().encode("latin-1") + + pop_indexes = [] + for idx, (item_key, item_value) in enumerate(self._list): + if item_key == del_key: + pop_indexes.append(idx) + + for idx in reversed(pop_indexes): + del (self._list[idx]) + + def setdefault(self, key: str, value: str): + """ + If the header `key` does not exist, then set it to `value`. + Returns the header value. + """ + set_key = key.lower().encode("latin-1") + set_value = value.encode("latin-1") + + for idx, (item_key, item_value) in enumerate(self._list): + if item_key == set_key: + return item_value.decode("latin-1") + self._list.append((set_key, set_value)) + return value diff --git a/starlette/response.py b/starlette/response.py index 28d0efd8..c3c98197 100644 --- a/starlette/response.py +++ b/starlette/response.py @@ -1,7 +1,10 @@ +from mimetypes import guess_type from starlette.datastructures import MutableHeaders from starlette.types import Receive, Send +import aiofiles import json import typing +import os class Response: @@ -19,36 +22,37 @@ class Response: self.status_code = status_code if media_type is not None: self.media_type = media_type - self.set_default_headers(headers) + self.init_headers(headers) def render(self, content: typing.Any) -> bytes: if isinstance(content, bytes): return content return content.encode(self.charset) - def set_default_headers(self, headers: dict = None): + def init_headers(self, headers): if headers is None: raw_headers = [] - missing_content_length = True - missing_content_type = True + populate_content_length = True + populate_content_type = True else: raw_headers = [ (k.lower().encode("latin-1"), v.encode("latin-1")) for k, v in headers.items() ] - missing_content_length = "content-length" not in headers - missing_content_type = "content-type" not in headers + keys = [h[0] for h in raw_headers] + populate_content_length = b"content-length" in keys + populate_content_type = b"content-type" in keys - if missing_content_length: - content_length = str(len(self.body)).encode() - raw_headers.append((b"content-length", content_length)) + body = getattr(self, "body", None) + if body is not None and populate_content_length: + content_length = str(len(body)) + raw_headers.append((b"content-length", content_length.encode("latin-1"))) - if self.media_type is not None and missing_content_type: - content_type = self.media_type - if content_type.startswith("text/") and self.charset is not None: - content_type += "; charset=%s" % self.charset - content_type_value = content_type.encode("latin-1") - raw_headers.append((b"content-type", content_type_value)) + content_type = self.media_type + if content_type is not None and populate_content_type: + if content_type.startswith("text/"): + content_type += "; charset=" + self.charset + raw_headers.append((b"content-type", content_type.encode("latin-1"))) self.raw_headers = raw_headers @@ -100,18 +104,15 @@ class StreamingResponse(Response): ) -> None: self.body_iterator = content self.status_code = status_code - if media_type is not None: - self.media_type = media_type - self.set_default_headers(headers) + self.media_type = self.media_type if media_type is None else media_type + self.init_headers(headers) async def __call__(self, receive: Receive, send: Send) -> None: await send( { "type": "http.response.start", "status": self.status_code, - "headers": [ - [key.encode(), value.encode()] for key, value in self.headers - ], + "headers": self.raw_headers, } ) async for chunk in self.body_iterator: @@ -120,22 +121,41 @@ class StreamingResponse(Response): await send({"type": "http.response.body", "body": chunk, "more_body": True}) await send({"type": "http.response.body", "body": b"", "more_body": False}) - def set_default_headers(self, headers: dict = None): - if headers is None: - raw_headers = [] - missing_content_type = True - else: - raw_headers = [ - (k.lower().encode("latin-1"), v.encode("latin-1")) - for k, v in headers.items() - ] - missing_content_type = "content-type" not in headers - if self.media_type is not None and missing_content_type: - content_type = self.media_type - if content_type.startswith("text/") and self.charset is not None: - content_type += "; charset=%s" % self.charset - content_type_value = content_type.encode("latin-1") - raw_headers.append((b"content-type", content_type_value)) +class FileResponse(Response): + chunk_size = 4096 - self.raw_headers = raw_headers + def __init__( + self, + path: str, + headers: dict = None, + media_type: str = None, + filename: str = None, + ) -> None: + self.path = path + self.status_code = 200 + self.filename = filename + if media_type is None: + media_type = guess_type(filename or path)[0] or "text/plain" + self.media_type = media_type + self.init_headers(headers) + if self.filename is not None: + content_disposition = 'attachment; filename="{}"'.format(self.filename) + self.headers.setdefault("content-disposition", content_disposition) + + async def __call__(self, receive: Receive, send: Send) -> None: + await send( + { + "type": "http.response.start", + "status": self.status_code, + "headers": self.raw_headers, + } + ) + async with aiofiles.open(self.path, mode="rb") as file: + more_body = True + while more_body: + chunk = await file.read(self.chunk_size) + more_body = len(chunk) == self.chunk_size + await send( + {"type": "http.response.body", "body": chunk, "more_body": False} + ) diff --git a/tests/test_datastructures.py b/tests/test_datastructures.py index c510d003..0a5a8441 100644 --- a/tests/test_datastructures.py +++ b/tests/test_datastructures.py @@ -1,4 +1,4 @@ -from starlette.datastructures import Headers, QueryParams, URL +from starlette.datastructures import Headers, MutableHeaders, QueryParams, URL def test_url(): @@ -25,7 +25,7 @@ def test_headers(): assert h["a"] == "123" assert h.get("a") == "123" assert h.get("nope", default=None) is None - assert h.get_list("a") == ["123", "456"] + assert h.getlist("a") == ["123", "456"] assert h.keys() == ["a", "a", "b"] assert h.values() == ["123", "456", "789"] assert h.items() == [("a", "123"), ("a", "456"), ("b", "789")] @@ -34,8 +34,21 @@ def test_headers(): assert repr(h) == "Headers([('a', '123'), ('a', '456'), ('b', '789')])" assert h == Headers([(b"a", b"123"), (b"b", b"789"), (b"a", b"456")]) assert h != [(b"a", b"123"), (b"A", b"456"), (b"b", b"789")] - h = Headers() - assert not h.items() + + +def test_mutable_headers(): + h = MutableHeaders() + assert dict(h) == {} + h["a"] = "1" + assert dict(h) == {"a": "1"} + h["a"] = "2" + assert dict(h) == {"a": "2"} + h.setdefault("a", "3") + assert dict(h) == {"a": "2"} + h.setdefault("b", "4") + assert dict(h) == {"a": "2", "b": "4"} + del h["a"] + assert dict(h) == {"b": "4"} def test_queryparams(): @@ -46,7 +59,7 @@ def test_queryparams(): assert q["a"] == "123" assert q.get("a") == "123" assert q.get("nope", default=None) is None - assert q.get_list("a") == ["123", "456"] + assert q.getlist("a") == ["123", "456"] assert q.keys() == ["a", "a", "b"] assert q.values() == ["123", "456", "789"] assert q.items() == [("a", "123"), ("a", "456"), ("b", "789")] diff --git a/tests/test_response.py b/tests/test_response.py index edcabb84..d4850fc8 100644 --- a/tests/test_response.py +++ b/tests/test_response.py @@ -1,4 +1,4 @@ -from starlette import Response, StreamingResponse, TestClient +from starlette import FileResponse, Response, StreamingResponse, TestClient import asyncio @@ -67,22 +67,18 @@ def test_response_headers(): assert response.headers["x-header-2"] == "789" -def test_streaming_response_headers(): +def test_file_response(tmpdir): + with open("xyz", "wb") as file: + file.write(b"") + def app(scope): - async def asgi(receive, send): - async def stream(msg): - yield "hello, world" - - headers = {"x-header-1": "123", "x-header-2": "456"} - response = StreamingResponse( - stream("hello, world"), media_type="text/plain", headers=headers - ) - response.headers["x-header-2"] = "789" - await response(receive, send) - - return asgi + return FileResponse(path="xyz", filename="example.png") client = TestClient(app) response = client.get("/") - assert response.headers["x-header-1"] == "123" - assert response.headers["x-header-2"] == "789" + assert response.status_code == 200 + assert response.content == b"" + assert response.headers["content-type"] == "image/png" + assert ( + response.headers["content-disposition"] == 'attachment; filename="example.png"' + )