starlette/starlette/responses.py

363 lines
13 KiB
Python

import http.cookies
import json
import os
import stat
import sys
import typing
from email.utils import formatdate
from functools import partial
from mimetypes import guess_type as mimetypes_guess_type
from urllib.parse import quote
import anyio
from starlette._compat import md5_hexdigest
from starlette.background import BackgroundTask
from starlette.concurrency import iterate_in_threadpool
from starlette.datastructures import URL, MutableHeaders
from starlette.types import Receive, Scope, Send
if sys.version_info >= (3, 8): # pragma: no cover
from typing import Literal
else: # pragma: no cover
from typing_extensions import Literal
# Workaround for adding samesite support to pre 3.8 python
http.cookies.Morsel._reserved["samesite"] = "SameSite" # type: ignore
# Compatibility wrapper for `mimetypes.guess_type` to support `os.PathLike` on <py3.8
def guess_type(
url: typing.Union[str, "os.PathLike[str]"], strict: bool = True
) -> typing.Tuple[typing.Optional[str], typing.Optional[str]]:
if sys.version_info < (3, 8): # pragma: no cover
url = os.fspath(url)
return mimetypes_guess_type(url, strict)
class Response:
media_type = None
charset = "utf-8"
def __init__(
self,
content: typing.Any = None,
status_code: int = 200,
headers: typing.Optional[typing.Mapping[str, str]] = None,
media_type: typing.Optional[str] = None,
background: typing.Optional[BackgroundTask] = None,
) -> None:
self.status_code = status_code
if media_type is not None:
self.media_type = media_type
self.background = background
self.body = self.render(content)
self.init_headers(headers)
def render(self, content: typing.Any) -> bytes:
if content is None:
return b""
if isinstance(content, bytes):
return content
return content.encode(self.charset)
def init_headers(
self, headers: typing.Optional[typing.Mapping[str, str]] = None
) -> None:
if headers is None:
raw_headers: typing.List[typing.Tuple[bytes, bytes]] = []
populate_content_length = True
populate_content_type = True
else:
raw_headers = [
(k.lower().encode("latin-1"), v.encode("latin-1"))
for k, v in headers.items()
]
keys = [h[0] for h in raw_headers]
populate_content_length = b"content-length" not in keys
populate_content_type = b"content-type" not in keys
body = getattr(self, "body", None)
if (
body is not None
and populate_content_length
and not (self.status_code < 200 or self.status_code in (204, 304))
):
content_length = str(len(body))
raw_headers.append((b"content-length", content_length.encode("latin-1")))
content_type = self.media_type
if content_type is not None and populate_content_type:
if content_type.startswith("text/"):
content_type += "; charset=" + self.charset
raw_headers.append((b"content-type", content_type.encode("latin-1")))
self.raw_headers = raw_headers
@property
def headers(self) -> MutableHeaders:
if not hasattr(self, "_headers"):
self._headers = MutableHeaders(raw=self.raw_headers)
return self._headers
def set_cookie(
self,
key: str,
value: str = "",
max_age: typing.Optional[int] = None,
expires: typing.Optional[int] = None,
path: str = "/",
domain: typing.Optional[str] = None,
secure: bool = False,
httponly: bool = False,
samesite: typing.Optional[Literal["lax", "strict", "none"]] = "lax",
) -> None:
cookie: "http.cookies.BaseCookie[str]" = http.cookies.SimpleCookie()
cookie[key] = value
if max_age is not None:
cookie[key]["max-age"] = max_age
if expires is not None:
cookie[key]["expires"] = expires
if path is not None:
cookie[key]["path"] = path
if domain is not None:
cookie[key]["domain"] = domain
if secure:
cookie[key]["secure"] = True
if httponly:
cookie[key]["httponly"] = True
if samesite is not None:
assert samesite.lower() in [
"strict",
"lax",
"none",
], "samesite must be either 'strict', 'lax' or 'none'"
cookie[key]["samesite"] = samesite
cookie_val = cookie.output(header="").strip()
self.raw_headers.append((b"set-cookie", cookie_val.encode("latin-1")))
def delete_cookie(
self,
key: str,
path: str = "/",
domain: typing.Optional[str] = None,
secure: bool = False,
httponly: bool = False,
samesite: typing.Optional[Literal["lax", "strict", "none"]] = "lax",
) -> None:
self.set_cookie(
key,
max_age=0,
expires=0,
path=path,
domain=domain,
secure=secure,
httponly=httponly,
samesite=samesite,
)
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
await send(
{
"type": "http.response.start",
"status": self.status_code,
"headers": self.raw_headers,
}
)
await send({"type": "http.response.body", "body": self.body})
if self.background is not None:
await self.background()
class HTMLResponse(Response):
media_type = "text/html"
class PlainTextResponse(Response):
media_type = "text/plain"
class JSONResponse(Response):
media_type = "application/json"
def __init__(
self,
content: typing.Any,
status_code: int = 200,
headers: typing.Optional[typing.Dict[str, str]] = None,
media_type: typing.Optional[str] = None,
background: typing.Optional[BackgroundTask] = None,
) -> None:
super().__init__(content, status_code, headers, media_type, background)
def render(self, content: typing.Any) -> bytes:
return json.dumps(
content,
ensure_ascii=False,
allow_nan=False,
indent=None,
separators=(",", ":"),
).encode("utf-8")
class RedirectResponse(Response):
def __init__(
self,
url: typing.Union[str, URL],
status_code: int = 307,
headers: typing.Optional[typing.Mapping[str, str]] = None,
background: typing.Optional[BackgroundTask] = None,
) -> None:
super().__init__(
content=b"", status_code=status_code, headers=headers, background=background
)
self.headers["location"] = quote(str(url), safe=":/%#?=@[]!$&'()*+,;")
Content = typing.Union[str, bytes]
SyncContentStream = typing.Iterator[Content]
AsyncContentStream = typing.AsyncIterable[Content]
ContentStream = typing.Union[AsyncContentStream, SyncContentStream]
class StreamingResponse(Response):
body_iterator: AsyncContentStream
def __init__(
self,
content: ContentStream,
status_code: int = 200,
headers: typing.Optional[typing.Mapping[str, str]] = None,
media_type: typing.Optional[str] = None,
background: typing.Optional[BackgroundTask] = None,
) -> None:
if isinstance(content, typing.AsyncIterable):
self.body_iterator = content
else:
self.body_iterator = iterate_in_threadpool(content)
self.status_code = status_code
self.media_type = self.media_type if media_type is None else media_type
self.background = background
self.init_headers(headers)
async def listen_for_disconnect(self, receive: Receive) -> None:
while True:
message = await receive()
if message["type"] == "http.disconnect":
break
async def stream_response(self, send: Send) -> None:
await send(
{
"type": "http.response.start",
"status": self.status_code,
"headers": self.raw_headers,
}
)
async for chunk in self.body_iterator:
if not isinstance(chunk, bytes):
chunk = chunk.encode(self.charset)
await send({"type": "http.response.body", "body": chunk, "more_body": True})
await send({"type": "http.response.body", "body": b"", "more_body": False})
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
async with anyio.create_task_group() as task_group:
async def wrap(func: "typing.Callable[[], typing.Awaitable[None]]") -> None:
await func()
task_group.cancel_scope.cancel()
task_group.start_soon(wrap, partial(self.stream_response, send))
await wrap(partial(self.listen_for_disconnect, receive))
if self.background is not None:
await self.background()
class FileResponse(Response):
chunk_size = 64 * 1024
def __init__(
self,
path: typing.Union[str, "os.PathLike[str]"],
status_code: int = 200,
headers: typing.Optional[typing.Mapping[str, str]] = None,
media_type: typing.Optional[str] = None,
background: typing.Optional[BackgroundTask] = None,
filename: typing.Optional[str] = None,
stat_result: typing.Optional[os.stat_result] = None,
method: typing.Optional[str] = None,
content_disposition_type: str = "attachment",
) -> None:
self.path = path
self.status_code = status_code
self.filename = filename
self.send_header_only = method is not None and method.upper() == "HEAD"
if media_type is None:
media_type = guess_type(filename or path)[0] or "text/plain"
self.media_type = media_type
self.background = background
self.init_headers(headers)
if self.filename is not None:
content_disposition_filename = quote(self.filename)
if content_disposition_filename != self.filename:
content_disposition = "{}; filename*=utf-8''{}".format(
content_disposition_type, content_disposition_filename
)
else:
content_disposition = '{}; filename="{}"'.format(
content_disposition_type, self.filename
)
self.headers.setdefault("content-disposition", content_disposition)
self.stat_result = stat_result
if stat_result is not None:
self.set_stat_headers(stat_result)
def set_stat_headers(self, stat_result: os.stat_result) -> None:
content_length = str(stat_result.st_size)
last_modified = formatdate(stat_result.st_mtime, usegmt=True)
etag_base = str(stat_result.st_mtime) + "-" + str(stat_result.st_size)
etag = md5_hexdigest(etag_base.encode(), usedforsecurity=False)
self.headers.setdefault("content-length", content_length)
self.headers.setdefault("last-modified", last_modified)
self.headers.setdefault("etag", etag)
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
if self.stat_result is None:
try:
stat_result = await anyio.to_thread.run_sync(os.stat, self.path)
self.set_stat_headers(stat_result)
except FileNotFoundError:
raise RuntimeError(f"File at path {self.path} does not exist.")
else:
mode = stat_result.st_mode
if not stat.S_ISREG(mode):
raise RuntimeError(f"File at path {self.path} is not a file.")
await send(
{
"type": "http.response.start",
"status": self.status_code,
"headers": self.raw_headers,
}
)
if self.send_header_only:
await send({"type": "http.response.body", "body": b"", "more_body": False})
else:
async with await anyio.open_file(self.path, mode="rb") as file:
more_body = True
while more_body:
chunk = await file.read(self.chunk_size)
more_body = len(chunk) == self.chunk_size
await send(
{
"type": "http.response.body",
"body": chunk,
"more_body": more_body,
}
)
if self.background is not None:
await self.background()