From ed16b7df21174892095a0665c1e816a2a2ed306e Mon Sep 17 00:00:00 2001 From: Tom Christie Date: Fri, 23 Nov 2018 11:11:16 +0000 Subject: [PATCH] Asyncio cleanups (#236) --- starlette/middleware/wsgi.py | 21 +++++++++++++-------- starlette/testclient.py | 6 +++++- tests/test_testclient.py | 32 ++++++++++++++++++++++++++++++++ 3 files changed, 50 insertions(+), 9 deletions(-) create mode 100644 tests/test_testclient.py diff --git a/starlette/middleware/wsgi.py b/starlette/middleware/wsgi.py index b8006dac..eedb144b 100644 --- a/starlette/middleware/wsgi.py +++ b/starlette/middleware/wsgi.py @@ -81,14 +81,19 @@ class WSGIResponder: body += message.get("body", b"") more_body = message.get("more_body", False) environ = build_environ(self.scope, body) - wsgi = run_in_threadpool(self.wsgi, environ, self.start_response) - sender = self.loop.create_task(self.sender(send)) - await asyncio.wait_for(wsgi, None) - self.send_queue.append(None) - self.send_event.set() - await asyncio.wait_for(sender, None) - if self.exc_info is not None: - raise self.exc_info[0].with_traceback(self.exc_info[1], self.exc_info[2]) + try: + sender = self.loop.create_task(self.sender(send)) + await run_in_threadpool(self.wsgi, environ, self.start_response) + self.send_queue.append(None) + self.send_event.set() + await asyncio.wait_for(sender, None) + if self.exc_info is not None: + raise self.exc_info[0].with_traceback( + self.exc_info[1], self.exc_info[2] + ) + finally: + if not sender.done(): + sender.cancel() # pragma: no cover async def sender(self, send: Send) -> None: while True: diff --git a/starlette/testclient.py b/starlette/testclient.py index ee575407..faac9839 100644 --- a/starlette/testclient.py +++ b/starlette/testclient.py @@ -181,7 +181,11 @@ class _ASGIAdapter(requests.adapters.HTTPAdapter): response_complete = False raw_kwargs = {"body": io.BytesIO()} # type: typing.Dict[str, typing.Any] - loop = asyncio.get_event_loop() + try: + loop = asyncio.get_event_loop() + except RuntimeError: + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) try: connection = self.app(scope) diff --git a/tests/test_testclient.py b/tests/test_testclient.py new file mode 100644 index 00000000..e6a81988 --- /dev/null +++ b/tests/test_testclient.py @@ -0,0 +1,32 @@ +from starlette.applications import Starlette +from starlette.responses import JSONResponse +from starlette.testclient import TestClient + +mock_service = Starlette() + + +@mock_service.route("/") +def mock_service_endpoint(request): + return JSONResponse({"mock": "example"}) + + +app = Starlette() + + +@app.route("/") +def homepage(request): + client = TestClient(mock_service) + response = client.get("/") + return JSONResponse(response.json()) + + +def test_use_testclient_in_endpoint(): + """ + We should be able to use the test client within applications. + + This is useful if we need to mock out other services, + during tests or in development. + """ + client = TestClient(app) + response = client.get("/") + assert response.json() == {"mock": "example"}