From 96d50862e4468addc736ec0c4665055d6564ef95 Mon Sep 17 00:00:00 2001 From: Hiroyuki Tanaka <36621029+roy-freee@users.noreply.github.com> Date: Fri, 1 Feb 2019 18:06:33 +0900 Subject: [PATCH] Add content_type attribute to UploadFile (#371) * Add content_type attribute to UploadFile * Linting * Change: default value of UploadFile.content_type is from None to empty string --- starlette/datastructures.py | 5 ++- starlette/formparsers.py | 3 +- tests/test_formparsers.py | 73 ++++++++++++++++++++++++++++++++----- 3 files changed, 69 insertions(+), 12 deletions(-) diff --git a/starlette/datastructures.py b/starlette/datastructures.py index e8ea797d..b493ded5 100644 --- a/starlette/datastructures.py +++ b/starlette/datastructures.py @@ -422,8 +422,11 @@ class UploadFile: An uploaded file included as part of the request data. """ - def __init__(self, filename: str, file: typing.IO = None) -> None: + def __init__( + self, filename: str, file: typing.IO = None, content_type: str = "" + ) -> None: self.filename = filename + self.content_type = content_type if file is None: file = tempfile.SpooledTemporaryFile() self.file = file diff --git a/starlette/formparsers.py b/starlette/formparsers.py index 27161a83..cbbe8949 100644 --- a/starlette/formparsers.py +++ b/starlette/formparsers.py @@ -200,11 +200,12 @@ class MultiPartParser: elif message_type == MultiPartMessage.HEADERS_FINISHED: headers = Headers(raw=raw_headers) content_disposition = headers.get("Content-Disposition") + content_type = headers.get("Content-Type", "") disposition, options = parse_options_header(content_disposition) field_name = options[b"name"].decode("latin-1") if b"filename" in options: filename = options[b"filename"].decode("latin-1") - file = UploadFile(filename=filename) + file = UploadFile(filename=filename, content_type=content_type) else: file = None elif message_type == MultiPartMessage.PART_DATA: diff --git a/tests/test_formparsers.py b/tests/test_formparsers.py index 2f263fde..e22ebc89 100644 --- a/tests/test_formparsers.py +++ b/tests/test_formparsers.py @@ -23,7 +23,11 @@ def app(scope): for key, value in data.items(): if isinstance(value, UploadFile): content = await value.read() - output[key] = {"filename": value.filename, "content": content.decode()} + output[key] = { + "filename": value.filename, + "content": content.decode(), + "content_type": value.content_type, + } else: output[key] = value await request.close() @@ -44,7 +48,11 @@ def multi_items_app(scope): if isinstance(value, UploadFile): content = await value.read() output[key].append( - {"filename": value.filename, "content": content.decode()} + { + "filename": value.filename, + "content": content.decode(), + "content_type": value.content_type, + } ) else: output[key].append(value) @@ -86,7 +94,28 @@ def test_multipart_request_files(tmpdir): with open(path, "rb") as f: response = client.post("/", files={"test": f}) assert response.json() == { - "test": {"filename": "test.txt", "content": ""} + "test": { + "filename": "test.txt", + "content": "", + "content_type": "", + } + } + + +def test_multipart_request_files_with_content_type(tmpdir): + path = os.path.join(tmpdir, "test.txt") + with open(path, "wb") as file: + file.write(b"") + + client = TestClient(app) + with open(path, "rb") as f: + response = client.post("/", files={"test": ("test.txt", f, "text/plain")}) + assert response.json() == { + "test": { + "filename": "test.txt", + "content": "", + "content_type": "text/plain", + } } @@ -101,10 +130,20 @@ def test_multipart_request_multiple_files(tmpdir): client = TestClient(app) with open(path1, "rb") as f1, open(path2, "rb") as f2: - response = client.post("/", files={"test1": f1, "test2": f2}) + response = client.post( + "/", files={"test1": f1, "test2": ("test2.txt", f2, "text/plain")} + ) assert response.json() == { - "test1": {"filename": "test1.txt", "content": ""}, - "test2": {"filename": "test2.txt", "content": ""}, + "test1": { + "filename": "test1.txt", + "content": "", + "content_type": "", + }, + "test2": { + "filename": "test2.txt", + "content": "", + "content_type": "text/plain", + }, } @@ -120,13 +159,23 @@ def test_multi_items(tmpdir): client = TestClient(multi_items_app) with open(path1, "rb") as f1, open(path2, "rb") as f2: response = client.post( - "/", data=[("test1", "abc")], files=[("test1", f1), ("test1", f2)] + "/", + data=[("test1", "abc")], + files=[("test1", f1), ("test1", ("test2.txt", f2, "text/plain"))], ) assert response.json() == { "test1": [ "abc", - {"filename": "test1.txt", "content": ""}, - {"filename": "test2.txt", "content": ""}, + { + "filename": "test1.txt", + "content": "", + "content_type": "", + }, + { + "filename": "test2.txt", + "content": "", + "content_type": "text/plain", + }, ] } @@ -156,7 +205,11 @@ def test_multipart_request_mixed_files_and_data(tmpdir): }, ) assert response.json() == { - "file": {"filename": "file.txt", "content": ""}, + "file": { + "filename": "file.txt", + "content": "", + "content_type": "text/plain", + }, "field0": "value0", "field1": "value1", }