2018-09-05 10:39:38 +00:00
|
|
|
from starlette.applications import Starlette
|
2018-10-05 11:04:11 +00:00
|
|
|
from starlette.datastructures import Headers
|
2018-09-04 10:52:29 +00:00
|
|
|
from starlette.exceptions import HTTPException
|
2018-09-05 09:29:04 +00:00
|
|
|
from starlette.responses import JSONResponse, PlainTextResponse
|
2018-08-28 13:34:18 +00:00
|
|
|
from starlette.staticfiles import StaticFiles
|
|
|
|
from starlette.testclient import TestClient
|
2018-09-05 09:29:04 +00:00
|
|
|
from starlette.endpoints import HTTPEndpoint
|
2018-08-28 13:34:18 +00:00
|
|
|
import os
|
|
|
|
|
|
|
|
|
2018-10-05 11:04:11 +00:00
|
|
|
class TrustedHostMiddleware:
|
|
|
|
def __init__(self, app, hostname):
|
|
|
|
self.app = app
|
|
|
|
self.hostname = hostname
|
|
|
|
|
|
|
|
def __call__(self, scope):
|
2018-10-17 11:31:53 +00:00
|
|
|
headers = Headers(scope=scope)
|
2018-10-05 11:04:11 +00:00
|
|
|
if headers.get("host") != self.hostname:
|
|
|
|
return PlainTextResponse("Invalid host header", status_code=400)
|
|
|
|
return self.app(scope)
|
|
|
|
|
|
|
|
|
2018-09-05 09:29:04 +00:00
|
|
|
app = Starlette()
|
2018-08-28 13:34:18 +00:00
|
|
|
|
|
|
|
|
2018-10-05 11:04:11 +00:00
|
|
|
app.add_middleware(TrustedHostMiddleware, hostname="testserver")
|
|
|
|
|
|
|
|
|
2018-09-04 10:52:29 +00:00
|
|
|
@app.exception_handler(Exception)
|
|
|
|
async def error_500(request, exc):
|
|
|
|
return JSONResponse({"detail": "Server Error"}, status_code=500)
|
|
|
|
|
|
|
|
|
|
|
|
@app.exception_handler(HTTPException)
|
|
|
|
async def handler(request, exc):
|
|
|
|
return JSONResponse({"detail": exc.detail}, status_code=exc.status_code)
|
|
|
|
|
|
|
|
|
2018-08-28 13:34:18 +00:00
|
|
|
@app.route("/func")
|
|
|
|
def func_homepage(request):
|
|
|
|
return PlainTextResponse("Hello, world!")
|
|
|
|
|
|
|
|
|
|
|
|
@app.route("/async")
|
|
|
|
async def async_homepage(request):
|
|
|
|
return PlainTextResponse("Hello, world!")
|
|
|
|
|
|
|
|
|
2018-09-04 10:52:29 +00:00
|
|
|
@app.route("/class")
|
2018-09-05 09:29:04 +00:00
|
|
|
class Homepage(HTTPEndpoint):
|
2018-09-04 10:52:29 +00:00
|
|
|
def get(self, request):
|
|
|
|
return PlainTextResponse("Hello, world!")
|
|
|
|
|
|
|
|
|
2018-08-28 13:34:18 +00:00
|
|
|
@app.route("/user/{username}")
|
|
|
|
def user_page(request, username):
|
|
|
|
return PlainTextResponse("Hello, %s!" % username)
|
|
|
|
|
|
|
|
|
2018-09-04 10:52:29 +00:00
|
|
|
@app.route("/500")
|
2018-10-16 14:56:28 +00:00
|
|
|
def runtime_error(request):
|
2018-09-04 10:52:29 +00:00
|
|
|
raise RuntimeError()
|
|
|
|
|
|
|
|
|
2018-08-28 13:34:18 +00:00
|
|
|
@app.websocket_route("/ws")
|
|
|
|
async def websocket_endpoint(session):
|
|
|
|
await session.accept()
|
|
|
|
await session.send_text("Hello, world!")
|
|
|
|
await session.close()
|
|
|
|
|
|
|
|
|
|
|
|
client = TestClient(app)
|
|
|
|
|
|
|
|
|
|
|
|
def test_func_route():
|
|
|
|
response = client.get("/func")
|
|
|
|
assert response.status_code == 200
|
|
|
|
assert response.text == "Hello, world!"
|
|
|
|
|
|
|
|
|
|
|
|
def test_async_route():
|
|
|
|
response = client.get("/async")
|
|
|
|
assert response.status_code == 200
|
|
|
|
assert response.text == "Hello, world!"
|
|
|
|
|
|
|
|
|
2018-09-04 10:52:29 +00:00
|
|
|
def test_class_route():
|
|
|
|
response = client.get("/class")
|
|
|
|
assert response.status_code == 200
|
|
|
|
assert response.text == "Hello, world!"
|
|
|
|
|
|
|
|
|
2018-08-28 13:34:18 +00:00
|
|
|
def test_route_kwargs():
|
|
|
|
response = client.get("/user/tomchristie")
|
|
|
|
assert response.status_code == 200
|
|
|
|
assert response.text == "Hello, tomchristie!"
|
|
|
|
|
|
|
|
|
|
|
|
def test_websocket_route():
|
2018-08-28 13:45:06 +00:00
|
|
|
with client.websocket_connect("/ws") as session:
|
2018-08-28 13:34:18 +00:00
|
|
|
text = session.receive_text()
|
|
|
|
assert text == "Hello, world!"
|
|
|
|
|
|
|
|
|
|
|
|
def test_400():
|
|
|
|
response = client.get("/404")
|
|
|
|
assert response.status_code == 404
|
2018-09-04 10:52:29 +00:00
|
|
|
assert response.json() == {"detail": "Not Found"}
|
|
|
|
|
|
|
|
|
|
|
|
def test_405():
|
|
|
|
response = client.post("/func")
|
|
|
|
assert response.status_code == 405
|
|
|
|
assert response.json() == {"detail": "Method Not Allowed"}
|
|
|
|
|
|
|
|
response = client.post("/class")
|
|
|
|
assert response.status_code == 405
|
|
|
|
assert response.json() == {"detail": "Method Not Allowed"}
|
|
|
|
|
|
|
|
|
|
|
|
def test_500():
|
|
|
|
client = TestClient(app, raise_server_exceptions=False)
|
|
|
|
response = client.get("/500")
|
|
|
|
assert response.status_code == 500
|
|
|
|
assert response.json() == {"detail": "Server Error"}
|
2018-08-28 13:34:18 +00:00
|
|
|
|
|
|
|
|
2018-10-05 11:04:11 +00:00
|
|
|
def test_middleware():
|
|
|
|
client = TestClient(app, base_url="http://incorrecthost")
|
|
|
|
response = client.get("/func")
|
|
|
|
assert response.status_code == 400
|
|
|
|
assert response.text == "Invalid host header"
|
|
|
|
|
|
|
|
|
2018-08-28 13:34:18 +00:00
|
|
|
def test_app_mount(tmpdir):
|
|
|
|
path = os.path.join(tmpdir, "example.txt")
|
|
|
|
with open(path, "w") as file:
|
|
|
|
file.write("<file content>")
|
|
|
|
|
2018-09-05 09:29:04 +00:00
|
|
|
app = Starlette()
|
2018-09-04 10:52:29 +00:00
|
|
|
app.mount("/static", StaticFiles(directory=tmpdir), methods=["GET", "HEAD"])
|
2018-09-05 09:29:04 +00:00
|
|
|
|
2018-08-28 13:34:18 +00:00
|
|
|
client = TestClient(app)
|
2018-09-04 10:52:29 +00:00
|
|
|
|
2018-08-28 13:34:18 +00:00
|
|
|
response = client.get("/static/example.txt")
|
|
|
|
assert response.status_code == 200
|
|
|
|
assert response.text == "<file content>"
|
2018-09-04 10:52:29 +00:00
|
|
|
|
|
|
|
response = client.post("/static/example.txt")
|
|
|
|
assert response.status_code == 405
|
|
|
|
assert response.text == "Method Not Allowed"
|
|
|
|
|
|
|
|
|
|
|
|
def test_app_debug():
|
2018-09-05 09:29:04 +00:00
|
|
|
app = Starlette()
|
2018-09-04 10:52:29 +00:00
|
|
|
app.debug = True
|
|
|
|
|
|
|
|
@app.route("/")
|
|
|
|
async def homepage(request):
|
|
|
|
raise RuntimeError()
|
|
|
|
|
|
|
|
client = TestClient(app, raise_server_exceptions=False)
|
|
|
|
response = client.get("/")
|
|
|
|
assert response.status_code == 500
|
|
|
|
assert "RuntimeError" in response.text
|
|
|
|
assert app.debug
|