mirror of https://github.com/encode/starlette.git
Configuration checks for StaticFiles
This commit is contained in:
parent
2ed6768352
commit
6259df6e7a
|
@ -5,46 +5,85 @@ import stat
|
|||
|
||||
|
||||
class StaticFile:
|
||||
def __init__(self, *, path):
|
||||
def __init__(self, path):
|
||||
self.path = path
|
||||
|
||||
def __call__(self, scope):
|
||||
if scope['method'] not in ('GET', 'HEAD'):
|
||||
return PlainTextResponse('Method not allowed', status_code=406)
|
||||
return _StaticFileResponder(scope, path=self.path, allow_404=False)
|
||||
return _StaticFileResponder(scope, path=self.path)
|
||||
|
||||
|
||||
class StaticFiles:
|
||||
def __init__(self, *, directory):
|
||||
self.directory = directory
|
||||
self.config_checked = False
|
||||
|
||||
def __call__(self, scope):
|
||||
if scope['method'] not in ('GET', 'HEAD'):
|
||||
return PlainTextResponse('Method not allowed', status_code=406)
|
||||
split_path = scope['path'].split('/')
|
||||
path = os.path.join(self.directory, *split_path)
|
||||
return _StaticFileResponder(scope, path=path, allow_404=True)
|
||||
if self.config_checked:
|
||||
check_directory = None
|
||||
else:
|
||||
check_directory = self.directory
|
||||
self.config_checked = True
|
||||
return _StaticFilesResponder(scope, path=path, check_directory=check_directory)
|
||||
|
||||
|
||||
class _StaticFileResponder:
|
||||
def __init__(self, scope, path, allow_404):
|
||||
def __init__(self, scope, path):
|
||||
self.scope = scope
|
||||
self.path = path
|
||||
self.allow_404 = allow_404
|
||||
|
||||
async def __call__(self, receive, send):
|
||||
try:
|
||||
stat_result = await aio_stat(self.path)
|
||||
except FileNotFoundError:
|
||||
if not self.allow_404:
|
||||
raise RuntimeError("StaticFile at path '%s' does not exist." % self.path)
|
||||
raise RuntimeError("StaticFile at path '%s' does not exist." % self.path)
|
||||
else:
|
||||
mode = stat_result.st_mode
|
||||
if not stat.S_ISREG(mode):
|
||||
raise RuntimeError("StaticFile at path '%s' is not a file." % self.path)
|
||||
|
||||
response = FileResponse(self.path, stat_result=stat_result)
|
||||
await response(receive, send)
|
||||
|
||||
|
||||
class _StaticFilesResponder:
|
||||
def __init__(self, scope, path, check_directory=None):
|
||||
self.scope = scope
|
||||
self.path = path
|
||||
self.check_directory = check_directory
|
||||
|
||||
async def check_directory_configured_correctly(self):
|
||||
"""
|
||||
Perform a one-off configuration check that StaticFiles is actually
|
||||
pointed at a directory, so that we can raise loud errors rather than
|
||||
just returning 404 responses.
|
||||
"""
|
||||
directory = self.check_directory
|
||||
try:
|
||||
stat_result = await aio_stat(directory)
|
||||
except FileNotFoundError:
|
||||
raise RuntimeError("StaticFiles directory '%s' does not exist." % directory)
|
||||
if not stat.S_ISDIR(stat_result.st_mode):
|
||||
raise RuntimeError("StaticFiles path '%s' is not a directory." % directory)
|
||||
|
||||
async def __call__(self, receive, send):
|
||||
if self.check_directory is not None:
|
||||
await self.check_directory_configured_correctly()
|
||||
|
||||
try:
|
||||
stat_result = await aio_stat(self.path)
|
||||
except FileNotFoundError:
|
||||
response = PlainTextResponse('Not found', status_code=404)
|
||||
else:
|
||||
mode = stat_result.st_mode
|
||||
if stat.S_ISREG(mode) or stat.S_ISLNK(mode):
|
||||
response = FileResponse(self.path, stat_result=stat_result)
|
||||
else:
|
||||
if not self.allow_404:
|
||||
raise RuntimeError("StaticFile at path '%s' is not a file." % self.path)
|
||||
if not stat.S_ISREG(mode):
|
||||
response = PlainTextResponse('Not found', status_code=404)
|
||||
else:
|
||||
response = FileResponse(self.path, stat_result=stat_result)
|
||||
|
||||
await response(receive, send)
|
||||
|
|
|
@ -91,3 +91,34 @@ def test_staticfiles_with_missing_file_returns_404(tmpdir):
|
|||
response = client.get("/404.txt")
|
||||
assert response.status_code == 404
|
||||
assert response.text == 'Not found'
|
||||
|
||||
|
||||
def test_staticfiles_configured_with_missing_directory(tmpdir):
|
||||
path = os.path.join(tmpdir, "no_such_directory")
|
||||
app = StaticFiles(directory=path)
|
||||
client = TestClient(app)
|
||||
with pytest.raises(RuntimeError) as exc:
|
||||
response = client.get("/example.txt")
|
||||
assert 'does not exist' in str(exc)
|
||||
|
||||
|
||||
def test_staticfiles_configured_with_file_instead_of_directory(tmpdir):
|
||||
path = os.path.join(tmpdir, "example.txt")
|
||||
with open(path, "w") as file:
|
||||
file.write("<file content>")
|
||||
|
||||
app = StaticFiles(directory=path)
|
||||
client = TestClient(app)
|
||||
with pytest.raises(RuntimeError) as exc:
|
||||
response = client.get("/example.txt")
|
||||
assert 'is not a directory' in str(exc)
|
||||
|
||||
|
||||
def test_staticfiles_config_check_occurs_only_once(tmpdir):
|
||||
app = StaticFiles(directory=tmpdir)
|
||||
client = TestClient(app)
|
||||
assert not app.config_checked
|
||||
response = client.get("/")
|
||||
assert app.config_checked
|
||||
response = client.get("/")
|
||||
assert app.config_checked
|
||||
|
|
Loading…
Reference in New Issue