starlette/tests/middleware/test_wsgi.py

165 lines
4.9 KiB
Python

import sys
from typing import Any, Callable, Dict, Iterable
import pytest
from starlette._utils import collapse_excgroups
from starlette.middleware.wsgi import WSGIMiddleware, build_environ
from tests.types import TestClientFactory
WSGIResponse = Iterable[bytes]
StartResponse = Callable[..., Any]
Environment = Dict[str, Any]
def hello_world(
environ: Environment,
start_response: StartResponse,
) -> WSGIResponse:
status = "200 OK"
output = b"Hello World!\n"
headers = [
("Content-Type", "text/plain; charset=utf-8"),
("Content-Length", str(len(output))),
]
start_response(status, headers)
return [output]
def echo_body(
environ: Environment,
start_response: StartResponse,
) -> WSGIResponse:
status = "200 OK"
output = environ["wsgi.input"].read()
headers = [
("Content-Type", "text/plain; charset=utf-8"),
("Content-Length", str(len(output))),
]
start_response(status, headers)
return [output]
def raise_exception(
environ: Environment,
start_response: StartResponse,
) -> WSGIResponse:
raise RuntimeError("Something went wrong")
def return_exc_info(
environ: Environment,
start_response: StartResponse,
) -> WSGIResponse:
try:
raise RuntimeError("Something went wrong")
except RuntimeError:
status = "500 Internal Server Error"
output = b"Internal Server Error"
headers = [
("Content-Type", "text/plain; charset=utf-8"),
("Content-Length", str(len(output))),
]
start_response(status, headers, exc_info=sys.exc_info())
return [output]
def test_wsgi_get(test_client_factory: TestClientFactory) -> None:
app = WSGIMiddleware(hello_world)
client = test_client_factory(app)
response = client.get("/")
assert response.status_code == 200
assert response.text == "Hello World!\n"
def test_wsgi_post(test_client_factory: TestClientFactory) -> None:
app = WSGIMiddleware(echo_body)
client = test_client_factory(app)
response = client.post("/", json={"example": 123})
assert response.status_code == 200
assert response.text == '{"example":123}'
def test_wsgi_exception(test_client_factory: TestClientFactory) -> None:
# Note that we're testing the WSGI app directly here.
# The HTTP protocol implementations would catch this error and return 500.
app = WSGIMiddleware(raise_exception)
client = test_client_factory(app)
with pytest.raises(RuntimeError), collapse_excgroups():
client.get("/")
def test_wsgi_exc_info(test_client_factory: TestClientFactory) -> None:
# Note that we're testing the WSGI app directly here.
# The HTTP protocol implementations would catch this error and return 500.
app = WSGIMiddleware(return_exc_info)
client = test_client_factory(app)
with pytest.raises(RuntimeError):
response = client.get("/")
app = WSGIMiddleware(return_exc_info)
client = test_client_factory(app, raise_server_exceptions=False)
response = client.get("/")
assert response.status_code == 500
assert response.text == "Internal Server Error"
def test_build_environ() -> None:
scope = {
"type": "http",
"http_version": "1.1",
"method": "GET",
"scheme": "https",
"path": "/sub/",
"root_path": "/sub",
"query_string": b"a=123&b=456",
"headers": [
(b"host", b"www.example.org"),
(b"content-type", b"application/json"),
(b"content-length", b"18"),
(b"accept", b"application/json"),
(b"accept", b"text/plain"),
],
"client": ("134.56.78.4", 1453),
"server": ("www.example.org", 443),
}
body = b'{"example":"body"}'
environ = build_environ(scope, body)
stream = environ.pop("wsgi.input")
assert stream.read() == b'{"example":"body"}'
assert environ == {
"CONTENT_LENGTH": "18",
"CONTENT_TYPE": "application/json",
"HTTP_ACCEPT": "application/json,text/plain",
"HTTP_HOST": "www.example.org",
"PATH_INFO": "/",
"QUERY_STRING": "a=123&b=456",
"REMOTE_ADDR": "134.56.78.4",
"REQUEST_METHOD": "GET",
"SCRIPT_NAME": "/sub",
"SERVER_NAME": "www.example.org",
"SERVER_PORT": 443,
"SERVER_PROTOCOL": "HTTP/1.1",
"wsgi.errors": sys.stdout,
"wsgi.multiprocess": True,
"wsgi.multithread": True,
"wsgi.run_once": False,
"wsgi.url_scheme": "https",
"wsgi.version": (1, 0),
}
def test_build_environ_encoding() -> None:
scope = {
"type": "http",
"http_version": "1.1",
"method": "GET",
"path": "/小星",
"root_path": "/中国",
"query_string": b"a=123&b=456",
"headers": [],
}
environ = build_environ(scope, b"")
assert environ["SCRIPT_NAME"] == "/中国".encode().decode("latin-1")
assert environ["PATH_INFO"] == "/小星".encode().decode("latin-1")