diff --git a/starlette/staticfiles.py b/starlette/staticfiles.py index d5964158..b816fb1b 100644 --- a/starlette/staticfiles.py +++ b/starlette/staticfiles.py @@ -5,46 +5,85 @@ import stat class StaticFile: - def __init__(self, *, path): + def __init__(self, path): self.path = path def __call__(self, scope): if scope['method'] not in ('GET', 'HEAD'): return PlainTextResponse('Method not allowed', status_code=406) - return _StaticFileResponder(scope, path=self.path, allow_404=False) + return _StaticFileResponder(scope, path=self.path) class StaticFiles: def __init__(self, *, directory): self.directory = directory + self.config_checked = False def __call__(self, scope): if scope['method'] not in ('GET', 'HEAD'): return PlainTextResponse('Method not allowed', status_code=406) split_path = scope['path'].split('/') path = os.path.join(self.directory, *split_path) - return _StaticFileResponder(scope, path=path, allow_404=True) + if self.config_checked: + check_directory = None + else: + check_directory = self.directory + self.config_checked = True + return _StaticFilesResponder(scope, path=path, check_directory=check_directory) class _StaticFileResponder: - def __init__(self, scope, path, allow_404): + def __init__(self, scope, path): self.scope = scope self.path = path - self.allow_404 = allow_404 async def __call__(self, receive, send): try: stat_result = await aio_stat(self.path) except FileNotFoundError: - if not self.allow_404: - raise RuntimeError("StaticFile at path '%s' does not exist." % self.path) + raise RuntimeError("StaticFile at path '%s' does not exist." % self.path) + else: + mode = stat_result.st_mode + if not stat.S_ISREG(mode): + raise RuntimeError("StaticFile at path '%s' is not a file." % self.path) + + response = FileResponse(self.path, stat_result=stat_result) + await response(receive, send) + + +class _StaticFilesResponder: + def __init__(self, scope, path, check_directory=None): + self.scope = scope + self.path = path + self.check_directory = check_directory + + async def check_directory_configured_correctly(self): + """ + Perform a one-off configuration check that StaticFiles is actually + pointed at a directory, so that we can raise loud errors rather than + just returning 404 responses. + """ + directory = self.check_directory + try: + stat_result = await aio_stat(directory) + except FileNotFoundError: + raise RuntimeError("StaticFiles directory '%s' does not exist." % directory) + if not stat.S_ISDIR(stat_result.st_mode): + raise RuntimeError("StaticFiles path '%s' is not a directory." % directory) + + async def __call__(self, receive, send): + if self.check_directory is not None: + await self.check_directory_configured_correctly() + + try: + stat_result = await aio_stat(self.path) + except FileNotFoundError: response = PlainTextResponse('Not found', status_code=404) else: mode = stat_result.st_mode - if stat.S_ISREG(mode) or stat.S_ISLNK(mode): - response = FileResponse(self.path, stat_result=stat_result) - else: - if not self.allow_404: - raise RuntimeError("StaticFile at path '%s' is not a file." % self.path) + if not stat.S_ISREG(mode): response = PlainTextResponse('Not found', status_code=404) + else: + response = FileResponse(self.path, stat_result=stat_result) + await response(receive, send) diff --git a/tests/test_staticfiles.py b/tests/test_staticfiles.py index 5565a432..75e072d4 100644 --- a/tests/test_staticfiles.py +++ b/tests/test_staticfiles.py @@ -91,3 +91,34 @@ def test_staticfiles_with_missing_file_returns_404(tmpdir): response = client.get("/404.txt") assert response.status_code == 404 assert response.text == 'Not found' + + +def test_staticfiles_configured_with_missing_directory(tmpdir): + path = os.path.join(tmpdir, "no_such_directory") + app = StaticFiles(directory=path) + client = TestClient(app) + with pytest.raises(RuntimeError) as exc: + response = client.get("/example.txt") + assert 'does not exist' in str(exc) + + +def test_staticfiles_configured_with_file_instead_of_directory(tmpdir): + path = os.path.join(tmpdir, "example.txt") + with open(path, "w") as file: + file.write("") + + app = StaticFiles(directory=path) + client = TestClient(app) + with pytest.raises(RuntimeError) as exc: + response = client.get("/example.txt") + assert 'is not a directory' in str(exc) + + +def test_staticfiles_config_check_occurs_only_once(tmpdir): + app = StaticFiles(directory=tmpdir) + client = TestClient(app) + assert not app.config_checked + response = client.get("/") + assert app.config_checked + response = client.get("/") + assert app.config_checked