diff --git a/docs/release-notes.md b/docs/release-notes.md index 72e657a5..05aa13cf 100644 --- a/docs/release-notes.md +++ b/docs/release-notes.md @@ -1,3 +1,7 @@ +## 0.9.7 + +* Ensure that `AuthenticationMiddleware` handles lifespan messages correctly. + ## 0.9.6 * Add `AuthenticationMiddleware`, and `@requires()` decorator. diff --git a/starlette/__init__.py b/starlette/__init__.py index 50533e30..f5b77301 100644 --- a/starlette/__init__.py +++ b/starlette/__init__.py @@ -1 +1 @@ -__version__ = "0.9.6" +__version__ = "0.9.7" diff --git a/starlette/middleware/authentication.py b/starlette/middleware/authentication.py index 58ea93cd..b9d90e8b 100644 --- a/starlette/middleware/authentication.py +++ b/starlette/middleware/authentication.py @@ -17,7 +17,9 @@ class AuthenticationMiddleware: self.backend = backend def __call__(self, scope: Scope) -> ASGIInstance: - return functools.partial(self.asgi, scope=scope) + if scope["type"] in ["http", "websockets"]: + return functools.partial(self.asgi, scope=scope) + return self.app(scope) async def asgi(self, receive: Receive, send: Send, scope: Scope) -> None: request = Request(scope, receive=receive) diff --git a/tests/test_authentication.py b/tests/test_authentication.py index 39d70e49..9341c71f 100644 --- a/tests/test_authentication.py +++ b/tests/test_authentication.py @@ -89,52 +89,52 @@ def admin(request): ) -client = TestClient(app) - - def test_user_interface(): - response = client.get("/") - assert response.status_code == 200 - assert response.json() == {"authenticated": False, "user": ""} + with TestClient(app) as client: + response = client.get("/") + assert response.status_code == 200 + assert response.json() == {"authenticated": False, "user": ""} - response = client.get("/", auth=("tomchristie", "example")) - assert response.status_code == 200 - assert response.json() == {"authenticated": True, "user": "tomchristie"} + response = client.get("/", auth=("tomchristie", "example")) + assert response.status_code == 200 + assert response.json() == {"authenticated": True, "user": "tomchristie"} def test_authentication_required(): - response = client.get("/dashboard") - assert response.status_code == 403 + with TestClient(app) as client: + response = client.get("/dashboard") + assert response.status_code == 403 - response = client.get("/dashboard", auth=("tomchristie", "example")) - assert response.status_code == 200 - assert response.json() == {"authenticated": True, "user": "tomchristie"} + response = client.get("/dashboard", auth=("tomchristie", "example")) + assert response.status_code == 200 + assert response.json() == {"authenticated": True, "user": "tomchristie"} - response = client.get("/dashboard/sync") - assert response.status_code == 403 + response = client.get("/dashboard/sync") + assert response.status_code == 403 - response = client.get("/dashboard/sync", auth=("tomchristie", "example")) - assert response.status_code == 200 - assert response.json() == {"authenticated": True, "user": "tomchristie"} + response = client.get("/dashboard/sync", auth=("tomchristie", "example")) + assert response.status_code == 200 + assert response.json() == {"authenticated": True, "user": "tomchristie"} - response = client.get("/dashboard", headers={"Authorization": "basic foobar"}) - assert response.status_code == 400 - assert response.text == "Invalid basic auth credentials" + response = client.get("/dashboard", headers={"Authorization": "basic foobar"}) + assert response.status_code == 400 + assert response.text == "Invalid basic auth credentials" def test_authentication_redirect(): - response = client.get("/admin") - assert response.status_code == 200 - assert response.url == "http://testserver/" + with TestClient(app) as client: + response = client.get("/admin") + assert response.status_code == 200 + assert response.url == "http://testserver/" - response = client.get("/admin", auth=("tomchristie", "example")) - assert response.status_code == 200 - assert response.json() == {"authenticated": True, "user": "tomchristie"} + response = client.get("/admin", auth=("tomchristie", "example")) + assert response.status_code == 200 + assert response.json() == {"authenticated": True, "user": "tomchristie"} - response = client.get("/admin/sync") - assert response.status_code == 200 - assert response.url == "http://testserver/" + response = client.get("/admin/sync") + assert response.status_code == 200 + assert response.url == "http://testserver/" - response = client.get("/admin/sync", auth=("tomchristie", "example")) - assert response.status_code == 200 - assert response.json() == {"authenticated": True, "user": "tomchristie"} + response = client.get("/admin/sync", auth=("tomchristie", "example")) + assert response.status_code == 200 + assert response.json() == {"authenticated": True, "user": "tomchristie"}