starlette/tests/test_responses.py

549 lines
20 KiB
Python

from __future__ import annotations
import datetime as dt
import time
from http.cookies import SimpleCookie
from pathlib import Path
from typing import AsyncIterator, Iterator
import anyio
import pytest
from starlette import status
from starlette.background import BackgroundTask
from starlette.datastructures import Headers
from starlette.requests import Request
from starlette.responses import (
FileResponse,
JSONResponse,
RedirectResponse,
Response,
StreamingResponse,
)
from starlette.testclient import TestClient
from starlette.types import Message, Receive, Scope, Send
from tests.types import TestClientFactory
def test_text_response(test_client_factory: TestClientFactory) -> None:
async def app(scope: Scope, receive: Receive, send: Send) -> None:
response = Response("hello, world", media_type="text/plain")
await response(scope, receive, send)
client = test_client_factory(app)
response = client.get("/")
assert response.text == "hello, world"
def test_bytes_response(test_client_factory: TestClientFactory) -> None:
async def app(scope: Scope, receive: Receive, send: Send) -> None:
response = Response(b"xxxxx", media_type="image/png")
await response(scope, receive, send)
client = test_client_factory(app)
response = client.get("/")
assert response.content == b"xxxxx"
def test_json_none_response(test_client_factory: TestClientFactory) -> None:
async def app(scope: Scope, receive: Receive, send: Send) -> None:
response = JSONResponse(None)
await response(scope, receive, send)
client = test_client_factory(app)
response = client.get("/")
assert response.json() is None
assert response.content == b"null"
def test_redirect_response(test_client_factory: TestClientFactory) -> None:
async def app(scope: Scope, receive: Receive, send: Send) -> None:
if scope["path"] == "/":
response = Response("hello, world", media_type="text/plain")
else:
response = RedirectResponse("/")
await response(scope, receive, send)
client = test_client_factory(app)
response = client.get("/redirect")
assert response.text == "hello, world"
assert response.url == "http://testserver/"
def test_quoting_redirect_response(test_client_factory: TestClientFactory) -> None:
async def app(scope: Scope, receive: Receive, send: Send) -> None:
if scope["path"] == "/I ♥ Starlette/":
response = Response("hello, world", media_type="text/plain")
else:
response = RedirectResponse("/I ♥ Starlette/")
await response(scope, receive, send)
client = test_client_factory(app)
response = client.get("/redirect")
assert response.text == "hello, world"
assert response.url == "http://testserver/I%20%E2%99%A5%20Starlette/"
def test_redirect_response_content_length_header(
test_client_factory: TestClientFactory,
) -> None:
async def app(scope: Scope, receive: Receive, send: Send) -> None:
if scope["path"] == "/":
response = Response("hello", media_type="text/plain") # pragma: nocover
else:
response = RedirectResponse("/")
await response(scope, receive, send)
client: TestClient = test_client_factory(app)
response = client.request("GET", "/redirect", allow_redirects=False)
assert response.url == "http://testserver/redirect"
assert response.headers["content-length"] == "0"
def test_streaming_response(test_client_factory: TestClientFactory) -> None:
filled_by_bg_task = ""
async def app(scope: Scope, receive: Receive, send: Send) -> None:
async def numbers(minimum: int, maximum: int) -> AsyncIterator[str]:
for i in range(minimum, maximum + 1):
yield str(i)
if i != maximum:
yield ", "
await anyio.sleep(0)
async def numbers_for_cleanup(start: int = 1, stop: int = 5) -> None:
nonlocal filled_by_bg_task
async for thing in numbers(start, stop):
filled_by_bg_task = filled_by_bg_task + thing
cleanup_task = BackgroundTask(numbers_for_cleanup, start=6, stop=9)
generator = numbers(1, 5)
response = StreamingResponse(generator, media_type="text/plain", background=cleanup_task)
await response(scope, receive, send)
assert filled_by_bg_task == ""
client = test_client_factory(app)
response = client.get("/")
assert response.text == "1, 2, 3, 4, 5"
assert filled_by_bg_task == "6, 7, 8, 9"
def test_streaming_response_custom_iterator(
test_client_factory: TestClientFactory,
) -> None:
async def app(scope: Scope, receive: Receive, send: Send) -> None:
class CustomAsyncIterator:
def __init__(self) -> None:
self._called = 0
def __aiter__(self) -> AsyncIterator[str]:
return self
async def __anext__(self) -> str:
if self._called == 5:
raise StopAsyncIteration()
self._called += 1
return str(self._called)
response = StreamingResponse(CustomAsyncIterator(), media_type="text/plain")
await response(scope, receive, send)
client = test_client_factory(app)
response = client.get("/")
assert response.text == "12345"
def test_streaming_response_custom_iterable(
test_client_factory: TestClientFactory,
) -> None:
async def app(scope: Scope, receive: Receive, send: Send) -> None:
class CustomAsyncIterable:
async def __aiter__(self) -> AsyncIterator[str | bytes]:
for i in range(5):
yield str(i + 1)
response = StreamingResponse(CustomAsyncIterable(), media_type="text/plain")
await response(scope, receive, send)
client = test_client_factory(app)
response = client.get("/")
assert response.text == "12345"
def test_sync_streaming_response(test_client_factory: TestClientFactory) -> None:
async def app(scope: Scope, receive: Receive, send: Send) -> None:
def numbers(minimum: int, maximum: int) -> Iterator[str]:
for i in range(minimum, maximum + 1):
yield str(i)
if i != maximum:
yield ", "
generator = numbers(1, 5)
response = StreamingResponse(generator, media_type="text/plain")
await response(scope, receive, send)
client = test_client_factory(app)
response = client.get("/")
assert response.text == "1, 2, 3, 4, 5"
def test_response_headers(test_client_factory: TestClientFactory) -> None:
async def app(scope: Scope, receive: Receive, send: Send) -> None:
headers = {"x-header-1": "123", "x-header-2": "456"}
response = Response("hello, world", media_type="text/plain", headers=headers)
response.headers["x-header-2"] = "789"
await response(scope, receive, send)
client = test_client_factory(app)
response = client.get("/")
assert response.headers["x-header-1"] == "123"
assert response.headers["x-header-2"] == "789"
def test_response_phrase(test_client_factory: TestClientFactory) -> None:
app = Response(status_code=204)
client = test_client_factory(app)
response = client.get("/")
assert response.reason_phrase == "No Content"
app = Response(b"", status_code=123)
client = test_client_factory(app)
response = client.get("/")
assert response.reason_phrase == ""
def test_file_response(tmp_path: Path, test_client_factory: TestClientFactory) -> None:
path = tmp_path / "xyz"
content = b"<file content>" * 1000
path.write_bytes(content)
filled_by_bg_task = ""
async def numbers(minimum: int, maximum: int) -> AsyncIterator[str]:
for i in range(minimum, maximum + 1):
yield str(i)
if i != maximum:
yield ", "
await anyio.sleep(0)
async def numbers_for_cleanup(start: int = 1, stop: int = 5) -> None:
nonlocal filled_by_bg_task
async for thing in numbers(start, stop):
filled_by_bg_task = filled_by_bg_task + thing
cleanup_task = BackgroundTask(numbers_for_cleanup, start=6, stop=9)
async def app(scope: Scope, receive: Receive, send: Send) -> None:
response = FileResponse(path=path, filename="example.png", background=cleanup_task)
await response(scope, receive, send)
assert filled_by_bg_task == ""
client = test_client_factory(app)
response = client.get("/")
expected_disposition = 'attachment; filename="example.png"'
assert response.status_code == status.HTTP_200_OK
assert response.content == content
assert response.headers["content-type"] == "image/png"
assert response.headers["content-disposition"] == expected_disposition
assert "content-length" in response.headers
assert "last-modified" in response.headers
assert "etag" in response.headers
assert filled_by_bg_task == "6, 7, 8, 9"
@pytest.mark.anyio
async def test_file_response_on_head_method(tmp_path: Path) -> None:
path = tmp_path / "xyz"
content = b"<file content>" * 1000
path.write_bytes(content)
app = FileResponse(path=path, filename="example.png")
async def receive() -> Message: # type: ignore[empty-body]
... # pragma: no cover
async def send(message: Message) -> None:
if message["type"] == "http.response.start":
assert message["status"] == status.HTTP_200_OK
headers = Headers(raw=message["headers"])
assert headers["content-type"] == "image/png"
assert "content-length" in headers
assert "content-disposition" in headers
assert "last-modified" in headers
assert "etag" in headers
elif message["type"] == "http.response.body":
assert message["body"] == b""
assert message["more_body"] is False
# Since the TestClient drops the response body on HEAD requests, we need to test
# this directly.
await app({"type": "http", "method": "head"}, receive, send)
def test_file_response_set_media_type(tmp_path: Path, test_client_factory: TestClientFactory) -> None:
path = tmp_path / "xyz"
path.write_bytes(b"<file content>")
# By default, FileResponse will determine the `content-type` based on
# the filename or path, unless a specific `media_type` is provided.
app = FileResponse(path=path, filename="example.png", media_type="image/jpeg")
client: TestClient = test_client_factory(app)
response = client.get("/")
assert response.headers["content-type"] == "image/jpeg"
def test_file_response_with_directory_raises_error(tmp_path: Path, test_client_factory: TestClientFactory) -> None:
app = FileResponse(path=tmp_path, filename="example.png")
client = test_client_factory(app)
with pytest.raises(RuntimeError) as exc_info:
client.get("/")
assert "is not a file" in str(exc_info.value)
def test_file_response_with_missing_file_raises_error(tmp_path: Path, test_client_factory: TestClientFactory) -> None:
path = tmp_path / "404.txt"
app = FileResponse(path=path, filename="404.txt")
client = test_client_factory(app)
with pytest.raises(RuntimeError) as exc_info:
client.get("/")
assert "does not exist" in str(exc_info.value)
def test_file_response_with_chinese_filename(tmp_path: Path, test_client_factory: TestClientFactory) -> None:
content = b"file content"
filename = "你好.txt" # probably "Hello.txt" in Chinese
path = tmp_path / filename
path.write_bytes(content)
app = FileResponse(path=path, filename=filename)
client = test_client_factory(app)
response = client.get("/")
expected_disposition = "attachment; filename*=utf-8''%E4%BD%A0%E5%A5%BD.txt"
assert response.status_code == status.HTTP_200_OK
assert response.content == content
assert response.headers["content-disposition"] == expected_disposition
def test_file_response_with_inline_disposition(tmp_path: Path, test_client_factory: TestClientFactory) -> None:
content = b"file content"
filename = "hello.txt"
path = tmp_path / filename
path.write_bytes(content)
app = FileResponse(path=path, filename=filename, content_disposition_type="inline")
client = test_client_factory(app)
response = client.get("/")
expected_disposition = 'inline; filename="hello.txt"'
assert response.status_code == status.HTTP_200_OK
assert response.content == content
assert response.headers["content-disposition"] == expected_disposition
def test_file_response_with_method_warns(tmp_path: Path) -> None:
with pytest.warns(DeprecationWarning):
FileResponse(path=tmp_path, filename="example.png", method="GET")
def test_set_cookie(test_client_factory: TestClientFactory, monkeypatch: pytest.MonkeyPatch) -> None:
# Mock time used as a reference for `Expires` by stdlib `SimpleCookie`.
mocked_now = dt.datetime(2037, 1, 22, 12, 0, 0, tzinfo=dt.timezone.utc)
monkeypatch.setattr(time, "time", lambda: mocked_now.timestamp())
async def app(scope: Scope, receive: Receive, send: Send) -> None:
response = Response("Hello, world!", media_type="text/plain")
response.set_cookie(
"mycookie",
"myvalue",
max_age=10,
expires=10,
path="/",
domain="localhost",
secure=True,
httponly=True,
samesite="none",
)
await response(scope, receive, send)
client = test_client_factory(app)
response = client.get("/")
assert response.text == "Hello, world!"
assert (
response.headers["set-cookie"] == "mycookie=myvalue; Domain=localhost; expires=Thu, 22 Jan 2037 12:00:10 GMT; "
"HttpOnly; Max-Age=10; Path=/; SameSite=none; Secure"
)
def test_set_cookie_path_none(test_client_factory: TestClientFactory) -> None:
async def app(scope: Scope, receive: Receive, send: Send) -> None:
response = Response("Hello, world!", media_type="text/plain")
response.set_cookie("mycookie", "myvalue", path=None)
await response(scope, receive, send)
client = test_client_factory(app)
response = client.get("/")
assert response.text == "Hello, world!"
assert response.headers["set-cookie"] == "mycookie=myvalue; SameSite=lax"
@pytest.mark.parametrize(
"expires",
[
pytest.param(dt.datetime(2037, 1, 22, 12, 0, 10, tzinfo=dt.timezone.utc), id="datetime"),
pytest.param("Thu, 22 Jan 2037 12:00:10 GMT", id="str"),
pytest.param(10, id="int"),
],
)
def test_expires_on_set_cookie(
test_client_factory: TestClientFactory,
monkeypatch: pytest.MonkeyPatch,
expires: str,
) -> None:
# Mock time used as a reference for `Expires` by stdlib `SimpleCookie`.
mocked_now = dt.datetime(2037, 1, 22, 12, 0, 0, tzinfo=dt.timezone.utc)
monkeypatch.setattr(time, "time", lambda: mocked_now.timestamp())
async def app(scope: Scope, receive: Receive, send: Send) -> None:
response = Response("Hello, world!", media_type="text/plain")
response.set_cookie("mycookie", "myvalue", expires=expires)
await response(scope, receive, send)
client = test_client_factory(app)
response = client.get("/")
cookie = SimpleCookie(response.headers.get("set-cookie"))
assert cookie["mycookie"]["expires"] == "Thu, 22 Jan 2037 12:00:10 GMT"
def test_delete_cookie(test_client_factory: TestClientFactory) -> None:
async def app(scope: Scope, receive: Receive, send: Send) -> None:
request = Request(scope, receive)
response = Response("Hello, world!", media_type="text/plain")
if request.cookies.get("mycookie"):
response.delete_cookie("mycookie")
else:
response.set_cookie("mycookie", "myvalue")
await response(scope, receive, send)
client = test_client_factory(app)
response = client.get("/")
assert response.cookies["mycookie"]
response = client.get("/")
assert not response.cookies.get("mycookie")
def test_populate_headers(test_client_factory: TestClientFactory) -> None:
app = Response(content="hi", headers={}, media_type="text/html")
client = test_client_factory(app)
response = client.get("/")
assert response.text == "hi"
assert response.headers["content-length"] == "2"
assert response.headers["content-type"] == "text/html; charset=utf-8"
def test_head_method(test_client_factory: TestClientFactory) -> None:
app = Response("hello, world", media_type="text/plain")
client = test_client_factory(app)
response = client.head("/")
assert response.text == ""
def test_empty_response(test_client_factory: TestClientFactory) -> None:
app = Response()
client: TestClient = test_client_factory(app)
response = client.get("/")
assert response.content == b""
assert response.headers["content-length"] == "0"
assert "content-type" not in response.headers
def test_empty_204_response(test_client_factory: TestClientFactory) -> None:
app = Response(status_code=204)
client: TestClient = test_client_factory(app)
response = client.get("/")
assert "content-length" not in response.headers
def test_non_empty_response(test_client_factory: TestClientFactory) -> None:
app = Response(content="hi")
client: TestClient = test_client_factory(app)
response = client.get("/")
assert response.headers["content-length"] == "2"
def test_response_do_not_add_redundant_charset(
test_client_factory: TestClientFactory,
) -> None:
app = Response(media_type="text/plain; charset=utf-8")
client = test_client_factory(app)
response = client.get("/")
assert response.headers["content-type"] == "text/plain; charset=utf-8"
def test_file_response_known_size(tmp_path: Path, test_client_factory: TestClientFactory) -> None:
path = tmp_path / "xyz"
content = b"<file content>" * 1000
path.write_bytes(content)
app = FileResponse(path=path, filename="example.png")
client: TestClient = test_client_factory(app)
response = client.get("/")
assert response.headers["content-length"] == str(len(content))
def test_streaming_response_unknown_size(
test_client_factory: TestClientFactory,
) -> None:
app = StreamingResponse(content=iter(["hello", "world"]))
client: TestClient = test_client_factory(app)
response = client.get("/")
assert "content-length" not in response.headers
def test_streaming_response_known_size(test_client_factory: TestClientFactory) -> None:
app = StreamingResponse(content=iter(["hello", "world"]), headers={"content-length": "10"})
client: TestClient = test_client_factory(app)
response = client.get("/")
assert response.headers["content-length"] == "10"
def test_response_memoryview(test_client_factory: TestClientFactory) -> None:
app = Response(content=memoryview(b"\xc0"))
client: TestClient = test_client_factory(app)
response = client.get("/")
assert response.content == b"\xc0"
def test_streaming_response_memoryview(test_client_factory: TestClientFactory) -> None:
app = StreamingResponse(content=iter([memoryview(b"\xc0"), memoryview(b"\xf5")]))
client: TestClient = test_client_factory(app)
response = client.get("/")
assert response.content == b"\xc0\xf5"
@pytest.mark.anyio
async def test_streaming_response_stops_if_receiving_http_disconnect() -> None:
streamed = 0
disconnected = anyio.Event()
async def receive_disconnect() -> Message:
await disconnected.wait()
return {"type": "http.disconnect"}
async def send(message: Message) -> None:
nonlocal streamed
if message["type"] == "http.response.body":
streamed += len(message.get("body", b""))
# Simulate disconnection after download has started
if streamed >= 16:
disconnected.set()
async def stream_indefinitely() -> AsyncIterator[bytes]:
while True:
# Need a sleep for the event loop to switch to another task
await anyio.sleep(0)
yield b"chunk "
response = StreamingResponse(content=stream_indefinitely())
with anyio.move_on_after(1) as cancel_scope:
await response({}, receive_disconnect, send)
assert not cancel_scope.cancel_called, "Content streaming should stop itself."