From f985e4962926f7c8ffe14ec06e6e976e6d8c0688 Mon Sep 17 00:00:00 2001 From: Tom Christie Date: Wed, 11 Jul 2018 13:16:03 +0100 Subject: [PATCH 1/6] Add set_default --- starlette/datastructures.py | 17 ++++++ starlette/response.py | 101 +++++++++++++++++------------------- 2 files changed, 65 insertions(+), 53 deletions(-) diff --git a/starlette/datastructures.py b/starlette/datastructures.py index cfb24bf9..ba601764 100644 --- a/starlette/datastructures.py +++ b/starlette/datastructures.py @@ -195,3 +195,20 @@ class MutableHeaders(Headers): del self._list[idx] self._list.append((set_key, set_value)) + + def set_default(self, key: str, value: str): + set_key = key.lower().encode("latin-1") + set_value = value.encode("latin-1") + + is_set = False + pop_indexes = [] + for idx, (item_key, item_value) in enumerate(self._list): + if item_key == set_key: + if not is_set: + is_set = True + self._list[idx] = set_value + else: + pop_indexes.append(idx) + + for idx in reversed(pop_indexes): + del self._list[idx] diff --git a/starlette/response.py b/starlette/response.py index 28d0efd8..a3bc5a4b 100644 --- a/starlette/response.py +++ b/starlette/response.py @@ -2,6 +2,7 @@ from starlette.datastructures import MutableHeaders from starlette.types import Receive, Send import json import typing +import os class Response: @@ -17,46 +18,36 @@ class Response: ) -> None: self.body = self.render(content) self.status_code = status_code - if media_type is not None: - self.media_type = media_type - self.set_default_headers(headers) + self.media_type = self.media_type if media_type is None else media_type + self.raw_headers = [] if headers is None else [ + (k.lower().encode("latin-1"), v.encode("latin-1")) + for k, v in headers.items() + ] + self.headers = MutableHeaders(self.raw_headers) + self.set_default_headers() def render(self, content: typing.Any) -> bytes: if isinstance(content, bytes): return content return content.encode(self.charset) - def set_default_headers(self, headers: dict = None): - if headers is None: - raw_headers = [] - missing_content_length = True - missing_content_type = True - else: - raw_headers = [ - (k.lower().encode("latin-1"), v.encode("latin-1")) - for k, v in headers.items() - ] - missing_content_length = "content-length" not in headers - missing_content_type = "content-type" not in headers + def set_default_headers(self): + content_length = str(len(self.body)) if hasattr(self, 'body') else None + content_type = self.default_content_type - if missing_content_length: - content_length = str(len(self.body)).encode() - raw_headers.append((b"content-length", content_length)) - - if self.media_type is not None and missing_content_type: - content_type = self.media_type - if content_type.startswith("text/") and self.charset is not None: - content_type += "; charset=%s" % self.charset - content_type_value = content_type.encode("latin-1") - raw_headers.append((b"content-type", content_type_value)) - - self.raw_headers = raw_headers + if content_length is not None: + self.headers.set_default("content-length", content_length) + if content_type is not None: + self.headers.set_default("content-type", content_type) @property - def headers(self): - if not hasattr(self, "_headers"): - self._headers = MutableHeaders(self.raw_headers) - return self._headers + def default_content_type(self): + if self.media_type is None: + return None + + if self.media_type.startswith('text/') and self.charset is not None: + return '%s; charset=%s' % (self.media_type, self.charset) + return self.media_type async def __call__(self, receive: Receive, send: Send) -> None: await send( @@ -100,9 +91,13 @@ class StreamingResponse(Response): ) -> None: self.body_iterator = content self.status_code = status_code - if media_type is not None: - self.media_type = media_type - self.set_default_headers(headers) + self.media_type = self.media_type if media_type is None else media_type + self.raw_headers = [] if headers is None else [ + (k.lower().encode("latin-1"), v.encode("latin-1")) + for k, v in headers.items() + ] + self.headers = MutableHeaders(self.raw_headers) + self.set_default_headers() async def __call__(self, receive: Receive, send: Send) -> None: await send( @@ -120,22 +115,22 @@ class StreamingResponse(Response): await send({"type": "http.response.body", "body": chunk, "more_body": True}) await send({"type": "http.response.body", "body": b"", "more_body": False}) - def set_default_headers(self, headers: dict = None): - if headers is None: - raw_headers = [] - missing_content_type = True - else: - raw_headers = [ - (k.lower().encode("latin-1"), v.encode("latin-1")) - for k, v in headers.items() - ] - missing_content_type = "content-type" not in headers - - if self.media_type is not None and missing_content_type: - content_type = self.media_type - if content_type.startswith("text/") and self.charset is not None: - content_type += "; charset=%s" % self.charset - content_type_value = content_type.encode("latin-1") - raw_headers.append((b"content-type", content_type_value)) - - self.raw_headers = raw_headers +# +# class FileResponse: +# def __init__( +# self, +# path: str, +# headers: dict = None, +# media_type: str = None, +# filename: str = None +# ) -> None: +# self.path = path +# self.status_code = 200 +# if media_type is not None: +# self.media_type = media_type +# if filename is not None: +# self.filename = filename +# else: +# self.filename = os.path.basename(path) +# +# self.set_default_headers(headers) From 5f194f73bf9d789cb2b1eed3c48243d13f9b0a54 Mon Sep 17 00:00:00 2001 From: Tom Christie Date: Wed, 11 Jul 2018 16:08:51 +0100 Subject: [PATCH 2/6] Add FileResponse --- requirements.txt | 1 + setup.py | 1 + starlette/__init__.py | 2 + starlette/datastructures.py | 53 +++-- starlette/middleware/__init__.py | 4 + starlette/middleware/database.py | 27 +++ starlette/multipart.py | 365 +++++++++++++++++++++++++++++++ starlette/response.py | 123 ++++++----- 8 files changed, 509 insertions(+), 67 deletions(-) create mode 100644 starlette/middleware/__init__.py create mode 100644 starlette/middleware/database.py create mode 100644 starlette/multipart.py diff --git a/requirements.txt b/requirements.txt index 96b8130d..fbec2e51 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,3 +1,4 @@ +aiofiles requests # Testing diff --git a/setup.py b/setup.py index 3777b4ec..646c7613 100644 --- a/setup.py +++ b/setup.py @@ -44,6 +44,7 @@ setup( author_email='tom@tomchristie.com', packages=get_packages('starlette'), install_requires=[ + 'aiofiles', 'requests', ], classifiers=[ diff --git a/starlette/__init__.py b/starlette/__init__.py index 22ae9352..33d95e7c 100644 --- a/starlette/__init__.py +++ b/starlette/__init__.py @@ -1,5 +1,6 @@ from starlette.decorators import asgi_application from starlette.response import ( + FileResponse, HTMLResponse, JSONResponse, Response, @@ -13,6 +14,7 @@ from starlette.testclient import TestClient __all__ = ( "asgi_application", + "FileResponse", "HTMLResponse", "JSONResponse", "Path", diff --git a/starlette/datastructures.py b/starlette/datastructures.py index ba601764..71c52c04 100644 --- a/starlette/datastructures.py +++ b/starlette/datastructures.py @@ -119,16 +119,15 @@ class Headers(typing.Mapping[str, str]): An immutable, case-insensitive multidict. """ - def __init__(self, value: typing.Union[StrDict, StrPairs] = None) -> None: - if value is None: + def __init__(self, raw_headers=None) -> None: + if raw_headers is None: self._list = [] else: - assert isinstance(value, list) - for header_key, header_value in value: + for header_key, header_value in raw_headers: assert isinstance(header_key, bytes) assert isinstance(header_value, bytes) assert header_key == header_key.lower() - self._list = value + self._list = raw_headers def keys(self): return [key.decode("latin-1") for key, value in self._list] @@ -148,7 +147,7 @@ class Headers(typing.Mapping[str, str]): except KeyError: return default - def get_list(self, key: str) -> typing.List[str]: + def getlist(self, key: str) -> typing.List[str]: get_header_key = key.lower().encode("latin-1") return [ item_value.decode("latin-1") @@ -164,7 +163,11 @@ class Headers(typing.Mapping[str, str]): raise KeyError(key) def __contains__(self, key: str): - return key.lower() in self.keys() + get_header_key = key.lower().encode("latin-1") + for header_key, header_value in self._list: + if header_key == get_header_key: + return True + return False def __iter__(self): return iter(self.items()) @@ -183,6 +186,9 @@ class Headers(typing.Mapping[str, str]): class MutableHeaders(Headers): def __setitem__(self, key: str, value: str): + """ + Set the header `key` to `value`, removing any duplicate entries. + """ set_key = key.lower().encode("latin-1") set_value = value.encode("latin-1") @@ -196,19 +202,30 @@ class MutableHeaders(Headers): self._list.append((set_key, set_value)) - def set_default(self, key: str, value: str): + def __delitem__(self, key: str): + """ + Remove the header `key`. + """ + del_key = key.lower().encode("latin-1") + + pop_indexes = [] + for idx, (item_key, item_value) in enumerate(self._list): + if item_key == del_key: + pop_indexes.append(idx) + + for idx in reversed(pop_indexes): + del (self._list[idx]) + + def setdefault(self, key: str, value: str): + """ + If the header `key` does not exist, then set it to `value`. + Returns the header value. + """ set_key = key.lower().encode("latin-1") set_value = value.encode("latin-1") - is_set = False - pop_indexes = [] for idx, (item_key, item_value) in enumerate(self._list): if item_key == set_key: - if not is_set: - is_set = True - self._list[idx] = set_value - else: - pop_indexes.append(idx) - - for idx in reversed(pop_indexes): - del self._list[idx] + return item_value.decode("latin-1") + self._list.append((set_key, set_value)) + return value diff --git a/starlette/middleware/__init__.py b/starlette/middleware/__init__.py new file mode 100644 index 00000000..bdd37f2d --- /dev/null +++ b/starlette/middleware/__init__.py @@ -0,0 +1,4 @@ +from starlette.middleware.database import DatabaseMiddleware + + +__all__ = ["DatabaseMiddleware"] diff --git a/starlette/middleware/database.py b/starlette/middleware/database.py new file mode 100644 index 00000000..9fb5303b --- /dev/null +++ b/starlette/middleware/database.py @@ -0,0 +1,27 @@ +from functools import partial +from urllib.parse import urlparse +import asyncio +import asyncpg + + +class DatabaseMiddleware: + def __init__(self, app, database_url=None, database_config=None): + self.app = app + if database_config is None: + parsed = urlparse(database_url) + database_config = { + "user": parsed.user, + "password": parsed.password, + "database": parsed.database, + "host": parsed.host, + "port": parsed.port, + } + loop = asyncio.get_event_loop() + loop.run_until_complete(self.create_pool(database_config)) + + async def create_pool(self, database_config): + self.pool = await asyncpg.create_pool(**database_config) + + def __call__(self, scope): + scope["database"] = self.pool + return self.app(scope) diff --git a/starlette/multipart.py b/starlette/multipart.py new file mode 100644 index 00000000..c8965c10 --- /dev/null +++ b/starlette/multipart.py @@ -0,0 +1,365 @@ +_begin_form = "begin_form" +_begin_file = "begin_file" +_cont = "cont" +_end = "end" + + +class MultiPartParser(object): + def __init__( + self, + stream_factory=None, + charset="utf-8", + errors="replace", + max_form_memory_size=None, + cls=None, + buffer_size=64 * 1024, + ): + self.charset = charset + self.errors = errors + self.max_form_memory_size = max_form_memory_size + self.stream_factory = ( + default_stream_factory if stream_factory is None else stream_factory + ) + self.cls = MultiDict if cls is None else cls + + # make sure the buffer size is divisible by four so that we can base64 + # decode chunk by chunk + assert buffer_size % 4 == 0, "buffer size has to be divisible by 4" + # also the buffer size has to be at least 1024 bytes long or long headers + # will freak out the system + assert buffer_size >= 1024, "buffer size has to be at least 1KB" + + self.buffer_size = buffer_size + + def _fix_ie_filename(self, filename): + """Internet Explorer 6 transmits the full file name if a file is + uploaded. This function strips the full path if it thinks the + filename is Windows-like absolute. + """ + if filename[1:3] == ":\\" or filename[:2] == "\\\\": + return filename.split("\\")[-1] + return filename + + def _find_terminator(self, iterator): + """The terminator might have some additional newlines before it. + There is at least one application that sends additional newlines + before headers (the python setuptools package). + """ + for line in iterator: + if not line: + break + line = line.strip() + if line: + return line + return b"" + + def fail(self, message): + raise ValueError(message) + + def get_part_encoding(self, headers): + transfer_encoding = headers.get("content-transfer-encoding") + if ( + transfer_encoding is not None + and transfer_encoding in _supported_multipart_encodings + ): + return transfer_encoding + + def get_part_charset(self, headers): + # Figure out input charset for current part + content_type = headers.get("content-type") + if content_type: + mimetype, ct_params = parse_options_header(content_type) + return ct_params.get("charset", self.charset) + return self.charset + + def start_file_streaming(self, filename, headers, total_content_length): + if isinstance(filename, bytes): + filename = filename.decode(self.charset, self.errors) + filename = self._fix_ie_filename(filename) + content_type = headers.get("content-type") + try: + content_length = int(headers["content-length"]) + except (KeyError, ValueError): + content_length = 0 + container = self.stream_factory( + total_content_length, content_type, filename, content_length + ) + return filename, container + + def in_memory_threshold_reached(self, bytes): + raise exceptions.RequestEntityTooLarge() + + def validate_boundary(self, boundary): + if not boundary: + self.fail("Missing boundary") + if not is_valid_multipart_boundary(boundary): + self.fail("Invalid boundary: %s" % boundary) + if len(boundary) > self.buffer_size: # pragma: no cover + # this should never happen because we check for a minimum size + # of 1024 and boundaries may not be longer than 200. The only + # situation when this happens is for non debug builds where + # the assert is skipped. + self.fail("Boundary longer than buffer size") + + class LineSplitter(object): + def __init__(self, cap=None): + self.buffer = b"" + self.cap = cap + + def _splitlines(self, pre, post): + buf = pre + post + rv = [] + if not buf: + return rv, b"" + lines = buf.splitlines(True) + iv = b"" + for line in lines: + iv += line + while self.cap and len(iv) >= self.cap: + rv.append(iv[: self.cap]) + iv = iv[self.cap :] + if line[-1:] in b"\r\n": + rv.append(iv) + iv = b"" + # If this isn't the very end of the stream and what we got ends + # with \r, we need to hold on to it in case an \n comes next + if post and rv and not iv and rv[-1][-1:] == b"\r": + iv = rv[-1] + del rv[-1] + return rv, iv + + def feed(self, data): + lines, self.buffer = self._splitlines(self.buffer, data) + if not data: + lines += [self.buffer] + if self.buffer: + lines += [b""] + return lines + + class LineParser(object): + def __init__(self, parent, boundary): + self.parent = parent + self.boundary = boundary + self._next_part = b"--" + boundary + self._last_part = self._next_part + b"--" + self._state = self._state_pre_term + self._output = [] + self._headers = [] + self._tail = b"" + self._codec = None + + def _start_content(self): + disposition = self._headers.get("content-disposition") + if disposition is None: + raise ValueError("Missing Content-Disposition header") + self.disposition, extra = parse_options_header(disposition) + transfer_encoding = self.parent.get_part_encoding(self._headers) + if transfer_encoding is not None: + if transfer_encoding == "base64": + transfer_encoding = "base64_codec" + try: + self._codec = codecs.lookup(transfer_encoding) + except Exception: + raise ValueError( + "Cannot decode transfer-encoding: %r" % transfer_encoding + ) + self.name = extra.get("name") + self.filename = extra.get("filename") + if self.filename is not None: + self._output.append( + ("begin_file", (self._headers, self.name, self.filename)) + ) + else: + self._output.append(("begin_form", (self._headers, self.name))) + return self._state_output + + def _state_done(self, line): + return self._state_done + + def _state_output(self, line): + if not line: + raise ValueError("Unexpected end of file") + sline = line.rstrip() + if sline == self._last_part: + self._tail = b"" + self._output.append(("end", None)) + return self._state_done + elif sline == self._next_part: + self._tail = b"" + self._output.append(("end", None)) + self._headers = [] + return self._state_headers + + if self._codec: + try: + line, _ = self._codec.decode(line) + except Exception: + raise ValueError("Could not decode transfer-encoded chunk") + + # We don't know yet whether we can output the final newline, so + # we'll save it in self._tail and output it next time. + tail = self._tail + if line[-2:] == b"\r\n": + self._output.append(("cont", tail + line[:-2])) + self._tail = line[-2:] + elif line[-1:] in b"\r\n": + self._output.append(("cont", tail + line[:-1])) + self._tail = line[-1:] + else: + self._output.append(("cont", tail + line)) + self._tail = b"" + return self._state_output + + def _state_pre_term(self, line): + if not line: + raise ValueError("Unexpected end of file") + return self._state_pre_term + line = line.rstrip(b"\r\n") + if not line: + return self._state_pre_term + if line == self._last_part: + return self._state_done + elif line == self._next_part: + self._headers = [] + return self._state_headers + raise ValueError("Expected boundary at start of multipart data") + + def _state_headers(self, line): + if line is None: + raise ValueError("Unexpected end of file during headers") + line = to_native(line) + line, line_terminated = _line_parse(line) + if not line_terminated: + raise ValueError("Unexpected end of line in multipart header") + if not line: + self._headers = Headers(self._headers) + return self._start_content() + if line[0] in " \t" and self._headers: + key, value = self._headers[-1] + self._headers[-1] = (key, value + "\n " + line[1:]) + else: + parts = line.split(":", 1) + if len(parts) == 2: + self._headers.append((parts[0].strip(), parts[1].strip())) + else: + raise ValueError("Malformed header") + return self._state_headers + + def feed(self, lines): + self._output = [] + s = self._state + for line in lines: + s = s(line) + self._state = s + return self._output + + class PartParser(object): + def __init__(self, parent, content_length): + self.parent = parent + self.content_length = content_length + self._write = None + self._in_memory = 0 + self._guard_memory = False + + def _feed_one(self, event): + ev, data = event + p = self.parent + if ev == "begin_file": + self._headers, self._name, filename = data + self._filename, self._container = p.start_file_streaming( + filename, self._headers, self.content_length + ) + self._write = self._container.write + self._is_file = True + self._guard_memory = False + elif ev == "begin_form": + self._headers, self._name = data + self._container = [] + self._write = self._container.append + self._is_file = False + self._guard_memory = p.max_form_memory_size is not None + elif ev == "cont": + self._write(data) + if self._guard_memory: + self._in_memory += len(data) + if self._in_memory > p.max_form_memory_size: + p.in_memory_threshold_reached(self._in_memory) + elif ev == "end": + if self._is_file: + self._container.seek(0) + return ( + "file", + ( + self._name, + FileStorage( + self._container, + self._filename, + self._name, + headers=self._headers, + ), + ), + ) + else: + part_charset = p.get_part_charset(self._headers) + return ( + "form", + ( + self._name, + b"".join(self._container).decode(part_charset, p.errors), + ), + ) + + def feed(self, events): + rv = [] + for event in events: + v = self._feed_one(event) + if v is not None: + rv.append(v) + return rv + + def parse_lines(self, file, boundary, content_length, cap_at_buffer=True): + """Generate parts of + ``('begin_form', (headers, name))`` + ``('begin_file', (headers, name, filename))`` + ``('cont', bytestring)`` + ``('end', None)`` + Always obeys the grammar + parts = ( begin_form cont* end | + begin_file cont* end )* + """ + + line_splitter = self.LineSplitter(self.buffer_size if cap_at_buffer else None) + line_parser = self.LineParser(self, boundary) + while True: + buf = file.read(self.buffer_size) + lines = line_splitter.feed(buf) + parts = line_parser.feed(lines) + for part in parts: + yield part + if buf == b"": + break + + def parse_parts(self, file, boundary, content_length): + """Generate ``('file', (name, val))`` and + ``('form', (name, val))`` parts. + """ + line_splitter = self.LineSplitter() + line_parser = self.LineParser(self, boundary) + part_parser = self.PartParser(self, content_length) + while True: + buf = file.read(self.buffer_size) + lines = line_splitter.feed(buf) + parts = line_parser.feed(lines) + events = part_parser.feed(parts) + for event in events: + yield event + if buf == b"": + break + + def parse(self, file, boundary, content_length): + formstream, filestream = tee( + self.parse_parts(file, boundary, content_length), 2 + ) + form = (p[1] for p in formstream if p[0] == "form") + files = (p[1] for p in filestream if p[0] == "file") + return self.cls(form), self.cls(files) diff --git a/starlette/response.py b/starlette/response.py index a3bc5a4b..c3c98197 100644 --- a/starlette/response.py +++ b/starlette/response.py @@ -1,5 +1,7 @@ +from mimetypes import guess_type from starlette.datastructures import MutableHeaders from starlette.types import Receive, Send +import aiofiles import json import typing import os @@ -18,36 +20,47 @@ class Response: ) -> None: self.body = self.render(content) self.status_code = status_code - self.media_type = self.media_type if media_type is None else media_type - self.raw_headers = [] if headers is None else [ - (k.lower().encode("latin-1"), v.encode("latin-1")) - for k, v in headers.items() - ] - self.headers = MutableHeaders(self.raw_headers) - self.set_default_headers() + if media_type is not None: + self.media_type = media_type + self.init_headers(headers) def render(self, content: typing.Any) -> bytes: if isinstance(content, bytes): return content return content.encode(self.charset) - def set_default_headers(self): - content_length = str(len(self.body)) if hasattr(self, 'body') else None - content_type = self.default_content_type + def init_headers(self, headers): + if headers is None: + raw_headers = [] + populate_content_length = True + populate_content_type = True + else: + raw_headers = [ + (k.lower().encode("latin-1"), v.encode("latin-1")) + for k, v in headers.items() + ] + keys = [h[0] for h in raw_headers] + populate_content_length = b"content-length" in keys + populate_content_type = b"content-type" in keys - if content_length is not None: - self.headers.set_default("content-length", content_length) - if content_type is not None: - self.headers.set_default("content-type", content_type) + body = getattr(self, "body", None) + if body is not None and populate_content_length: + content_length = str(len(body)) + raw_headers.append((b"content-length", content_length.encode("latin-1"))) + + content_type = self.media_type + if content_type is not None and populate_content_type: + if content_type.startswith("text/"): + content_type += "; charset=" + self.charset + raw_headers.append((b"content-type", content_type.encode("latin-1"))) + + self.raw_headers = raw_headers @property - def default_content_type(self): - if self.media_type is None: - return None - - if self.media_type.startswith('text/') and self.charset is not None: - return '%s; charset=%s' % (self.media_type, self.charset) - return self.media_type + def headers(self): + if not hasattr(self, "_headers"): + self._headers = MutableHeaders(self.raw_headers) + return self._headers async def __call__(self, receive: Receive, send: Send) -> None: await send( @@ -92,21 +105,14 @@ class StreamingResponse(Response): self.body_iterator = content self.status_code = status_code self.media_type = self.media_type if media_type is None else media_type - self.raw_headers = [] if headers is None else [ - (k.lower().encode("latin-1"), v.encode("latin-1")) - for k, v in headers.items() - ] - self.headers = MutableHeaders(self.raw_headers) - self.set_default_headers() + self.init_headers(headers) async def __call__(self, receive: Receive, send: Send) -> None: await send( { "type": "http.response.start", "status": self.status_code, - "headers": [ - [key.encode(), value.encode()] for key, value in self.headers - ], + "headers": self.raw_headers, } ) async for chunk in self.body_iterator: @@ -115,22 +121,41 @@ class StreamingResponse(Response): await send({"type": "http.response.body", "body": chunk, "more_body": True}) await send({"type": "http.response.body", "body": b"", "more_body": False}) -# -# class FileResponse: -# def __init__( -# self, -# path: str, -# headers: dict = None, -# media_type: str = None, -# filename: str = None -# ) -> None: -# self.path = path -# self.status_code = 200 -# if media_type is not None: -# self.media_type = media_type -# if filename is not None: -# self.filename = filename -# else: -# self.filename = os.path.basename(path) -# -# self.set_default_headers(headers) + +class FileResponse(Response): + chunk_size = 4096 + + def __init__( + self, + path: str, + headers: dict = None, + media_type: str = None, + filename: str = None, + ) -> None: + self.path = path + self.status_code = 200 + self.filename = filename + if media_type is None: + media_type = guess_type(filename or path)[0] or "text/plain" + self.media_type = media_type + self.init_headers(headers) + if self.filename is not None: + content_disposition = 'attachment; filename="{}"'.format(self.filename) + self.headers.setdefault("content-disposition", content_disposition) + + async def __call__(self, receive: Receive, send: Send) -> None: + await send( + { + "type": "http.response.start", + "status": self.status_code, + "headers": self.raw_headers, + } + ) + async with aiofiles.open(self.path, mode="rb") as file: + more_body = True + while more_body: + chunk = await file.read(self.chunk_size) + more_body = len(chunk) == self.chunk_size + await send( + {"type": "http.response.body", "body": chunk, "more_body": False} + ) From 2e0bd33f59c9ffa1559bfe84cdfa8513f4b3ab41 Mon Sep 17 00:00:00 2001 From: Tom Christie Date: Wed, 11 Jul 2018 16:30:40 +0100 Subject: [PATCH 3/6] Use getlist, instead of get_list --- starlette/datastructures.py | 2 +- tests/test_datastructures.py | 23 ++++++++++++++++++----- tests/test_response.py | 28 ++++++++++++---------------- 3 files changed, 31 insertions(+), 22 deletions(-) diff --git a/starlette/datastructures.py b/starlette/datastructures.py index 71c52c04..247157ca 100644 --- a/starlette/datastructures.py +++ b/starlette/datastructures.py @@ -75,7 +75,7 @@ class QueryParams(typing.Mapping[str, str]): self._dict = {k: v for k, v in reversed(items)} self._list = items - def get_list(self, key: str) -> typing.List[str]: + def getlist(self, key: str) -> typing.List[str]: return [item_value for item_key, item_value in self._list if item_key == key] def keys(self): diff --git a/tests/test_datastructures.py b/tests/test_datastructures.py index c510d003..0a5a8441 100644 --- a/tests/test_datastructures.py +++ b/tests/test_datastructures.py @@ -1,4 +1,4 @@ -from starlette.datastructures import Headers, QueryParams, URL +from starlette.datastructures import Headers, MutableHeaders, QueryParams, URL def test_url(): @@ -25,7 +25,7 @@ def test_headers(): assert h["a"] == "123" assert h.get("a") == "123" assert h.get("nope", default=None) is None - assert h.get_list("a") == ["123", "456"] + assert h.getlist("a") == ["123", "456"] assert h.keys() == ["a", "a", "b"] assert h.values() == ["123", "456", "789"] assert h.items() == [("a", "123"), ("a", "456"), ("b", "789")] @@ -34,8 +34,21 @@ def test_headers(): assert repr(h) == "Headers([('a', '123'), ('a', '456'), ('b', '789')])" assert h == Headers([(b"a", b"123"), (b"b", b"789"), (b"a", b"456")]) assert h != [(b"a", b"123"), (b"A", b"456"), (b"b", b"789")] - h = Headers() - assert not h.items() + + +def test_mutable_headers(): + h = MutableHeaders() + assert dict(h) == {} + h["a"] = "1" + assert dict(h) == {"a": "1"} + h["a"] = "2" + assert dict(h) == {"a": "2"} + h.setdefault("a", "3") + assert dict(h) == {"a": "2"} + h.setdefault("b", "4") + assert dict(h) == {"a": "2", "b": "4"} + del h["a"] + assert dict(h) == {"b": "4"} def test_queryparams(): @@ -46,7 +59,7 @@ def test_queryparams(): assert q["a"] == "123" assert q.get("a") == "123" assert q.get("nope", default=None) is None - assert q.get_list("a") == ["123", "456"] + assert q.getlist("a") == ["123", "456"] assert q.keys() == ["a", "a", "b"] assert q.values() == ["123", "456", "789"] assert q.items() == [("a", "123"), ("a", "456"), ("b", "789")] diff --git a/tests/test_response.py b/tests/test_response.py index edcabb84..d4850fc8 100644 --- a/tests/test_response.py +++ b/tests/test_response.py @@ -1,4 +1,4 @@ -from starlette import Response, StreamingResponse, TestClient +from starlette import FileResponse, Response, StreamingResponse, TestClient import asyncio @@ -67,22 +67,18 @@ def test_response_headers(): assert response.headers["x-header-2"] == "789" -def test_streaming_response_headers(): +def test_file_response(tmpdir): + with open("xyz", "wb") as file: + file.write(b"") + def app(scope): - async def asgi(receive, send): - async def stream(msg): - yield "hello, world" - - headers = {"x-header-1": "123", "x-header-2": "456"} - response = StreamingResponse( - stream("hello, world"), media_type="text/plain", headers=headers - ) - response.headers["x-header-2"] = "789" - await response(receive, send) - - return asgi + return FileResponse(path="xyz", filename="example.png") client = TestClient(app) response = client.get("/") - assert response.headers["x-header-1"] == "123" - assert response.headers["x-header-2"] == "789" + assert response.status_code == 200 + assert response.content == b"" + assert response.headers["content-type"] == "image/png" + assert ( + response.headers["content-disposition"] == 'attachment; filename="example.png"' + ) From fd5647eae9b5f3705ec414baefad029b0f3c842c Mon Sep 17 00:00:00 2001 From: Tom Christie Date: Wed, 11 Jul 2018 16:32:48 +0100 Subject: [PATCH 4/6] Drop incomplete multipart and DatabaseMiddleware --- starlette/middleware/__init__.py | 4 - starlette/middleware/database.py | 27 --- starlette/multipart.py | 365 ------------------------------- 3 files changed, 396 deletions(-) delete mode 100644 starlette/middleware/__init__.py delete mode 100644 starlette/middleware/database.py delete mode 100644 starlette/multipart.py diff --git a/starlette/middleware/__init__.py b/starlette/middleware/__init__.py deleted file mode 100644 index bdd37f2d..00000000 --- a/starlette/middleware/__init__.py +++ /dev/null @@ -1,4 +0,0 @@ -from starlette.middleware.database import DatabaseMiddleware - - -__all__ = ["DatabaseMiddleware"] diff --git a/starlette/middleware/database.py b/starlette/middleware/database.py deleted file mode 100644 index 9fb5303b..00000000 --- a/starlette/middleware/database.py +++ /dev/null @@ -1,27 +0,0 @@ -from functools import partial -from urllib.parse import urlparse -import asyncio -import asyncpg - - -class DatabaseMiddleware: - def __init__(self, app, database_url=None, database_config=None): - self.app = app - if database_config is None: - parsed = urlparse(database_url) - database_config = { - "user": parsed.user, - "password": parsed.password, - "database": parsed.database, - "host": parsed.host, - "port": parsed.port, - } - loop = asyncio.get_event_loop() - loop.run_until_complete(self.create_pool(database_config)) - - async def create_pool(self, database_config): - self.pool = await asyncpg.create_pool(**database_config) - - def __call__(self, scope): - scope["database"] = self.pool - return self.app(scope) diff --git a/starlette/multipart.py b/starlette/multipart.py deleted file mode 100644 index c8965c10..00000000 --- a/starlette/multipart.py +++ /dev/null @@ -1,365 +0,0 @@ -_begin_form = "begin_form" -_begin_file = "begin_file" -_cont = "cont" -_end = "end" - - -class MultiPartParser(object): - def __init__( - self, - stream_factory=None, - charset="utf-8", - errors="replace", - max_form_memory_size=None, - cls=None, - buffer_size=64 * 1024, - ): - self.charset = charset - self.errors = errors - self.max_form_memory_size = max_form_memory_size - self.stream_factory = ( - default_stream_factory if stream_factory is None else stream_factory - ) - self.cls = MultiDict if cls is None else cls - - # make sure the buffer size is divisible by four so that we can base64 - # decode chunk by chunk - assert buffer_size % 4 == 0, "buffer size has to be divisible by 4" - # also the buffer size has to be at least 1024 bytes long or long headers - # will freak out the system - assert buffer_size >= 1024, "buffer size has to be at least 1KB" - - self.buffer_size = buffer_size - - def _fix_ie_filename(self, filename): - """Internet Explorer 6 transmits the full file name if a file is - uploaded. This function strips the full path if it thinks the - filename is Windows-like absolute. - """ - if filename[1:3] == ":\\" or filename[:2] == "\\\\": - return filename.split("\\")[-1] - return filename - - def _find_terminator(self, iterator): - """The terminator might have some additional newlines before it. - There is at least one application that sends additional newlines - before headers (the python setuptools package). - """ - for line in iterator: - if not line: - break - line = line.strip() - if line: - return line - return b"" - - def fail(self, message): - raise ValueError(message) - - def get_part_encoding(self, headers): - transfer_encoding = headers.get("content-transfer-encoding") - if ( - transfer_encoding is not None - and transfer_encoding in _supported_multipart_encodings - ): - return transfer_encoding - - def get_part_charset(self, headers): - # Figure out input charset for current part - content_type = headers.get("content-type") - if content_type: - mimetype, ct_params = parse_options_header(content_type) - return ct_params.get("charset", self.charset) - return self.charset - - def start_file_streaming(self, filename, headers, total_content_length): - if isinstance(filename, bytes): - filename = filename.decode(self.charset, self.errors) - filename = self._fix_ie_filename(filename) - content_type = headers.get("content-type") - try: - content_length = int(headers["content-length"]) - except (KeyError, ValueError): - content_length = 0 - container = self.stream_factory( - total_content_length, content_type, filename, content_length - ) - return filename, container - - def in_memory_threshold_reached(self, bytes): - raise exceptions.RequestEntityTooLarge() - - def validate_boundary(self, boundary): - if not boundary: - self.fail("Missing boundary") - if not is_valid_multipart_boundary(boundary): - self.fail("Invalid boundary: %s" % boundary) - if len(boundary) > self.buffer_size: # pragma: no cover - # this should never happen because we check for a minimum size - # of 1024 and boundaries may not be longer than 200. The only - # situation when this happens is for non debug builds where - # the assert is skipped. - self.fail("Boundary longer than buffer size") - - class LineSplitter(object): - def __init__(self, cap=None): - self.buffer = b"" - self.cap = cap - - def _splitlines(self, pre, post): - buf = pre + post - rv = [] - if not buf: - return rv, b"" - lines = buf.splitlines(True) - iv = b"" - for line in lines: - iv += line - while self.cap and len(iv) >= self.cap: - rv.append(iv[: self.cap]) - iv = iv[self.cap :] - if line[-1:] in b"\r\n": - rv.append(iv) - iv = b"" - # If this isn't the very end of the stream and what we got ends - # with \r, we need to hold on to it in case an \n comes next - if post and rv and not iv and rv[-1][-1:] == b"\r": - iv = rv[-1] - del rv[-1] - return rv, iv - - def feed(self, data): - lines, self.buffer = self._splitlines(self.buffer, data) - if not data: - lines += [self.buffer] - if self.buffer: - lines += [b""] - return lines - - class LineParser(object): - def __init__(self, parent, boundary): - self.parent = parent - self.boundary = boundary - self._next_part = b"--" + boundary - self._last_part = self._next_part + b"--" - self._state = self._state_pre_term - self._output = [] - self._headers = [] - self._tail = b"" - self._codec = None - - def _start_content(self): - disposition = self._headers.get("content-disposition") - if disposition is None: - raise ValueError("Missing Content-Disposition header") - self.disposition, extra = parse_options_header(disposition) - transfer_encoding = self.parent.get_part_encoding(self._headers) - if transfer_encoding is not None: - if transfer_encoding == "base64": - transfer_encoding = "base64_codec" - try: - self._codec = codecs.lookup(transfer_encoding) - except Exception: - raise ValueError( - "Cannot decode transfer-encoding: %r" % transfer_encoding - ) - self.name = extra.get("name") - self.filename = extra.get("filename") - if self.filename is not None: - self._output.append( - ("begin_file", (self._headers, self.name, self.filename)) - ) - else: - self._output.append(("begin_form", (self._headers, self.name))) - return self._state_output - - def _state_done(self, line): - return self._state_done - - def _state_output(self, line): - if not line: - raise ValueError("Unexpected end of file") - sline = line.rstrip() - if sline == self._last_part: - self._tail = b"" - self._output.append(("end", None)) - return self._state_done - elif sline == self._next_part: - self._tail = b"" - self._output.append(("end", None)) - self._headers = [] - return self._state_headers - - if self._codec: - try: - line, _ = self._codec.decode(line) - except Exception: - raise ValueError("Could not decode transfer-encoded chunk") - - # We don't know yet whether we can output the final newline, so - # we'll save it in self._tail and output it next time. - tail = self._tail - if line[-2:] == b"\r\n": - self._output.append(("cont", tail + line[:-2])) - self._tail = line[-2:] - elif line[-1:] in b"\r\n": - self._output.append(("cont", tail + line[:-1])) - self._tail = line[-1:] - else: - self._output.append(("cont", tail + line)) - self._tail = b"" - return self._state_output - - def _state_pre_term(self, line): - if not line: - raise ValueError("Unexpected end of file") - return self._state_pre_term - line = line.rstrip(b"\r\n") - if not line: - return self._state_pre_term - if line == self._last_part: - return self._state_done - elif line == self._next_part: - self._headers = [] - return self._state_headers - raise ValueError("Expected boundary at start of multipart data") - - def _state_headers(self, line): - if line is None: - raise ValueError("Unexpected end of file during headers") - line = to_native(line) - line, line_terminated = _line_parse(line) - if not line_terminated: - raise ValueError("Unexpected end of line in multipart header") - if not line: - self._headers = Headers(self._headers) - return self._start_content() - if line[0] in " \t" and self._headers: - key, value = self._headers[-1] - self._headers[-1] = (key, value + "\n " + line[1:]) - else: - parts = line.split(":", 1) - if len(parts) == 2: - self._headers.append((parts[0].strip(), parts[1].strip())) - else: - raise ValueError("Malformed header") - return self._state_headers - - def feed(self, lines): - self._output = [] - s = self._state - for line in lines: - s = s(line) - self._state = s - return self._output - - class PartParser(object): - def __init__(self, parent, content_length): - self.parent = parent - self.content_length = content_length - self._write = None - self._in_memory = 0 - self._guard_memory = False - - def _feed_one(self, event): - ev, data = event - p = self.parent - if ev == "begin_file": - self._headers, self._name, filename = data - self._filename, self._container = p.start_file_streaming( - filename, self._headers, self.content_length - ) - self._write = self._container.write - self._is_file = True - self._guard_memory = False - elif ev == "begin_form": - self._headers, self._name = data - self._container = [] - self._write = self._container.append - self._is_file = False - self._guard_memory = p.max_form_memory_size is not None - elif ev == "cont": - self._write(data) - if self._guard_memory: - self._in_memory += len(data) - if self._in_memory > p.max_form_memory_size: - p.in_memory_threshold_reached(self._in_memory) - elif ev == "end": - if self._is_file: - self._container.seek(0) - return ( - "file", - ( - self._name, - FileStorage( - self._container, - self._filename, - self._name, - headers=self._headers, - ), - ), - ) - else: - part_charset = p.get_part_charset(self._headers) - return ( - "form", - ( - self._name, - b"".join(self._container).decode(part_charset, p.errors), - ), - ) - - def feed(self, events): - rv = [] - for event in events: - v = self._feed_one(event) - if v is not None: - rv.append(v) - return rv - - def parse_lines(self, file, boundary, content_length, cap_at_buffer=True): - """Generate parts of - ``('begin_form', (headers, name))`` - ``('begin_file', (headers, name, filename))`` - ``('cont', bytestring)`` - ``('end', None)`` - Always obeys the grammar - parts = ( begin_form cont* end | - begin_file cont* end )* - """ - - line_splitter = self.LineSplitter(self.buffer_size if cap_at_buffer else None) - line_parser = self.LineParser(self, boundary) - while True: - buf = file.read(self.buffer_size) - lines = line_splitter.feed(buf) - parts = line_parser.feed(lines) - for part in parts: - yield part - if buf == b"": - break - - def parse_parts(self, file, boundary, content_length): - """Generate ``('file', (name, val))`` and - ``('form', (name, val))`` parts. - """ - line_splitter = self.LineSplitter() - line_parser = self.LineParser(self, boundary) - part_parser = self.PartParser(self, content_length) - while True: - buf = file.read(self.buffer_size) - lines = line_splitter.feed(buf) - parts = line_parser.feed(lines) - events = part_parser.feed(parts) - for event in events: - yield event - if buf == b"": - break - - def parse(self, file, boundary, content_length): - formstream, filestream = tee( - self.parse_parts(file, boundary, content_length), 2 - ) - form = (p[1] for p in formstream if p[0] == "form") - files = (p[1] for p in filestream if p[0] == "file") - return self.cls(form), self.cls(files) From c4cc8303d37dc70c42f3d509c1c9293315a84e17 Mon Sep 17 00:00:00 2001 From: Tom Christie Date: Wed, 11 Jul 2018 16:46:44 +0100 Subject: [PATCH 5/6] Add FileResponse to README --- README.md | 24 ++++++++++++++++++++++++ 1 file changed, 24 insertions(+) diff --git a/README.md b/README.md index c65f6a06..3f8211af 100644 --- a/README.md +++ b/README.md @@ -161,6 +161,30 @@ class App: await response(receive, send) ``` +### FileResponse + +Asynchronously streams a file as the response. + +Takes a different set of arguments to instantiate than the other response types: + +* `path` - The filepath to the file to stream. +* `headers` - Any custom headers to include, as a dictionary. +* `media_type` - A string giving the media type. If unset, the filename or path will be used to infer a media type. +* `filename` - If set, this will be included in the response `Content-Disposition`. + +```python +from starlette import FileResponse + + +class App: + def __init__(self, scope): + self.scope = scope + + async def __call__(self, receive, send): + response = FileResponse('/statics/favicon.ico') + await response(receive, send) +``` + --- ## Requests From 9303ffe6a338bf03d898cdaf9ff169eeb5fb6ca7 Mon Sep 17 00:00:00 2001 From: Tom Christie Date: Wed, 11 Jul 2018 16:47:18 +0100 Subject: [PATCH 6/6] Version 0.1.7 --- starlette/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/starlette/__init__.py b/starlette/__init__.py index 33d95e7c..26ac565a 100644 --- a/starlette/__init__.py +++ b/starlette/__init__.py @@ -26,4 +26,4 @@ __all__ = ( "Request", "TestClient", ) -__version__ = "0.1.6" +__version__ = "0.1.7"