Configuration checks for StaticFiles

This commit is contained in:
Tom Christie 2018-07-12 13:41:46 +01:00
parent 2ed6768352
commit 6259df6e7a
2 changed files with 82 additions and 12 deletions

View File

@ -5,46 +5,85 @@ import stat
class StaticFile:
def __init__(self, *, path):
def __init__(self, path):
self.path = path
def __call__(self, scope):
if scope['method'] not in ('GET', 'HEAD'):
return PlainTextResponse('Method not allowed', status_code=406)
return _StaticFileResponder(scope, path=self.path, allow_404=False)
return _StaticFileResponder(scope, path=self.path)
class StaticFiles:
def __init__(self, *, directory):
self.directory = directory
self.config_checked = False
def __call__(self, scope):
if scope['method'] not in ('GET', 'HEAD'):
return PlainTextResponse('Method not allowed', status_code=406)
split_path = scope['path'].split('/')
path = os.path.join(self.directory, *split_path)
return _StaticFileResponder(scope, path=path, allow_404=True)
if self.config_checked:
check_directory = None
else:
check_directory = self.directory
self.config_checked = True
return _StaticFilesResponder(scope, path=path, check_directory=check_directory)
class _StaticFileResponder:
def __init__(self, scope, path, allow_404):
def __init__(self, scope, path):
self.scope = scope
self.path = path
self.allow_404 = allow_404
async def __call__(self, receive, send):
try:
stat_result = await aio_stat(self.path)
except FileNotFoundError:
if not self.allow_404:
raise RuntimeError("StaticFile at path '%s' does not exist." % self.path)
raise RuntimeError("StaticFile at path '%s' does not exist." % self.path)
else:
mode = stat_result.st_mode
if not stat.S_ISREG(mode):
raise RuntimeError("StaticFile at path '%s' is not a file." % self.path)
response = FileResponse(self.path, stat_result=stat_result)
await response(receive, send)
class _StaticFilesResponder:
def __init__(self, scope, path, check_directory=None):
self.scope = scope
self.path = path
self.check_directory = check_directory
async def check_directory_configured_correctly(self):
"""
Perform a one-off configuration check that StaticFiles is actually
pointed at a directory, so that we can raise loud errors rather than
just returning 404 responses.
"""
directory = self.check_directory
try:
stat_result = await aio_stat(directory)
except FileNotFoundError:
raise RuntimeError("StaticFiles directory '%s' does not exist." % directory)
if not stat.S_ISDIR(stat_result.st_mode):
raise RuntimeError("StaticFiles path '%s' is not a directory." % directory)
async def __call__(self, receive, send):
if self.check_directory is not None:
await self.check_directory_configured_correctly()
try:
stat_result = await aio_stat(self.path)
except FileNotFoundError:
response = PlainTextResponse('Not found', status_code=404)
else:
mode = stat_result.st_mode
if stat.S_ISREG(mode) or stat.S_ISLNK(mode):
response = FileResponse(self.path, stat_result=stat_result)
else:
if not self.allow_404:
raise RuntimeError("StaticFile at path '%s' is not a file." % self.path)
if not stat.S_ISREG(mode):
response = PlainTextResponse('Not found', status_code=404)
else:
response = FileResponse(self.path, stat_result=stat_result)
await response(receive, send)

View File

@ -91,3 +91,34 @@ def test_staticfiles_with_missing_file_returns_404(tmpdir):
response = client.get("/404.txt")
assert response.status_code == 404
assert response.text == 'Not found'
def test_staticfiles_configured_with_missing_directory(tmpdir):
path = os.path.join(tmpdir, "no_such_directory")
app = StaticFiles(directory=path)
client = TestClient(app)
with pytest.raises(RuntimeError) as exc:
response = client.get("/example.txt")
assert 'does not exist' in str(exc)
def test_staticfiles_configured_with_file_instead_of_directory(tmpdir):
path = os.path.join(tmpdir, "example.txt")
with open(path, "w") as file:
file.write("<file content>")
app = StaticFiles(directory=path)
client = TestClient(app)
with pytest.raises(RuntimeError) as exc:
response = client.get("/example.txt")
assert 'is not a directory' in str(exc)
def test_staticfiles_config_check_occurs_only_once(tmpdir):
app = StaticFiles(directory=tmpdir)
client = TestClient(app)
assert not app.config_checked
response = client.get("/")
assert app.config_checked
response = client.get("/")
assert app.config_checked