import os import typing from contextlib import nullcontext as does_not_raise import pytest from starlette.applications import Starlette from starlette.formparsers import MultiPartException, UploadFile, _user_safe_decode from starlette.requests import Request from starlette.responses import JSONResponse from starlette.routing import Mount class ForceMultipartDict(dict): def __bool__(self): return True # FORCE_MULTIPART is an empty dict that boolean-evaluates as `True`. FORCE_MULTIPART = ForceMultipartDict() async def app(scope, receive, send): request = Request(scope, receive) data = await request.form() output: typing.Dict[str, typing.Any] = {} for key, value in data.items(): if isinstance(value, UploadFile): content = await value.read() output[key] = { "filename": value.filename, "content": content.decode(), "content_type": value.content_type, } else: output[key] = value await request.close() response = JSONResponse(output) await response(scope, receive, send) async def multi_items_app(scope, receive, send): request = Request(scope, receive) data = await request.form() output: typing.Dict[str, list] = {} for key, value in data.multi_items(): if key not in output: output[key] = [] if isinstance(value, UploadFile): content = await value.read() output[key].append( { "filename": value.filename, "content": content.decode(), "content_type": value.content_type, } ) else: output[key].append(value) await request.close() response = JSONResponse(output) await response(scope, receive, send) async def app_with_headers(scope, receive, send): request = Request(scope, receive) data = await request.form() output: typing.Dict[str, typing.Any] = {} for key, value in data.items(): if isinstance(value, UploadFile): content = await value.read() output[key] = { "filename": value.filename, "content": content.decode(), "content_type": value.content_type, "headers": list(value.headers.items()), } else: output[key] = value await request.close() response = JSONResponse(output) await response(scope, receive, send) async def app_read_body(scope, receive, send): request = Request(scope, receive) # Read bytes, to force request.stream() to return the already parsed body await request.body() data = await request.form() output = {} for key, value in data.items(): output[key] = value await request.close() response = JSONResponse(output) await response(scope, receive, send) def test_multipart_request_data(tmpdir, test_client_factory): client = test_client_factory(app) response = client.post("/", data={"some": "data"}, files=FORCE_MULTIPART) assert response.json() == {"some": "data"} def test_multipart_request_files(tmpdir, test_client_factory): path = os.path.join(tmpdir, "test.txt") with open(path, "wb") as file: file.write(b"") client = test_client_factory(app) with open(path, "rb") as f: response = client.post("/", files={"test": f}) assert response.json() == { "test": { "filename": "test.txt", "content": "", "content_type": "text/plain", } } def test_multipart_request_files_with_content_type(tmpdir, test_client_factory): path = os.path.join(tmpdir, "test.txt") with open(path, "wb") as file: file.write(b"") client = test_client_factory(app) with open(path, "rb") as f: response = client.post("/", files={"test": ("test.txt", f, "text/plain")}) assert response.json() == { "test": { "filename": "test.txt", "content": "", "content_type": "text/plain", } } def test_multipart_request_multiple_files(tmpdir, test_client_factory): path1 = os.path.join(tmpdir, "test1.txt") with open(path1, "wb") as file: file.write(b"") path2 = os.path.join(tmpdir, "test2.txt") with open(path2, "wb") as file: file.write(b"") client = test_client_factory(app) with open(path1, "rb") as f1, open(path2, "rb") as f2: response = client.post( "/", files={"test1": f1, "test2": ("test2.txt", f2, "text/plain")} ) assert response.json() == { "test1": { "filename": "test1.txt", "content": "", "content_type": "text/plain", }, "test2": { "filename": "test2.txt", "content": "", "content_type": "text/plain", }, } def test_multipart_request_multiple_files_with_headers(tmpdir, test_client_factory): path1 = os.path.join(tmpdir, "test1.txt") with open(path1, "wb") as file: file.write(b"") path2 = os.path.join(tmpdir, "test2.txt") with open(path2, "wb") as file: file.write(b"") client = test_client_factory(app_with_headers) with open(path1, "rb") as f1, open(path2, "rb") as f2: response = client.post( "/", files=[ ("test1", (None, f1)), ("test2", ("test2.txt", f2, "text/plain", {"x-custom": "f2"})), ], ) assert response.json() == { "test1": "", "test2": { "filename": "test2.txt", "content": "", "content_type": "text/plain", "headers": [ [ "content-disposition", 'form-data; name="test2"; filename="test2.txt"', ], ["x-custom", "f2"], ["content-type", "text/plain"], ], }, } def test_multi_items(tmpdir, test_client_factory): path1 = os.path.join(tmpdir, "test1.txt") with open(path1, "wb") as file: file.write(b"") path2 = os.path.join(tmpdir, "test2.txt") with open(path2, "wb") as file: file.write(b"") client = test_client_factory(multi_items_app) with open(path1, "rb") as f1, open(path2, "rb") as f2: response = client.post( "/", data={"test1": "abc"}, files=[("test1", f1), ("test1", ("test2.txt", f2, "text/plain"))], ) assert response.json() == { "test1": [ "abc", { "filename": "test1.txt", "content": "", "content_type": "text/plain", }, { "filename": "test2.txt", "content": "", "content_type": "text/plain", }, ] } def test_multipart_request_mixed_files_and_data(tmpdir, test_client_factory): client = test_client_factory(app) response = client.post( "/", data=( # data b"--a7f7ac8d4e2e437c877bb7b8d7cc549c\r\n" b'Content-Disposition: form-data; name="field0"\r\n\r\n' b"value0\r\n" # file b"--a7f7ac8d4e2e437c877bb7b8d7cc549c\r\n" b'Content-Disposition: form-data; name="file"; filename="file.txt"\r\n' b"Content-Type: text/plain\r\n\r\n" b"\r\n" # data b"--a7f7ac8d4e2e437c877bb7b8d7cc549c\r\n" b'Content-Disposition: form-data; name="field1"\r\n\r\n' b"value1\r\n" b"--a7f7ac8d4e2e437c877bb7b8d7cc549c--\r\n" ), headers={ "Content-Type": ( "multipart/form-data; boundary=a7f7ac8d4e2e437c877bb7b8d7cc549c" ) }, ) assert response.json() == { "file": { "filename": "file.txt", "content": "", "content_type": "text/plain", }, "field0": "value0", "field1": "value1", } def test_multipart_request_with_charset_for_filename(tmpdir, test_client_factory): client = test_client_factory(app) response = client.post( "/", data=( # file b"--a7f7ac8d4e2e437c877bb7b8d7cc549c\r\n" b'Content-Disposition: form-data; name="file"; filename="\xe6\x96\x87\xe6\x9b\xb8.txt"\r\n' # noqa: E501 b"Content-Type: text/plain\r\n\r\n" b"\r\n" b"--a7f7ac8d4e2e437c877bb7b8d7cc549c--\r\n" ), headers={ "Content-Type": ( "multipart/form-data; charset=utf-8; " "boundary=a7f7ac8d4e2e437c877bb7b8d7cc549c" ) }, ) assert response.json() == { "file": { "filename": "文書.txt", "content": "", "content_type": "text/plain", } } def test_multipart_request_without_charset_for_filename(tmpdir, test_client_factory): client = test_client_factory(app) response = client.post( "/", data=( # file b"--a7f7ac8d4e2e437c877bb7b8d7cc549c\r\n" b'Content-Disposition: form-data; name="file"; filename="\xe7\x94\xbb\xe5\x83\x8f.jpg"\r\n' # noqa: E501 b"Content-Type: image/jpeg\r\n\r\n" b"\r\n" b"--a7f7ac8d4e2e437c877bb7b8d7cc549c--\r\n" ), headers={ "Content-Type": ( "multipart/form-data; boundary=a7f7ac8d4e2e437c877bb7b8d7cc549c" ) }, ) assert response.json() == { "file": { "filename": "画像.jpg", "content": "", "content_type": "image/jpeg", } } def test_multipart_request_with_encoded_value(tmpdir, test_client_factory): client = test_client_factory(app) response = client.post( "/", data=( b"--20b303e711c4ab8c443184ac833ab00f\r\n" b"Content-Disposition: form-data; " b'name="value"\r\n\r\n' b"Transf\xc3\xa9rer\r\n" b"--20b303e711c4ab8c443184ac833ab00f--\r\n" ), headers={ "Content-Type": ( "multipart/form-data; charset=utf-8; " "boundary=20b303e711c4ab8c443184ac833ab00f" ) }, ) assert response.json() == {"value": "Transférer"} def test_urlencoded_request_data(tmpdir, test_client_factory): client = test_client_factory(app) response = client.post("/", data={"some": "data"}) assert response.json() == {"some": "data"} def test_no_request_data(tmpdir, test_client_factory): client = test_client_factory(app) response = client.post("/") assert response.json() == {} def test_urlencoded_percent_encoding(tmpdir, test_client_factory): client = test_client_factory(app) response = client.post("/", data={"some": "da ta"}) assert response.json() == {"some": "da ta"} def test_urlencoded_percent_encoding_keys(tmpdir, test_client_factory): client = test_client_factory(app) response = client.post("/", data={"so me": "data"}) assert response.json() == {"so me": "data"} def test_urlencoded_multi_field_app_reads_body(tmpdir, test_client_factory): client = test_client_factory(app_read_body) response = client.post("/", data={"some": "data", "second": "key pair"}) assert response.json() == {"some": "data", "second": "key pair"} def test_multipart_multi_field_app_reads_body(tmpdir, test_client_factory): client = test_client_factory(app_read_body) response = client.post( "/", data={"some": "data", "second": "key pair"}, files=FORCE_MULTIPART ) assert response.json() == {"some": "data", "second": "key pair"} def test_user_safe_decode_helper(): result = _user_safe_decode(b"\xc4\x99\xc5\xbc\xc4\x87", "utf-8") assert result == "ężć" def test_user_safe_decode_ignores_wrong_charset(): result = _user_safe_decode(b"abc", "latin-8") assert result == "abc" @pytest.mark.parametrize( "app,expectation", [ (app, pytest.raises(MultiPartException)), (Starlette(routes=[Mount("/", app=app)]), does_not_raise()), ], ) def test_missing_boundary_parameter(app, expectation, test_client_factory) -> None: client = test_client_factory(app) with expectation: res = client.post( "/", data=( # file b'Content-Disposition: form-data; name="file"; filename="\xe6\x96\x87\xe6\x9b\xb8.txt"\r\n' # noqa: E501 b"Content-Type: text/plain\r\n\r\n" b"\r\n" ), headers={"Content-Type": "multipart/form-data; charset=utf-8"}, ) assert res.status_code == 400 assert res.text == "Missing boundary in multipart." @pytest.mark.parametrize( "app,expectation", [ (app, pytest.raises(MultiPartException)), (Starlette(routes=[Mount("/", app=app)]), does_not_raise()), ], ) def test_missing_name_parameter_on_content_disposition( app, expectation, test_client_factory ): client = test_client_factory(app) with expectation: res = client.post( "/", data=( # data b"--a7f7ac8d4e2e437c877bb7b8d7cc549c\r\n" b'Content-Disposition: form-data; ="field0"\r\n\r\n' b"value0\r\n" ), headers={ "Content-Type": ( "multipart/form-data; boundary=a7f7ac8d4e2e437c877bb7b8d7cc549c" ) }, ) assert res.status_code == 400 assert ( res.text == 'The Content-Disposition header field "name" must be provided.' )