mirror of https://github.com/encode/starlette.git
Add set_default
This commit is contained in:
parent
5d66b939e2
commit
f985e49629
|
@ -195,3 +195,20 @@ class MutableHeaders(Headers):
|
|||
del self._list[idx]
|
||||
|
||||
self._list.append((set_key, set_value))
|
||||
|
||||
def set_default(self, key: str, value: str):
|
||||
set_key = key.lower().encode("latin-1")
|
||||
set_value = value.encode("latin-1")
|
||||
|
||||
is_set = False
|
||||
pop_indexes = []
|
||||
for idx, (item_key, item_value) in enumerate(self._list):
|
||||
if item_key == set_key:
|
||||
if not is_set:
|
||||
is_set = True
|
||||
self._list[idx] = set_value
|
||||
else:
|
||||
pop_indexes.append(idx)
|
||||
|
||||
for idx in reversed(pop_indexes):
|
||||
del self._list[idx]
|
||||
|
|
|
@ -2,6 +2,7 @@ from starlette.datastructures import MutableHeaders
|
|||
from starlette.types import Receive, Send
|
||||
import json
|
||||
import typing
|
||||
import os
|
||||
|
||||
|
||||
class Response:
|
||||
|
@ -17,46 +18,36 @@ class Response:
|
|||
) -> None:
|
||||
self.body = self.render(content)
|
||||
self.status_code = status_code
|
||||
if media_type is not None:
|
||||
self.media_type = media_type
|
||||
self.set_default_headers(headers)
|
||||
self.media_type = self.media_type if media_type is None else media_type
|
||||
self.raw_headers = [] if headers is None else [
|
||||
(k.lower().encode("latin-1"), v.encode("latin-1"))
|
||||
for k, v in headers.items()
|
||||
]
|
||||
self.headers = MutableHeaders(self.raw_headers)
|
||||
self.set_default_headers()
|
||||
|
||||
def render(self, content: typing.Any) -> bytes:
|
||||
if isinstance(content, bytes):
|
||||
return content
|
||||
return content.encode(self.charset)
|
||||
|
||||
def set_default_headers(self, headers: dict = None):
|
||||
if headers is None:
|
||||
raw_headers = []
|
||||
missing_content_length = True
|
||||
missing_content_type = True
|
||||
else:
|
||||
raw_headers = [
|
||||
(k.lower().encode("latin-1"), v.encode("latin-1"))
|
||||
for k, v in headers.items()
|
||||
]
|
||||
missing_content_length = "content-length" not in headers
|
||||
missing_content_type = "content-type" not in headers
|
||||
def set_default_headers(self):
|
||||
content_length = str(len(self.body)) if hasattr(self, 'body') else None
|
||||
content_type = self.default_content_type
|
||||
|
||||
if missing_content_length:
|
||||
content_length = str(len(self.body)).encode()
|
||||
raw_headers.append((b"content-length", content_length))
|
||||
|
||||
if self.media_type is not None and missing_content_type:
|
||||
content_type = self.media_type
|
||||
if content_type.startswith("text/") and self.charset is not None:
|
||||
content_type += "; charset=%s" % self.charset
|
||||
content_type_value = content_type.encode("latin-1")
|
||||
raw_headers.append((b"content-type", content_type_value))
|
||||
|
||||
self.raw_headers = raw_headers
|
||||
if content_length is not None:
|
||||
self.headers.set_default("content-length", content_length)
|
||||
if content_type is not None:
|
||||
self.headers.set_default("content-type", content_type)
|
||||
|
||||
@property
|
||||
def headers(self):
|
||||
if not hasattr(self, "_headers"):
|
||||
self._headers = MutableHeaders(self.raw_headers)
|
||||
return self._headers
|
||||
def default_content_type(self):
|
||||
if self.media_type is None:
|
||||
return None
|
||||
|
||||
if self.media_type.startswith('text/') and self.charset is not None:
|
||||
return '%s; charset=%s' % (self.media_type, self.charset)
|
||||
return self.media_type
|
||||
|
||||
async def __call__(self, receive: Receive, send: Send) -> None:
|
||||
await send(
|
||||
|
@ -100,9 +91,13 @@ class StreamingResponse(Response):
|
|||
) -> None:
|
||||
self.body_iterator = content
|
||||
self.status_code = status_code
|
||||
if media_type is not None:
|
||||
self.media_type = media_type
|
||||
self.set_default_headers(headers)
|
||||
self.media_type = self.media_type if media_type is None else media_type
|
||||
self.raw_headers = [] if headers is None else [
|
||||
(k.lower().encode("latin-1"), v.encode("latin-1"))
|
||||
for k, v in headers.items()
|
||||
]
|
||||
self.headers = MutableHeaders(self.raw_headers)
|
||||
self.set_default_headers()
|
||||
|
||||
async def __call__(self, receive: Receive, send: Send) -> None:
|
||||
await send(
|
||||
|
@ -120,22 +115,22 @@ class StreamingResponse(Response):
|
|||
await send({"type": "http.response.body", "body": chunk, "more_body": True})
|
||||
await send({"type": "http.response.body", "body": b"", "more_body": False})
|
||||
|
||||
def set_default_headers(self, headers: dict = None):
|
||||
if headers is None:
|
||||
raw_headers = []
|
||||
missing_content_type = True
|
||||
else:
|
||||
raw_headers = [
|
||||
(k.lower().encode("latin-1"), v.encode("latin-1"))
|
||||
for k, v in headers.items()
|
||||
]
|
||||
missing_content_type = "content-type" not in headers
|
||||
|
||||
if self.media_type is not None and missing_content_type:
|
||||
content_type = self.media_type
|
||||
if content_type.startswith("text/") and self.charset is not None:
|
||||
content_type += "; charset=%s" % self.charset
|
||||
content_type_value = content_type.encode("latin-1")
|
||||
raw_headers.append((b"content-type", content_type_value))
|
||||
|
||||
self.raw_headers = raw_headers
|
||||
#
|
||||
# class FileResponse:
|
||||
# def __init__(
|
||||
# self,
|
||||
# path: str,
|
||||
# headers: dict = None,
|
||||
# media_type: str = None,
|
||||
# filename: str = None
|
||||
# ) -> None:
|
||||
# self.path = path
|
||||
# self.status_code = 200
|
||||
# if media_type is not None:
|
||||
# self.media_type = media_type
|
||||
# if filename is not None:
|
||||
# self.filename = filename
|
||||
# else:
|
||||
# self.filename = os.path.basename(path)
|
||||
#
|
||||
# self.set_default_headers(headers)
|
||||
|
|
Loading…
Reference in New Issue