diff --git a/starlette/datastructures.py b/starlette/datastructures.py index cfb24bf9..ba601764 100644 --- a/starlette/datastructures.py +++ b/starlette/datastructures.py @@ -195,3 +195,20 @@ class MutableHeaders(Headers): del self._list[idx] self._list.append((set_key, set_value)) + + def set_default(self, key: str, value: str): + set_key = key.lower().encode("latin-1") + set_value = value.encode("latin-1") + + is_set = False + pop_indexes = [] + for idx, (item_key, item_value) in enumerate(self._list): + if item_key == set_key: + if not is_set: + is_set = True + self._list[idx] = set_value + else: + pop_indexes.append(idx) + + for idx in reversed(pop_indexes): + del self._list[idx] diff --git a/starlette/response.py b/starlette/response.py index 28d0efd8..a3bc5a4b 100644 --- a/starlette/response.py +++ b/starlette/response.py @@ -2,6 +2,7 @@ from starlette.datastructures import MutableHeaders from starlette.types import Receive, Send import json import typing +import os class Response: @@ -17,46 +18,36 @@ class Response: ) -> None: self.body = self.render(content) self.status_code = status_code - if media_type is not None: - self.media_type = media_type - self.set_default_headers(headers) + self.media_type = self.media_type if media_type is None else media_type + self.raw_headers = [] if headers is None else [ + (k.lower().encode("latin-1"), v.encode("latin-1")) + for k, v in headers.items() + ] + self.headers = MutableHeaders(self.raw_headers) + self.set_default_headers() def render(self, content: typing.Any) -> bytes: if isinstance(content, bytes): return content return content.encode(self.charset) - def set_default_headers(self, headers: dict = None): - if headers is None: - raw_headers = [] - missing_content_length = True - missing_content_type = True - else: - raw_headers = [ - (k.lower().encode("latin-1"), v.encode("latin-1")) - for k, v in headers.items() - ] - missing_content_length = "content-length" not in headers - missing_content_type = "content-type" not in headers + def set_default_headers(self): + content_length = str(len(self.body)) if hasattr(self, 'body') else None + content_type = self.default_content_type - if missing_content_length: - content_length = str(len(self.body)).encode() - raw_headers.append((b"content-length", content_length)) - - if self.media_type is not None and missing_content_type: - content_type = self.media_type - if content_type.startswith("text/") and self.charset is not None: - content_type += "; charset=%s" % self.charset - content_type_value = content_type.encode("latin-1") - raw_headers.append((b"content-type", content_type_value)) - - self.raw_headers = raw_headers + if content_length is not None: + self.headers.set_default("content-length", content_length) + if content_type is not None: + self.headers.set_default("content-type", content_type) @property - def headers(self): - if not hasattr(self, "_headers"): - self._headers = MutableHeaders(self.raw_headers) - return self._headers + def default_content_type(self): + if self.media_type is None: + return None + + if self.media_type.startswith('text/') and self.charset is not None: + return '%s; charset=%s' % (self.media_type, self.charset) + return self.media_type async def __call__(self, receive: Receive, send: Send) -> None: await send( @@ -100,9 +91,13 @@ class StreamingResponse(Response): ) -> None: self.body_iterator = content self.status_code = status_code - if media_type is not None: - self.media_type = media_type - self.set_default_headers(headers) + self.media_type = self.media_type if media_type is None else media_type + self.raw_headers = [] if headers is None else [ + (k.lower().encode("latin-1"), v.encode("latin-1")) + for k, v in headers.items() + ] + self.headers = MutableHeaders(self.raw_headers) + self.set_default_headers() async def __call__(self, receive: Receive, send: Send) -> None: await send( @@ -120,22 +115,22 @@ class StreamingResponse(Response): await send({"type": "http.response.body", "body": chunk, "more_body": True}) await send({"type": "http.response.body", "body": b"", "more_body": False}) - def set_default_headers(self, headers: dict = None): - if headers is None: - raw_headers = [] - missing_content_type = True - else: - raw_headers = [ - (k.lower().encode("latin-1"), v.encode("latin-1")) - for k, v in headers.items() - ] - missing_content_type = "content-type" not in headers - - if self.media_type is not None and missing_content_type: - content_type = self.media_type - if content_type.startswith("text/") and self.charset is not None: - content_type += "; charset=%s" % self.charset - content_type_value = content_type.encode("latin-1") - raw_headers.append((b"content-type", content_type_value)) - - self.raw_headers = raw_headers +# +# class FileResponse: +# def __init__( +# self, +# path: str, +# headers: dict = None, +# media_type: str = None, +# filename: str = None +# ) -> None: +# self.path = path +# self.status_code = 200 +# if media_type is not None: +# self.media_type = media_type +# if filename is not None: +# self.filename = filename +# else: +# self.filename = os.path.basename(path) +# +# self.set_default_headers(headers)