from __future__ import annotations import datetime as dt import time from http.cookies import SimpleCookie from pathlib import Path from typing import AsyncIterator, Iterator import anyio import pytest from starlette import status from starlette.background import BackgroundTask from starlette.datastructures import Headers from starlette.requests import Request from starlette.responses import ( FileResponse, JSONResponse, RedirectResponse, Response, StreamingResponse, ) from starlette.testclient import TestClient from starlette.types import Message, Receive, Scope, Send from tests.types import TestClientFactory def test_text_response(test_client_factory: TestClientFactory) -> None: async def app(scope: Scope, receive: Receive, send: Send) -> None: response = Response("hello, world", media_type="text/plain") await response(scope, receive, send) client = test_client_factory(app) response = client.get("/") assert response.text == "hello, world" def test_bytes_response(test_client_factory: TestClientFactory) -> None: async def app(scope: Scope, receive: Receive, send: Send) -> None: response = Response(b"xxxxx", media_type="image/png") await response(scope, receive, send) client = test_client_factory(app) response = client.get("/") assert response.content == b"xxxxx" def test_json_none_response(test_client_factory: TestClientFactory) -> None: async def app(scope: Scope, receive: Receive, send: Send) -> None: response = JSONResponse(None) await response(scope, receive, send) client = test_client_factory(app) response = client.get("/") assert response.json() is None assert response.content == b"null" def test_redirect_response(test_client_factory: TestClientFactory) -> None: async def app(scope: Scope, receive: Receive, send: Send) -> None: if scope["path"] == "/": response = Response("hello, world", media_type="text/plain") else: response = RedirectResponse("/") await response(scope, receive, send) client = test_client_factory(app) response = client.get("/redirect") assert response.text == "hello, world" assert response.url == "http://testserver/" def test_quoting_redirect_response(test_client_factory: TestClientFactory) -> None: async def app(scope: Scope, receive: Receive, send: Send) -> None: if scope["path"] == "/I ♥ Starlette/": response = Response("hello, world", media_type="text/plain") else: response = RedirectResponse("/I ♥ Starlette/") await response(scope, receive, send) client = test_client_factory(app) response = client.get("/redirect") assert response.text == "hello, world" assert response.url == "http://testserver/I%20%E2%99%A5%20Starlette/" def test_redirect_response_content_length_header( test_client_factory: TestClientFactory, ) -> None: async def app(scope: Scope, receive: Receive, send: Send) -> None: if scope["path"] == "/": response = Response("hello", media_type="text/plain") # pragma: nocover else: response = RedirectResponse("/") await response(scope, receive, send) client: TestClient = test_client_factory(app) response = client.request("GET", "/redirect", allow_redirects=False) assert response.url == "http://testserver/redirect" assert response.headers["content-length"] == "0" def test_streaming_response(test_client_factory: TestClientFactory) -> None: filled_by_bg_task = "" async def app(scope: Scope, receive: Receive, send: Send) -> None: async def numbers(minimum: int, maximum: int) -> AsyncIterator[str]: for i in range(minimum, maximum + 1): yield str(i) if i != maximum: yield ", " await anyio.sleep(0) async def numbers_for_cleanup(start: int = 1, stop: int = 5) -> None: nonlocal filled_by_bg_task async for thing in numbers(start, stop): filled_by_bg_task = filled_by_bg_task + thing cleanup_task = BackgroundTask(numbers_for_cleanup, start=6, stop=9) generator = numbers(1, 5) response = StreamingResponse(generator, media_type="text/plain", background=cleanup_task) await response(scope, receive, send) assert filled_by_bg_task == "" client = test_client_factory(app) response = client.get("/") assert response.text == "1, 2, 3, 4, 5" assert filled_by_bg_task == "6, 7, 8, 9" def test_streaming_response_custom_iterator( test_client_factory: TestClientFactory, ) -> None: async def app(scope: Scope, receive: Receive, send: Send) -> None: class CustomAsyncIterator: def __init__(self) -> None: self._called = 0 def __aiter__(self) -> AsyncIterator[str]: return self async def __anext__(self) -> str: if self._called == 5: raise StopAsyncIteration() self._called += 1 return str(self._called) response = StreamingResponse(CustomAsyncIterator(), media_type="text/plain") await response(scope, receive, send) client = test_client_factory(app) response = client.get("/") assert response.text == "12345" def test_streaming_response_custom_iterable( test_client_factory: TestClientFactory, ) -> None: async def app(scope: Scope, receive: Receive, send: Send) -> None: class CustomAsyncIterable: async def __aiter__(self) -> AsyncIterator[str | bytes]: for i in range(5): yield str(i + 1) response = StreamingResponse(CustomAsyncIterable(), media_type="text/plain") await response(scope, receive, send) client = test_client_factory(app) response = client.get("/") assert response.text == "12345" def test_sync_streaming_response(test_client_factory: TestClientFactory) -> None: async def app(scope: Scope, receive: Receive, send: Send) -> None: def numbers(minimum: int, maximum: int) -> Iterator[str]: for i in range(minimum, maximum + 1): yield str(i) if i != maximum: yield ", " generator = numbers(1, 5) response = StreamingResponse(generator, media_type="text/plain") await response(scope, receive, send) client = test_client_factory(app) response = client.get("/") assert response.text == "1, 2, 3, 4, 5" def test_response_headers(test_client_factory: TestClientFactory) -> None: async def app(scope: Scope, receive: Receive, send: Send) -> None: headers = {"x-header-1": "123", "x-header-2": "456"} response = Response("hello, world", media_type="text/plain", headers=headers) response.headers["x-header-2"] = "789" await response(scope, receive, send) client = test_client_factory(app) response = client.get("/") assert response.headers["x-header-1"] == "123" assert response.headers["x-header-2"] == "789" def test_response_phrase(test_client_factory: TestClientFactory) -> None: app = Response(status_code=204) client = test_client_factory(app) response = client.get("/") assert response.reason_phrase == "No Content" app = Response(b"", status_code=123) client = test_client_factory(app) response = client.get("/") assert response.reason_phrase == "" def test_file_response(tmp_path: Path, test_client_factory: TestClientFactory) -> None: path = tmp_path / "xyz" content = b"" * 1000 path.write_bytes(content) filled_by_bg_task = "" async def numbers(minimum: int, maximum: int) -> AsyncIterator[str]: for i in range(minimum, maximum + 1): yield str(i) if i != maximum: yield ", " await anyio.sleep(0) async def numbers_for_cleanup(start: int = 1, stop: int = 5) -> None: nonlocal filled_by_bg_task async for thing in numbers(start, stop): filled_by_bg_task = filled_by_bg_task + thing cleanup_task = BackgroundTask(numbers_for_cleanup, start=6, stop=9) async def app(scope: Scope, receive: Receive, send: Send) -> None: response = FileResponse(path=path, filename="example.png", background=cleanup_task) await response(scope, receive, send) assert filled_by_bg_task == "" client = test_client_factory(app) response = client.get("/") expected_disposition = 'attachment; filename="example.png"' assert response.status_code == status.HTTP_200_OK assert response.content == content assert response.headers["content-type"] == "image/png" assert response.headers["content-disposition"] == expected_disposition assert "content-length" in response.headers assert "last-modified" in response.headers assert "etag" in response.headers assert filled_by_bg_task == "6, 7, 8, 9" @pytest.mark.anyio async def test_file_response_on_head_method(tmp_path: Path) -> None: path = tmp_path / "xyz" content = b"" * 1000 path.write_bytes(content) app = FileResponse(path=path, filename="example.png") async def receive() -> Message: # type: ignore[empty-body] ... # pragma: no cover async def send(message: Message) -> None: if message["type"] == "http.response.start": assert message["status"] == status.HTTP_200_OK headers = Headers(raw=message["headers"]) assert headers["content-type"] == "image/png" assert "content-length" in headers assert "content-disposition" in headers assert "last-modified" in headers assert "etag" in headers elif message["type"] == "http.response.body": assert message["body"] == b"" assert message["more_body"] is False # Since the TestClient drops the response body on HEAD requests, we need to test # this directly. await app({"type": "http", "method": "head"}, receive, send) def test_file_response_set_media_type(tmp_path: Path, test_client_factory: TestClientFactory) -> None: path = tmp_path / "xyz" path.write_bytes(b"") # By default, FileResponse will determine the `content-type` based on # the filename or path, unless a specific `media_type` is provided. app = FileResponse(path=path, filename="example.png", media_type="image/jpeg") client: TestClient = test_client_factory(app) response = client.get("/") assert response.headers["content-type"] == "image/jpeg" def test_file_response_with_directory_raises_error(tmp_path: Path, test_client_factory: TestClientFactory) -> None: app = FileResponse(path=tmp_path, filename="example.png") client = test_client_factory(app) with pytest.raises(RuntimeError) as exc_info: client.get("/") assert "is not a file" in str(exc_info.value) def test_file_response_with_missing_file_raises_error(tmp_path: Path, test_client_factory: TestClientFactory) -> None: path = tmp_path / "404.txt" app = FileResponse(path=path, filename="404.txt") client = test_client_factory(app) with pytest.raises(RuntimeError) as exc_info: client.get("/") assert "does not exist" in str(exc_info.value) def test_file_response_with_chinese_filename(tmp_path: Path, test_client_factory: TestClientFactory) -> None: content = b"file content" filename = "你好.txt" # probably "Hello.txt" in Chinese path = tmp_path / filename path.write_bytes(content) app = FileResponse(path=path, filename=filename) client = test_client_factory(app) response = client.get("/") expected_disposition = "attachment; filename*=utf-8''%E4%BD%A0%E5%A5%BD.txt" assert response.status_code == status.HTTP_200_OK assert response.content == content assert response.headers["content-disposition"] == expected_disposition def test_file_response_with_inline_disposition(tmp_path: Path, test_client_factory: TestClientFactory) -> None: content = b"file content" filename = "hello.txt" path = tmp_path / filename path.write_bytes(content) app = FileResponse(path=path, filename=filename, content_disposition_type="inline") client = test_client_factory(app) response = client.get("/") expected_disposition = 'inline; filename="hello.txt"' assert response.status_code == status.HTTP_200_OK assert response.content == content assert response.headers["content-disposition"] == expected_disposition def test_file_response_with_method_warns(tmp_path: Path) -> None: with pytest.warns(DeprecationWarning): FileResponse(path=tmp_path, filename="example.png", method="GET") def test_set_cookie(test_client_factory: TestClientFactory, monkeypatch: pytest.MonkeyPatch) -> None: # Mock time used as a reference for `Expires` by stdlib `SimpleCookie`. mocked_now = dt.datetime(2037, 1, 22, 12, 0, 0, tzinfo=dt.timezone.utc) monkeypatch.setattr(time, "time", lambda: mocked_now.timestamp()) async def app(scope: Scope, receive: Receive, send: Send) -> None: response = Response("Hello, world!", media_type="text/plain") response.set_cookie( "mycookie", "myvalue", max_age=10, expires=10, path="/", domain="localhost", secure=True, httponly=True, samesite="none", ) await response(scope, receive, send) client = test_client_factory(app) response = client.get("/") assert response.text == "Hello, world!" assert ( response.headers["set-cookie"] == "mycookie=myvalue; Domain=localhost; expires=Thu, 22 Jan 2037 12:00:10 GMT; " "HttpOnly; Max-Age=10; Path=/; SameSite=none; Secure" ) def test_set_cookie_path_none(test_client_factory: TestClientFactory) -> None: async def app(scope: Scope, receive: Receive, send: Send) -> None: response = Response("Hello, world!", media_type="text/plain") response.set_cookie("mycookie", "myvalue", path=None) await response(scope, receive, send) client = test_client_factory(app) response = client.get("/") assert response.text == "Hello, world!" assert response.headers["set-cookie"] == "mycookie=myvalue; SameSite=lax" @pytest.mark.parametrize( "expires", [ pytest.param(dt.datetime(2037, 1, 22, 12, 0, 10, tzinfo=dt.timezone.utc), id="datetime"), pytest.param("Thu, 22 Jan 2037 12:00:10 GMT", id="str"), pytest.param(10, id="int"), ], ) def test_expires_on_set_cookie( test_client_factory: TestClientFactory, monkeypatch: pytest.MonkeyPatch, expires: str, ) -> None: # Mock time used as a reference for `Expires` by stdlib `SimpleCookie`. mocked_now = dt.datetime(2037, 1, 22, 12, 0, 0, tzinfo=dt.timezone.utc) monkeypatch.setattr(time, "time", lambda: mocked_now.timestamp()) async def app(scope: Scope, receive: Receive, send: Send) -> None: response = Response("Hello, world!", media_type="text/plain") response.set_cookie("mycookie", "myvalue", expires=expires) await response(scope, receive, send) client = test_client_factory(app) response = client.get("/") cookie = SimpleCookie(response.headers.get("set-cookie")) assert cookie["mycookie"]["expires"] == "Thu, 22 Jan 2037 12:00:10 GMT" def test_delete_cookie(test_client_factory: TestClientFactory) -> None: async def app(scope: Scope, receive: Receive, send: Send) -> None: request = Request(scope, receive) response = Response("Hello, world!", media_type="text/plain") if request.cookies.get("mycookie"): response.delete_cookie("mycookie") else: response.set_cookie("mycookie", "myvalue") await response(scope, receive, send) client = test_client_factory(app) response = client.get("/") assert response.cookies["mycookie"] response = client.get("/") assert not response.cookies.get("mycookie") def test_populate_headers(test_client_factory: TestClientFactory) -> None: app = Response(content="hi", headers={}, media_type="text/html") client = test_client_factory(app) response = client.get("/") assert response.text == "hi" assert response.headers["content-length"] == "2" assert response.headers["content-type"] == "text/html; charset=utf-8" def test_head_method(test_client_factory: TestClientFactory) -> None: app = Response("hello, world", media_type="text/plain") client = test_client_factory(app) response = client.head("/") assert response.text == "" def test_empty_response(test_client_factory: TestClientFactory) -> None: app = Response() client: TestClient = test_client_factory(app) response = client.get("/") assert response.content == b"" assert response.headers["content-length"] == "0" assert "content-type" not in response.headers def test_empty_204_response(test_client_factory: TestClientFactory) -> None: app = Response(status_code=204) client: TestClient = test_client_factory(app) response = client.get("/") assert "content-length" not in response.headers def test_non_empty_response(test_client_factory: TestClientFactory) -> None: app = Response(content="hi") client: TestClient = test_client_factory(app) response = client.get("/") assert response.headers["content-length"] == "2" def test_response_do_not_add_redundant_charset( test_client_factory: TestClientFactory, ) -> None: app = Response(media_type="text/plain; charset=utf-8") client = test_client_factory(app) response = client.get("/") assert response.headers["content-type"] == "text/plain; charset=utf-8" def test_file_response_known_size(tmp_path: Path, test_client_factory: TestClientFactory) -> None: path = tmp_path / "xyz" content = b"" * 1000 path.write_bytes(content) app = FileResponse(path=path, filename="example.png") client: TestClient = test_client_factory(app) response = client.get("/") assert response.headers["content-length"] == str(len(content)) def test_streaming_response_unknown_size( test_client_factory: TestClientFactory, ) -> None: app = StreamingResponse(content=iter(["hello", "world"])) client: TestClient = test_client_factory(app) response = client.get("/") assert "content-length" not in response.headers def test_streaming_response_known_size(test_client_factory: TestClientFactory) -> None: app = StreamingResponse(content=iter(["hello", "world"]), headers={"content-length": "10"}) client: TestClient = test_client_factory(app) response = client.get("/") assert response.headers["content-length"] == "10" def test_response_memoryview(test_client_factory: TestClientFactory) -> None: app = Response(content=memoryview(b"\xc0")) client: TestClient = test_client_factory(app) response = client.get("/") assert response.content == b"\xc0" def test_streaming_response_memoryview(test_client_factory: TestClientFactory) -> None: app = StreamingResponse(content=iter([memoryview(b"\xc0"), memoryview(b"\xf5")])) client: TestClient = test_client_factory(app) response = client.get("/") assert response.content == b"\xc0\xf5" @pytest.mark.anyio async def test_streaming_response_stops_if_receiving_http_disconnect() -> None: streamed = 0 disconnected = anyio.Event() async def receive_disconnect() -> Message: await disconnected.wait() return {"type": "http.disconnect"} async def send(message: Message) -> None: nonlocal streamed if message["type"] == "http.response.body": streamed += len(message.get("body", b"")) # Simulate disconnection after download has started if streamed >= 16: disconnected.set() async def stream_indefinitely() -> AsyncIterator[bytes]: while True: # Need a sleep for the event loop to switch to another task await anyio.sleep(0) yield b"chunk " response = StreamingResponse(content=stream_indefinitely()) with anyio.move_on_after(1) as cancel_scope: await response({}, receive_disconnect, send) assert not cancel_scope.cancel_called, "Content streaming should stop itself."