mirror of https://github.com/encode/starlette.git
Merge branch 'master' of https://github.com/encode/starlette
This commit is contained in:
commit
342a52f185
24
README.md
24
README.md
|
@ -161,6 +161,30 @@ class App:
|
|||
await response(receive, send)
|
||||
```
|
||||
|
||||
### FileResponse
|
||||
|
||||
Asynchronously streams a file as the response.
|
||||
|
||||
Takes a different set of arguments to instantiate than the other response types:
|
||||
|
||||
* `path` - The filepath to the file to stream.
|
||||
* `headers` - Any custom headers to include, as a dictionary.
|
||||
* `media_type` - A string giving the media type. If unset, the filename or path will be used to infer a media type.
|
||||
* `filename` - If set, this will be included in the response `Content-Disposition`.
|
||||
|
||||
```python
|
||||
from starlette import FileResponse
|
||||
|
||||
|
||||
class App:
|
||||
def __init__(self, scope):
|
||||
self.scope = scope
|
||||
|
||||
async def __call__(self, receive, send):
|
||||
response = FileResponse('/statics/favicon.ico')
|
||||
await response(receive, send)
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Requests
|
||||
|
|
|
@ -1,3 +1,4 @@
|
|||
aiofiles
|
||||
requests
|
||||
|
||||
# Testing
|
||||
|
|
1
setup.py
1
setup.py
|
@ -44,6 +44,7 @@ setup(
|
|||
author_email='tom@tomchristie.com',
|
||||
packages=get_packages('starlette'),
|
||||
install_requires=[
|
||||
'aiofiles',
|
||||
'requests',
|
||||
],
|
||||
classifiers=[
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
from starlette.decorators import asgi_application
|
||||
from starlette.response import (
|
||||
FileResponse,
|
||||
HTMLResponse,
|
||||
JSONResponse,
|
||||
Response,
|
||||
|
@ -13,6 +14,7 @@ from starlette.testclient import TestClient
|
|||
|
||||
__all__ = (
|
||||
"asgi_application",
|
||||
"FileResponse",
|
||||
"HTMLResponse",
|
||||
"JSONResponse",
|
||||
"Path",
|
||||
|
@ -24,4 +26,4 @@ __all__ = (
|
|||
"Request",
|
||||
"TestClient",
|
||||
)
|
||||
__version__ = "0.1.6"
|
||||
__version__ = "0.1.7"
|
||||
|
|
|
@ -75,7 +75,7 @@ class QueryParams(typing.Mapping[str, str]):
|
|||
self._dict = {k: v for k, v in reversed(items)}
|
||||
self._list = items
|
||||
|
||||
def get_list(self, key: str) -> typing.List[str]:
|
||||
def getlist(self, key: str) -> typing.List[str]:
|
||||
return [item_value for item_key, item_value in self._list if item_key == key]
|
||||
|
||||
def keys(self):
|
||||
|
@ -119,16 +119,15 @@ class Headers(typing.Mapping[str, str]):
|
|||
An immutable, case-insensitive multidict.
|
||||
"""
|
||||
|
||||
def __init__(self, value: typing.Union[StrDict, StrPairs] = None) -> None:
|
||||
if value is None:
|
||||
def __init__(self, raw_headers=None) -> None:
|
||||
if raw_headers is None:
|
||||
self._list = []
|
||||
else:
|
||||
assert isinstance(value, list)
|
||||
for header_key, header_value in value:
|
||||
for header_key, header_value in raw_headers:
|
||||
assert isinstance(header_key, bytes)
|
||||
assert isinstance(header_value, bytes)
|
||||
assert header_key == header_key.lower()
|
||||
self._list = value
|
||||
self._list = raw_headers
|
||||
|
||||
def keys(self):
|
||||
return [key.decode("latin-1") for key, value in self._list]
|
||||
|
@ -148,7 +147,7 @@ class Headers(typing.Mapping[str, str]):
|
|||
except KeyError:
|
||||
return default
|
||||
|
||||
def get_list(self, key: str) -> typing.List[str]:
|
||||
def getlist(self, key: str) -> typing.List[str]:
|
||||
get_header_key = key.lower().encode("latin-1")
|
||||
return [
|
||||
item_value.decode("latin-1")
|
||||
|
@ -164,7 +163,11 @@ class Headers(typing.Mapping[str, str]):
|
|||
raise KeyError(key)
|
||||
|
||||
def __contains__(self, key: str):
|
||||
return key.lower() in self.keys()
|
||||
get_header_key = key.lower().encode("latin-1")
|
||||
for header_key, header_value in self._list:
|
||||
if header_key == get_header_key:
|
||||
return True
|
||||
return False
|
||||
|
||||
def __iter__(self):
|
||||
return iter(self.items())
|
||||
|
@ -183,6 +186,9 @@ class Headers(typing.Mapping[str, str]):
|
|||
|
||||
class MutableHeaders(Headers):
|
||||
def __setitem__(self, key: str, value: str):
|
||||
"""
|
||||
Set the header `key` to `value`, removing any duplicate entries.
|
||||
"""
|
||||
set_key = key.lower().encode("latin-1")
|
||||
set_value = value.encode("latin-1")
|
||||
|
||||
|
@ -195,3 +201,31 @@ class MutableHeaders(Headers):
|
|||
del self._list[idx]
|
||||
|
||||
self._list.append((set_key, set_value))
|
||||
|
||||
def __delitem__(self, key: str):
|
||||
"""
|
||||
Remove the header `key`.
|
||||
"""
|
||||
del_key = key.lower().encode("latin-1")
|
||||
|
||||
pop_indexes = []
|
||||
for idx, (item_key, item_value) in enumerate(self._list):
|
||||
if item_key == del_key:
|
||||
pop_indexes.append(idx)
|
||||
|
||||
for idx in reversed(pop_indexes):
|
||||
del (self._list[idx])
|
||||
|
||||
def setdefault(self, key: str, value: str):
|
||||
"""
|
||||
If the header `key` does not exist, then set it to `value`.
|
||||
Returns the header value.
|
||||
"""
|
||||
set_key = key.lower().encode("latin-1")
|
||||
set_value = value.encode("latin-1")
|
||||
|
||||
for idx, (item_key, item_value) in enumerate(self._list):
|
||||
if item_key == set_key:
|
||||
return item_value.decode("latin-1")
|
||||
self._list.append((set_key, set_value))
|
||||
return value
|
||||
|
|
|
@ -1,7 +1,10 @@
|
|||
from mimetypes import guess_type
|
||||
from starlette.datastructures import MutableHeaders
|
||||
from starlette.types import Receive, Send
|
||||
import aiofiles
|
||||
import json
|
||||
import typing
|
||||
import os
|
||||
|
||||
|
||||
class Response:
|
||||
|
@ -19,36 +22,37 @@ class Response:
|
|||
self.status_code = status_code
|
||||
if media_type is not None:
|
||||
self.media_type = media_type
|
||||
self.set_default_headers(headers)
|
||||
self.init_headers(headers)
|
||||
|
||||
def render(self, content: typing.Any) -> bytes:
|
||||
if isinstance(content, bytes):
|
||||
return content
|
||||
return content.encode(self.charset)
|
||||
|
||||
def set_default_headers(self, headers: dict = None):
|
||||
def init_headers(self, headers):
|
||||
if headers is None:
|
||||
raw_headers = []
|
||||
missing_content_length = True
|
||||
missing_content_type = True
|
||||
populate_content_length = True
|
||||
populate_content_type = True
|
||||
else:
|
||||
raw_headers = [
|
||||
(k.lower().encode("latin-1"), v.encode("latin-1"))
|
||||
for k, v in headers.items()
|
||||
]
|
||||
missing_content_length = "content-length" not in headers
|
||||
missing_content_type = "content-type" not in headers
|
||||
keys = [h[0] for h in raw_headers]
|
||||
populate_content_length = b"content-length" in keys
|
||||
populate_content_type = b"content-type" in keys
|
||||
|
||||
if missing_content_length:
|
||||
content_length = str(len(self.body)).encode()
|
||||
raw_headers.append((b"content-length", content_length))
|
||||
body = getattr(self, "body", None)
|
||||
if body is not None and populate_content_length:
|
||||
content_length = str(len(body))
|
||||
raw_headers.append((b"content-length", content_length.encode("latin-1")))
|
||||
|
||||
if self.media_type is not None and missing_content_type:
|
||||
content_type = self.media_type
|
||||
if content_type.startswith("text/") and self.charset is not None:
|
||||
content_type += "; charset=%s" % self.charset
|
||||
content_type_value = content_type.encode("latin-1")
|
||||
raw_headers.append((b"content-type", content_type_value))
|
||||
content_type = self.media_type
|
||||
if content_type is not None and populate_content_type:
|
||||
if content_type.startswith("text/"):
|
||||
content_type += "; charset=" + self.charset
|
||||
raw_headers.append((b"content-type", content_type.encode("latin-1")))
|
||||
|
||||
self.raw_headers = raw_headers
|
||||
|
||||
|
@ -100,18 +104,15 @@ class StreamingResponse(Response):
|
|||
) -> None:
|
||||
self.body_iterator = content
|
||||
self.status_code = status_code
|
||||
if media_type is not None:
|
||||
self.media_type = media_type
|
||||
self.set_default_headers(headers)
|
||||
self.media_type = self.media_type if media_type is None else media_type
|
||||
self.init_headers(headers)
|
||||
|
||||
async def __call__(self, receive: Receive, send: Send) -> None:
|
||||
await send(
|
||||
{
|
||||
"type": "http.response.start",
|
||||
"status": self.status_code,
|
||||
"headers": [
|
||||
[key.encode(), value.encode()] for key, value in self.headers
|
||||
],
|
||||
"headers": self.raw_headers,
|
||||
}
|
||||
)
|
||||
async for chunk in self.body_iterator:
|
||||
|
@ -120,22 +121,41 @@ class StreamingResponse(Response):
|
|||
await send({"type": "http.response.body", "body": chunk, "more_body": True})
|
||||
await send({"type": "http.response.body", "body": b"", "more_body": False})
|
||||
|
||||
def set_default_headers(self, headers: dict = None):
|
||||
if headers is None:
|
||||
raw_headers = []
|
||||
missing_content_type = True
|
||||
else:
|
||||
raw_headers = [
|
||||
(k.lower().encode("latin-1"), v.encode("latin-1"))
|
||||
for k, v in headers.items()
|
||||
]
|
||||
missing_content_type = "content-type" not in headers
|
||||
|
||||
if self.media_type is not None and missing_content_type:
|
||||
content_type = self.media_type
|
||||
if content_type.startswith("text/") and self.charset is not None:
|
||||
content_type += "; charset=%s" % self.charset
|
||||
content_type_value = content_type.encode("latin-1")
|
||||
raw_headers.append((b"content-type", content_type_value))
|
||||
class FileResponse(Response):
|
||||
chunk_size = 4096
|
||||
|
||||
self.raw_headers = raw_headers
|
||||
def __init__(
|
||||
self,
|
||||
path: str,
|
||||
headers: dict = None,
|
||||
media_type: str = None,
|
||||
filename: str = None,
|
||||
) -> None:
|
||||
self.path = path
|
||||
self.status_code = 200
|
||||
self.filename = filename
|
||||
if media_type is None:
|
||||
media_type = guess_type(filename or path)[0] or "text/plain"
|
||||
self.media_type = media_type
|
||||
self.init_headers(headers)
|
||||
if self.filename is not None:
|
||||
content_disposition = 'attachment; filename="{}"'.format(self.filename)
|
||||
self.headers.setdefault("content-disposition", content_disposition)
|
||||
|
||||
async def __call__(self, receive: Receive, send: Send) -> None:
|
||||
await send(
|
||||
{
|
||||
"type": "http.response.start",
|
||||
"status": self.status_code,
|
||||
"headers": self.raw_headers,
|
||||
}
|
||||
)
|
||||
async with aiofiles.open(self.path, mode="rb") as file:
|
||||
more_body = True
|
||||
while more_body:
|
||||
chunk = await file.read(self.chunk_size)
|
||||
more_body = len(chunk) == self.chunk_size
|
||||
await send(
|
||||
{"type": "http.response.body", "body": chunk, "more_body": False}
|
||||
)
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
from starlette.datastructures import Headers, QueryParams, URL
|
||||
from starlette.datastructures import Headers, MutableHeaders, QueryParams, URL
|
||||
|
||||
|
||||
def test_url():
|
||||
|
@ -25,7 +25,7 @@ def test_headers():
|
|||
assert h["a"] == "123"
|
||||
assert h.get("a") == "123"
|
||||
assert h.get("nope", default=None) is None
|
||||
assert h.get_list("a") == ["123", "456"]
|
||||
assert h.getlist("a") == ["123", "456"]
|
||||
assert h.keys() == ["a", "a", "b"]
|
||||
assert h.values() == ["123", "456", "789"]
|
||||
assert h.items() == [("a", "123"), ("a", "456"), ("b", "789")]
|
||||
|
@ -34,8 +34,21 @@ def test_headers():
|
|||
assert repr(h) == "Headers([('a', '123'), ('a', '456'), ('b', '789')])"
|
||||
assert h == Headers([(b"a", b"123"), (b"b", b"789"), (b"a", b"456")])
|
||||
assert h != [(b"a", b"123"), (b"A", b"456"), (b"b", b"789")]
|
||||
h = Headers()
|
||||
assert not h.items()
|
||||
|
||||
|
||||
def test_mutable_headers():
|
||||
h = MutableHeaders()
|
||||
assert dict(h) == {}
|
||||
h["a"] = "1"
|
||||
assert dict(h) == {"a": "1"}
|
||||
h["a"] = "2"
|
||||
assert dict(h) == {"a": "2"}
|
||||
h.setdefault("a", "3")
|
||||
assert dict(h) == {"a": "2"}
|
||||
h.setdefault("b", "4")
|
||||
assert dict(h) == {"a": "2", "b": "4"}
|
||||
del h["a"]
|
||||
assert dict(h) == {"b": "4"}
|
||||
|
||||
|
||||
def test_queryparams():
|
||||
|
@ -46,7 +59,7 @@ def test_queryparams():
|
|||
assert q["a"] == "123"
|
||||
assert q.get("a") == "123"
|
||||
assert q.get("nope", default=None) is None
|
||||
assert q.get_list("a") == ["123", "456"]
|
||||
assert q.getlist("a") == ["123", "456"]
|
||||
assert q.keys() == ["a", "a", "b"]
|
||||
assert q.values() == ["123", "456", "789"]
|
||||
assert q.items() == [("a", "123"), ("a", "456"), ("b", "789")]
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
from starlette import Response, StreamingResponse, TestClient
|
||||
from starlette import FileResponse, Response, StreamingResponse, TestClient
|
||||
import asyncio
|
||||
|
||||
|
||||
|
@ -67,22 +67,18 @@ def test_response_headers():
|
|||
assert response.headers["x-header-2"] == "789"
|
||||
|
||||
|
||||
def test_streaming_response_headers():
|
||||
def test_file_response(tmpdir):
|
||||
with open("xyz", "wb") as file:
|
||||
file.write(b"<file content>")
|
||||
|
||||
def app(scope):
|
||||
async def asgi(receive, send):
|
||||
async def stream(msg):
|
||||
yield "hello, world"
|
||||
|
||||
headers = {"x-header-1": "123", "x-header-2": "456"}
|
||||
response = StreamingResponse(
|
||||
stream("hello, world"), media_type="text/plain", headers=headers
|
||||
)
|
||||
response.headers["x-header-2"] = "789"
|
||||
await response(receive, send)
|
||||
|
||||
return asgi
|
||||
return FileResponse(path="xyz", filename="example.png")
|
||||
|
||||
client = TestClient(app)
|
||||
response = client.get("/")
|
||||
assert response.headers["x-header-1"] == "123"
|
||||
assert response.headers["x-header-2"] == "789"
|
||||
assert response.status_code == 200
|
||||
assert response.content == b"<file content>"
|
||||
assert response.headers["content-type"] == "image/png"
|
||||
assert (
|
||||
response.headers["content-disposition"] == 'attachment; filename="example.png"'
|
||||
)
|
||||
|
|
Loading…
Reference in New Issue