This commit is contained in:
Tom Christie 2018-07-11 16:54:09 +01:00
commit 342a52f185
8 changed files with 159 additions and 68 deletions

View File

@ -161,6 +161,30 @@ class App:
await response(receive, send)
```
### FileResponse
Asynchronously streams a file as the response.
Takes a different set of arguments to instantiate than the other response types:
* `path` - The filepath to the file to stream.
* `headers` - Any custom headers to include, as a dictionary.
* `media_type` - A string giving the media type. If unset, the filename or path will be used to infer a media type.
* `filename` - If set, this will be included in the response `Content-Disposition`.
```python
from starlette import FileResponse
class App:
def __init__(self, scope):
self.scope = scope
async def __call__(self, receive, send):
response = FileResponse('/statics/favicon.ico')
await response(receive, send)
```
---
## Requests

View File

@ -1,3 +1,4 @@
aiofiles
requests
# Testing

View File

@ -44,6 +44,7 @@ setup(
author_email='tom@tomchristie.com',
packages=get_packages('starlette'),
install_requires=[
'aiofiles',
'requests',
],
classifiers=[

View File

@ -1,5 +1,6 @@
from starlette.decorators import asgi_application
from starlette.response import (
FileResponse,
HTMLResponse,
JSONResponse,
Response,
@ -13,6 +14,7 @@ from starlette.testclient import TestClient
__all__ = (
"asgi_application",
"FileResponse",
"HTMLResponse",
"JSONResponse",
"Path",
@ -24,4 +26,4 @@ __all__ = (
"Request",
"TestClient",
)
__version__ = "0.1.6"
__version__ = "0.1.7"

View File

@ -75,7 +75,7 @@ class QueryParams(typing.Mapping[str, str]):
self._dict = {k: v for k, v in reversed(items)}
self._list = items
def get_list(self, key: str) -> typing.List[str]:
def getlist(self, key: str) -> typing.List[str]:
return [item_value for item_key, item_value in self._list if item_key == key]
def keys(self):
@ -119,16 +119,15 @@ class Headers(typing.Mapping[str, str]):
An immutable, case-insensitive multidict.
"""
def __init__(self, value: typing.Union[StrDict, StrPairs] = None) -> None:
if value is None:
def __init__(self, raw_headers=None) -> None:
if raw_headers is None:
self._list = []
else:
assert isinstance(value, list)
for header_key, header_value in value:
for header_key, header_value in raw_headers:
assert isinstance(header_key, bytes)
assert isinstance(header_value, bytes)
assert header_key == header_key.lower()
self._list = value
self._list = raw_headers
def keys(self):
return [key.decode("latin-1") for key, value in self._list]
@ -148,7 +147,7 @@ class Headers(typing.Mapping[str, str]):
except KeyError:
return default
def get_list(self, key: str) -> typing.List[str]:
def getlist(self, key: str) -> typing.List[str]:
get_header_key = key.lower().encode("latin-1")
return [
item_value.decode("latin-1")
@ -164,7 +163,11 @@ class Headers(typing.Mapping[str, str]):
raise KeyError(key)
def __contains__(self, key: str):
return key.lower() in self.keys()
get_header_key = key.lower().encode("latin-1")
for header_key, header_value in self._list:
if header_key == get_header_key:
return True
return False
def __iter__(self):
return iter(self.items())
@ -183,6 +186,9 @@ class Headers(typing.Mapping[str, str]):
class MutableHeaders(Headers):
def __setitem__(self, key: str, value: str):
"""
Set the header `key` to `value`, removing any duplicate entries.
"""
set_key = key.lower().encode("latin-1")
set_value = value.encode("latin-1")
@ -195,3 +201,31 @@ class MutableHeaders(Headers):
del self._list[idx]
self._list.append((set_key, set_value))
def __delitem__(self, key: str):
"""
Remove the header `key`.
"""
del_key = key.lower().encode("latin-1")
pop_indexes = []
for idx, (item_key, item_value) in enumerate(self._list):
if item_key == del_key:
pop_indexes.append(idx)
for idx in reversed(pop_indexes):
del (self._list[idx])
def setdefault(self, key: str, value: str):
"""
If the header `key` does not exist, then set it to `value`.
Returns the header value.
"""
set_key = key.lower().encode("latin-1")
set_value = value.encode("latin-1")
for idx, (item_key, item_value) in enumerate(self._list):
if item_key == set_key:
return item_value.decode("latin-1")
self._list.append((set_key, set_value))
return value

View File

@ -1,7 +1,10 @@
from mimetypes import guess_type
from starlette.datastructures import MutableHeaders
from starlette.types import Receive, Send
import aiofiles
import json
import typing
import os
class Response:
@ -19,36 +22,37 @@ class Response:
self.status_code = status_code
if media_type is not None:
self.media_type = media_type
self.set_default_headers(headers)
self.init_headers(headers)
def render(self, content: typing.Any) -> bytes:
if isinstance(content, bytes):
return content
return content.encode(self.charset)
def set_default_headers(self, headers: dict = None):
def init_headers(self, headers):
if headers is None:
raw_headers = []
missing_content_length = True
missing_content_type = True
populate_content_length = True
populate_content_type = True
else:
raw_headers = [
(k.lower().encode("latin-1"), v.encode("latin-1"))
for k, v in headers.items()
]
missing_content_length = "content-length" not in headers
missing_content_type = "content-type" not in headers
keys = [h[0] for h in raw_headers]
populate_content_length = b"content-length" in keys
populate_content_type = b"content-type" in keys
if missing_content_length:
content_length = str(len(self.body)).encode()
raw_headers.append((b"content-length", content_length))
body = getattr(self, "body", None)
if body is not None and populate_content_length:
content_length = str(len(body))
raw_headers.append((b"content-length", content_length.encode("latin-1")))
if self.media_type is not None and missing_content_type:
content_type = self.media_type
if content_type.startswith("text/") and self.charset is not None:
content_type += "; charset=%s" % self.charset
content_type_value = content_type.encode("latin-1")
raw_headers.append((b"content-type", content_type_value))
content_type = self.media_type
if content_type is not None and populate_content_type:
if content_type.startswith("text/"):
content_type += "; charset=" + self.charset
raw_headers.append((b"content-type", content_type.encode("latin-1")))
self.raw_headers = raw_headers
@ -100,18 +104,15 @@ class StreamingResponse(Response):
) -> None:
self.body_iterator = content
self.status_code = status_code
if media_type is not None:
self.media_type = media_type
self.set_default_headers(headers)
self.media_type = self.media_type if media_type is None else media_type
self.init_headers(headers)
async def __call__(self, receive: Receive, send: Send) -> None:
await send(
{
"type": "http.response.start",
"status": self.status_code,
"headers": [
[key.encode(), value.encode()] for key, value in self.headers
],
"headers": self.raw_headers,
}
)
async for chunk in self.body_iterator:
@ -120,22 +121,41 @@ class StreamingResponse(Response):
await send({"type": "http.response.body", "body": chunk, "more_body": True})
await send({"type": "http.response.body", "body": b"", "more_body": False})
def set_default_headers(self, headers: dict = None):
if headers is None:
raw_headers = []
missing_content_type = True
else:
raw_headers = [
(k.lower().encode("latin-1"), v.encode("latin-1"))
for k, v in headers.items()
]
missing_content_type = "content-type" not in headers
if self.media_type is not None and missing_content_type:
content_type = self.media_type
if content_type.startswith("text/") and self.charset is not None:
content_type += "; charset=%s" % self.charset
content_type_value = content_type.encode("latin-1")
raw_headers.append((b"content-type", content_type_value))
class FileResponse(Response):
chunk_size = 4096
self.raw_headers = raw_headers
def __init__(
self,
path: str,
headers: dict = None,
media_type: str = None,
filename: str = None,
) -> None:
self.path = path
self.status_code = 200
self.filename = filename
if media_type is None:
media_type = guess_type(filename or path)[0] or "text/plain"
self.media_type = media_type
self.init_headers(headers)
if self.filename is not None:
content_disposition = 'attachment; filename="{}"'.format(self.filename)
self.headers.setdefault("content-disposition", content_disposition)
async def __call__(self, receive: Receive, send: Send) -> None:
await send(
{
"type": "http.response.start",
"status": self.status_code,
"headers": self.raw_headers,
}
)
async with aiofiles.open(self.path, mode="rb") as file:
more_body = True
while more_body:
chunk = await file.read(self.chunk_size)
more_body = len(chunk) == self.chunk_size
await send(
{"type": "http.response.body", "body": chunk, "more_body": False}
)

View File

@ -1,4 +1,4 @@
from starlette.datastructures import Headers, QueryParams, URL
from starlette.datastructures import Headers, MutableHeaders, QueryParams, URL
def test_url():
@ -25,7 +25,7 @@ def test_headers():
assert h["a"] == "123"
assert h.get("a") == "123"
assert h.get("nope", default=None) is None
assert h.get_list("a") == ["123", "456"]
assert h.getlist("a") == ["123", "456"]
assert h.keys() == ["a", "a", "b"]
assert h.values() == ["123", "456", "789"]
assert h.items() == [("a", "123"), ("a", "456"), ("b", "789")]
@ -34,8 +34,21 @@ def test_headers():
assert repr(h) == "Headers([('a', '123'), ('a', '456'), ('b', '789')])"
assert h == Headers([(b"a", b"123"), (b"b", b"789"), (b"a", b"456")])
assert h != [(b"a", b"123"), (b"A", b"456"), (b"b", b"789")]
h = Headers()
assert not h.items()
def test_mutable_headers():
h = MutableHeaders()
assert dict(h) == {}
h["a"] = "1"
assert dict(h) == {"a": "1"}
h["a"] = "2"
assert dict(h) == {"a": "2"}
h.setdefault("a", "3")
assert dict(h) == {"a": "2"}
h.setdefault("b", "4")
assert dict(h) == {"a": "2", "b": "4"}
del h["a"]
assert dict(h) == {"b": "4"}
def test_queryparams():
@ -46,7 +59,7 @@ def test_queryparams():
assert q["a"] == "123"
assert q.get("a") == "123"
assert q.get("nope", default=None) is None
assert q.get_list("a") == ["123", "456"]
assert q.getlist("a") == ["123", "456"]
assert q.keys() == ["a", "a", "b"]
assert q.values() == ["123", "456", "789"]
assert q.items() == [("a", "123"), ("a", "456"), ("b", "789")]

View File

@ -1,4 +1,4 @@
from starlette import Response, StreamingResponse, TestClient
from starlette import FileResponse, Response, StreamingResponse, TestClient
import asyncio
@ -67,22 +67,18 @@ def test_response_headers():
assert response.headers["x-header-2"] == "789"
def test_streaming_response_headers():
def test_file_response(tmpdir):
with open("xyz", "wb") as file:
file.write(b"<file content>")
def app(scope):
async def asgi(receive, send):
async def stream(msg):
yield "hello, world"
headers = {"x-header-1": "123", "x-header-2": "456"}
response = StreamingResponse(
stream("hello, world"), media_type="text/plain", headers=headers
)
response.headers["x-header-2"] = "789"
await response(receive, send)
return asgi
return FileResponse(path="xyz", filename="example.png")
client = TestClient(app)
response = client.get("/")
assert response.headers["x-header-1"] == "123"
assert response.headers["x-header-2"] == "789"
assert response.status_code == 200
assert response.content == b"<file content>"
assert response.headers["content-type"] == "image/png"
assert (
response.headers["content-disposition"] == 'attachment; filename="example.png"'
)