mirror of https://github.com/encode/starlette.git
Don't omit `Content-Length` header for `Content-Length: 0` cases (#1395)
* Add content-length header by default * Add test for #1099 * Revert changes and add tests * Check if is StreamingResponse or FileResponse before adding content-length headers * Change conditional logic to check if body is present
This commit is contained in:
parent
9d686a7125
commit
4633427816
|
@ -70,8 +70,8 @@ class Response:
|
||||||
populate_content_length = b"content-length" not in keys
|
populate_content_length = b"content-length" not in keys
|
||||||
populate_content_type = b"content-type" not in keys
|
populate_content_type = b"content-type" not in keys
|
||||||
|
|
||||||
body = getattr(self, "body", b"")
|
body = getattr(self, "body", None)
|
||||||
if body and populate_content_length:
|
if body is not None and populate_content_length:
|
||||||
content_length = str(len(body))
|
content_length = str(len(body))
|
||||||
raw_headers.append((b"content-length", content_length.encode("latin-1")))
|
raw_headers.append((b"content-length", content_length.encode("latin-1")))
|
||||||
|
|
||||||
|
|
|
@ -100,7 +100,7 @@ class StaticFiles:
|
||||||
def get_path(self, scope: Scope) -> str:
|
def get_path(self, scope: Scope) -> str:
|
||||||
"""
|
"""
|
||||||
Given the ASGI scope, return the `path` string to serve up,
|
Given the ASGI scope, return the `path` string to serve up,
|
||||||
with OS specific path seperators, and any '..', '.' components removed.
|
with OS specific path separators, and any '..', '.' components removed.
|
||||||
"""
|
"""
|
||||||
return os.path.normpath(os.path.join(*scope["path"].split("/")))
|
return os.path.normpath(os.path.join(*scope["path"].split("/")))
|
||||||
|
|
||||||
|
|
|
@ -13,6 +13,7 @@ from starlette.responses import (
|
||||||
Response,
|
Response,
|
||||||
StreamingResponse,
|
StreamingResponse,
|
||||||
)
|
)
|
||||||
|
from starlette.testclient import TestClient
|
||||||
|
|
||||||
|
|
||||||
def test_text_response(test_client_factory):
|
def test_text_response(test_client_factory):
|
||||||
|
@ -73,6 +74,20 @@ def test_quoting_redirect_response(test_client_factory):
|
||||||
assert response.url == "http://testserver/I%20%E2%99%A5%20Starlette/"
|
assert response.url == "http://testserver/I%20%E2%99%A5%20Starlette/"
|
||||||
|
|
||||||
|
|
||||||
|
def test_redirect_response_content_length_header(test_client_factory):
|
||||||
|
async def app(scope, receive, send):
|
||||||
|
if scope["path"] == "/":
|
||||||
|
response = Response("hello", media_type="text/plain") # pragma: nocover
|
||||||
|
else:
|
||||||
|
response = RedirectResponse("/")
|
||||||
|
await response(scope, receive, send)
|
||||||
|
|
||||||
|
client: TestClient = test_client_factory(app)
|
||||||
|
response = client.request("GET", "/redirect", allow_redirects=False)
|
||||||
|
assert response.url == "http://testserver/redirect"
|
||||||
|
assert response.headers["content-length"] == "0"
|
||||||
|
|
||||||
|
|
||||||
def test_streaming_response(test_client_factory):
|
def test_streaming_response(test_client_factory):
|
||||||
filled_by_bg_task = ""
|
filled_by_bg_task = ""
|
||||||
|
|
||||||
|
@ -309,3 +324,45 @@ def test_head_method(test_client_factory):
|
||||||
client = test_client_factory(app)
|
client = test_client_factory(app)
|
||||||
response = client.head("/")
|
response = client.head("/")
|
||||||
assert response.text == ""
|
assert response.text == ""
|
||||||
|
|
||||||
|
|
||||||
|
def test_empty_response(test_client_factory):
|
||||||
|
app = Response()
|
||||||
|
client: TestClient = test_client_factory(app)
|
||||||
|
response = client.get("/")
|
||||||
|
assert response.headers["content-length"] == "0"
|
||||||
|
|
||||||
|
|
||||||
|
def test_non_empty_response(test_client_factory):
|
||||||
|
app = Response(content="hi")
|
||||||
|
client: TestClient = test_client_factory(app)
|
||||||
|
response = client.get("/")
|
||||||
|
assert response.headers["content-length"] == "2"
|
||||||
|
|
||||||
|
|
||||||
|
def test_file_response_known_size(tmpdir, test_client_factory):
|
||||||
|
path = os.path.join(tmpdir, "xyz")
|
||||||
|
content = b"<file content>" * 1000
|
||||||
|
with open(path, "wb") as file:
|
||||||
|
file.write(content)
|
||||||
|
|
||||||
|
app = FileResponse(path=path, filename="example.png")
|
||||||
|
client: TestClient = test_client_factory(app)
|
||||||
|
response = client.get("/")
|
||||||
|
assert response.headers["content-length"] == str(len(content))
|
||||||
|
|
||||||
|
|
||||||
|
def test_streaming_response_unknown_size(test_client_factory):
|
||||||
|
app = StreamingResponse(content=iter(["hello", "world"]))
|
||||||
|
client: TestClient = test_client_factory(app)
|
||||||
|
response = client.get("/")
|
||||||
|
assert "content-length" not in response.headers
|
||||||
|
|
||||||
|
|
||||||
|
def test_streaming_response_known_size(test_client_factory):
|
||||||
|
app = StreamingResponse(
|
||||||
|
content=iter(["hello", "world"]), headers={"content-length": "10"}
|
||||||
|
)
|
||||||
|
client: TestClient = test_client_factory(app)
|
||||||
|
response = client.get("/")
|
||||||
|
assert response.headers["content-length"] == "10"
|
||||||
|
|
Loading…
Reference in New Issue