Add set_default

This commit is contained in:
Tom Christie 2018-07-11 13:16:03 +01:00
parent 5d66b939e2
commit f985e49629
2 changed files with 65 additions and 53 deletions

View File

@ -195,3 +195,20 @@ class MutableHeaders(Headers):
del self._list[idx]
self._list.append((set_key, set_value))
def set_default(self, key: str, value: str):
set_key = key.lower().encode("latin-1")
set_value = value.encode("latin-1")
is_set = False
pop_indexes = []
for idx, (item_key, item_value) in enumerate(self._list):
if item_key == set_key:
if not is_set:
is_set = True
self._list[idx] = set_value
else:
pop_indexes.append(idx)
for idx in reversed(pop_indexes):
del self._list[idx]

View File

@ -2,6 +2,7 @@ from starlette.datastructures import MutableHeaders
from starlette.types import Receive, Send
import json
import typing
import os
class Response:
@ -17,46 +18,36 @@ class Response:
) -> None:
self.body = self.render(content)
self.status_code = status_code
if media_type is not None:
self.media_type = media_type
self.set_default_headers(headers)
self.media_type = self.media_type if media_type is None else media_type
self.raw_headers = [] if headers is None else [
(k.lower().encode("latin-1"), v.encode("latin-1"))
for k, v in headers.items()
]
self.headers = MutableHeaders(self.raw_headers)
self.set_default_headers()
def render(self, content: typing.Any) -> bytes:
if isinstance(content, bytes):
return content
return content.encode(self.charset)
def set_default_headers(self, headers: dict = None):
if headers is None:
raw_headers = []
missing_content_length = True
missing_content_type = True
else:
raw_headers = [
(k.lower().encode("latin-1"), v.encode("latin-1"))
for k, v in headers.items()
]
missing_content_length = "content-length" not in headers
missing_content_type = "content-type" not in headers
def set_default_headers(self):
content_length = str(len(self.body)) if hasattr(self, 'body') else None
content_type = self.default_content_type
if missing_content_length:
content_length = str(len(self.body)).encode()
raw_headers.append((b"content-length", content_length))
if self.media_type is not None and missing_content_type:
content_type = self.media_type
if content_type.startswith("text/") and self.charset is not None:
content_type += "; charset=%s" % self.charset
content_type_value = content_type.encode("latin-1")
raw_headers.append((b"content-type", content_type_value))
self.raw_headers = raw_headers
if content_length is not None:
self.headers.set_default("content-length", content_length)
if content_type is not None:
self.headers.set_default("content-type", content_type)
@property
def headers(self):
if not hasattr(self, "_headers"):
self._headers = MutableHeaders(self.raw_headers)
return self._headers
def default_content_type(self):
if self.media_type is None:
return None
if self.media_type.startswith('text/') and self.charset is not None:
return '%s; charset=%s' % (self.media_type, self.charset)
return self.media_type
async def __call__(self, receive: Receive, send: Send) -> None:
await send(
@ -100,9 +91,13 @@ class StreamingResponse(Response):
) -> None:
self.body_iterator = content
self.status_code = status_code
if media_type is not None:
self.media_type = media_type
self.set_default_headers(headers)
self.media_type = self.media_type if media_type is None else media_type
self.raw_headers = [] if headers is None else [
(k.lower().encode("latin-1"), v.encode("latin-1"))
for k, v in headers.items()
]
self.headers = MutableHeaders(self.raw_headers)
self.set_default_headers()
async def __call__(self, receive: Receive, send: Send) -> None:
await send(
@ -120,22 +115,22 @@ class StreamingResponse(Response):
await send({"type": "http.response.body", "body": chunk, "more_body": True})
await send({"type": "http.response.body", "body": b"", "more_body": False})
def set_default_headers(self, headers: dict = None):
if headers is None:
raw_headers = []
missing_content_type = True
else:
raw_headers = [
(k.lower().encode("latin-1"), v.encode("latin-1"))
for k, v in headers.items()
]
missing_content_type = "content-type" not in headers
if self.media_type is not None and missing_content_type:
content_type = self.media_type
if content_type.startswith("text/") and self.charset is not None:
content_type += "; charset=%s" % self.charset
content_type_value = content_type.encode("latin-1")
raw_headers.append((b"content-type", content_type_value))
self.raw_headers = raw_headers
#
# class FileResponse:
# def __init__(
# self,
# path: str,
# headers: dict = None,
# media_type: str = None,
# filename: str = None
# ) -> None:
# self.path = path
# self.status_code = 200
# if media_type is not None:
# self.media_type = media_type
# if filename is not None:
# self.filename = filename
# else:
# self.filename = os.path.basename(path)
#
# self.set_default_headers(headers)