diff --git a/docs/staticfiles.md b/docs/staticfiles.md index 03d9a581..f1b9d0a5 100644 --- a/docs/staticfiles.md +++ b/docs/staticfiles.md @@ -1,20 +1,17 @@ -As well as the `FileResponse` class, Starlette also includes ASGI applications -for serving a specific file or directory: +Starlette also includes an `StaticFiles` class for serving a specific directory: -* `StaticFile(path)` - Serve a single file, given by `path`. * `StaticFiles(directory)` - Serve any files in the given `directory`. -You can combine these ASGI applications with Starlette's routing to provide +You can combine this ASGI application with Starlette's routing to provide comprehensive static file serving. ```python from starlette.routing import Router, Path, PathPrefix -from starlette.staticfiles import StaticFile, StaticFiles +from starlette.staticfiles import StaticFiles app = Router(routes=[ - Path('/', app=StaticFile(path='index.html')), PathPrefix('/static', app=StaticFiles(directory='static')), ]) ``` diff --git a/starlette/responses.py b/starlette/responses.py index 3471fbf2..09ac58f5 100644 --- a/starlette/responses.py +++ b/starlette/responses.py @@ -1,15 +1,16 @@ -import hashlib import os -import typing import json - +import stat +import typing +import hashlib +import http.cookies from email.utils import formatdate from mimetypes import guess_type +from urllib.parse import quote_plus + from starlette.background import BackgroundTask from starlette.datastructures import MutableHeaders, URL from starlette.types import Receive, Send -from urllib.parse import quote_plus -import http.cookies try: import aiofiles @@ -227,8 +228,15 @@ class FileResponse(Response): async def __call__(self, receive: Receive, send: Send) -> None: if self.stat_result is None: - stat_result = await aio_stat(self.path) - self.set_stat_headers(stat_result) + try: + stat_result = await aio_stat(self.path) + self.set_stat_headers(stat_result) + except FileNotFoundError: + raise RuntimeError(f"File at path {self.path} does not exist.") + else: + mode = stat_result.st_mode + if not stat.S_ISREG(mode): + raise RuntimeError(f"File at path {self.path} is not a file.") await send( { "type": "http.response.start", diff --git a/starlette/staticfiles.py b/starlette/staticfiles.py index eb8f7483..b9b54cfe 100644 --- a/starlette/staticfiles.py +++ b/starlette/staticfiles.py @@ -7,17 +7,6 @@ from starlette.responses import PlainTextResponse, FileResponse, Response from starlette.types import Send, Receive, Scope, ASGIInstance -class StaticFile: - def __init__(self, *, path: str) -> None: - self.path = path - - def __call__(self, scope: Scope) -> ASGIInstance: - assert scope["type"] == "http" - if scope["method"] not in ("GET", "HEAD"): - return PlainTextResponse("Method Not Allowed", status_code=405) - return _StaticFileResponder(scope, path=self.path) - - class StaticFiles: def __init__(self, *, directory: str) -> None: self.directory = directory @@ -39,25 +28,6 @@ class StaticFiles: return _StaticFilesResponder(scope, path=path, check_directory=check_directory) -class _StaticFileResponder: - def __init__(self, scope: Scope, path: str) -> None: - self.scope = scope - self.path = path - - async def __call__(self, receive: Receive, send: Send) -> None: - try: - stat_result = await aio_stat(self.path) - except FileNotFoundError: - raise RuntimeError("StaticFile at path '%s' does not exist." % self.path) - else: - mode = stat_result.st_mode - if not stat.S_ISREG(mode): - raise RuntimeError("StaticFile at path '%s' is not a file." % self.path) - - response = FileResponse(self.path, stat_result=stat_result) - await response(receive, send) - - class _StaticFilesResponder: def __init__(self, scope: Scope, path: str, check_directory: str = None) -> None: self.scope = scope diff --git a/tests/test_responses.py b/tests/test_responses.py index 467b3601..63d50767 100644 --- a/tests/test_responses.py +++ b/tests/test_responses.py @@ -9,6 +9,7 @@ from starlette.requests import Request from starlette.testclient import TestClient from starlette import status import asyncio +import pytest import os @@ -144,6 +145,28 @@ def test_file_response(tmpdir): assert "etag" in response.headers +def test_file_response_with_directory_raises_error(tmpdir): + def app(scope): + return FileResponse(path=tmpdir, filename="example.png") + + client = TestClient(app) + with pytest.raises(RuntimeError) as exc: + client.get("/") + assert "is not a file" in str(exc) + + +def test_file_response_with_missing_file_raises_error(tmpdir): + path = os.path.join(tmpdir, "404.txt") + + def app(scope): + return FileResponse(path=path, filename="404.txt") + + client = TestClient(app) + with pytest.raises(RuntimeError) as exc: + client.get("/") + assert "does not exist" in str(exc) + + def test_set_cookie(): def app(scope): async def asgi(receive, send): diff --git a/tests/test_staticfiles.py b/tests/test_staticfiles.py index e21ce602..bc7ef0fb 100644 --- a/tests/test_staticfiles.py +++ b/tests/test_staticfiles.py @@ -2,63 +2,7 @@ import os import pytest from starlette.testclient import TestClient -from starlette.staticfiles import StaticFile, StaticFiles - - -def test_staticfile(tmpdir): - path = os.path.join(tmpdir, "example.txt") - with open(path, "w") as file: - file.write("") - - app = StaticFile(path=path) - client = TestClient(app) - response = client.get("/") - assert response.status_code == 200 - assert response.text == "" - - -def test_large_staticfile(tmpdir): - path = os.path.join(tmpdir, "example.txt") - content = "this is a lot of content" * 200 - print("content len = ", len(content)) - with open(path, "w") as file: - file.write(content) - - app = StaticFile(path=path) - client = TestClient(app) - response = client.get("/") - assert response.status_code == 200 - assert len(content) == len(response.text) - assert content == response.text - - -def test_staticfile_post(tmpdir): - path = os.path.join(tmpdir, "example.txt") - with open(path, "w") as file: - file.write("") - - app = StaticFile(path=path) - client = TestClient(app) - response = client.post("/") - assert response.status_code == 405 - assert response.text == "Method Not Allowed" - - -def test_staticfile_with_directory_raises_error(tmpdir): - app = StaticFile(path=tmpdir) - client = TestClient(app) - with pytest.raises(RuntimeError) as exc: - client.get("/") - assert "is not a file" in str(exc) - - -def test_staticfile_with_missing_file_raises_error(tmpdir): - path = os.path.join(tmpdir, "404.txt") - app = StaticFile(path=path) - client = TestClient(app) - with pytest.raises(RuntimeError) as exc: - client.get("/") - assert "does not exist" in str(exc) +from starlette.staticfiles import StaticFiles def test_staticfiles(tmpdir):