From 32970172f654b76f8d200683558cd15d13c16c67 Mon Sep 17 00:00:00 2001 From: Martin Chang Date: Tue, 18 May 2021 19:20:15 +0800 Subject: [PATCH] Make AsyncTask only destruct when the coroutine reaches end of execution (#857) --- examples/simple_example/api_v1_CoroTest.cc | 5 ++ .../simple_example_test/WebSocketCoroTest.cc | 4 +- lib/inc/drogon/utils/coroutine.h | 65 ++++++++++--------- 3 files changed, 42 insertions(+), 32 deletions(-) diff --git a/examples/simple_example/api_v1_CoroTest.cc b/examples/simple_example/api_v1_CoroTest.cc index 05d3dffb..df05503b 100644 --- a/examples/simple_example/api_v1_CoroTest.cc +++ b/examples/simple_example/api_v1_CoroTest.cc @@ -4,6 +4,11 @@ using namespace api::v1; Task<> CoroTest::get(HttpRequestPtr req, std::function callback) { + // Force co_await to test awaiting works + co_await drogon::sleepCoro( + trantor::EventLoop::getEventLoopOfCurrentThread(), + std::chrono::milliseconds(100)); + auto resp = HttpResponse::newHttpResponse(); resp->setBody("DEADBEEF"); callback(resp); diff --git a/examples/simple_example_test/WebSocketCoroTest.cc b/examples/simple_example_test/WebSocketCoroTest.cc index d9e2bca2..132e4dca 100644 --- a/examples/simple_example_test/WebSocketCoroTest.cc +++ b/examples/simple_example_test/WebSocketCoroTest.cc @@ -74,9 +74,7 @@ int main(int argc, char* argv[]) }); app().setLogLevel(trantor::Logger::kTrace); - auto test = [=]() -> AsyncTask { - co_await doTest(wsPtr, req, continually); - }(); + [=]() -> AsyncTask { co_await doTest(wsPtr, req, continually); }(); app().run(); } diff --git a/lib/inc/drogon/utils/coroutine.h b/lib/inc/drogon/utils/coroutine.h index f752bd2f..7833889e 100644 --- a/lib/inc/drogon/utils/coroutine.h +++ b/lib/inc/drogon/utils/coroutine.h @@ -392,23 +392,23 @@ struct AsyncTask struct promise_type; using handle_type = std::coroutine_handle; + AsyncTask() = default; + AsyncTask(handle_type h) : coro_(h) { + if (coro_) + coro_.promise().setSelf(coro_); } AsyncTask(const AsyncTask &) = delete; ~AsyncTask() { - if (coro_) - coro_.destroy(); } AsyncTask &operator=(const AsyncTask &) = delete; AsyncTask &operator=(AsyncTask &&other) { if (std::addressof(other) == this) return *this; - if (coro_) - coro_.destroy(); coro_ = other.coro_; other.coro_ = nullptr; @@ -418,6 +418,7 @@ struct AsyncTask struct promise_type { std::coroutine_handle<> continuation_; + handle_type self_; AsyncTask get_return_object() noexcept { @@ -444,10 +445,28 @@ struct AsyncTask continuation_ = handle; } + void setSelf(handle_type handle) + { + self_ = handle; + } + auto final_suspend() const noexcept { - struct awaiter + struct awaiter final { + awaiter(handle_type h) : self_(h) + { + } + + awaiter(const awaiter &) = delete; + awaiter &operator=(const awaiter &) = delete; + + ~awaiter() + { + if (self_) + self_.destroy(); + } + bool await_ready() const noexcept { return false; @@ -466,8 +485,11 @@ struct AsyncTask return std::noop_coroutine(); } + + handle_type self_; }; - return awaiter{}; + + return awaiter(self_); } }; bool await_ready() const noexcept @@ -479,7 +501,7 @@ struct AsyncTask { } - void await_suspend(std::coroutine_handle<> coroutine) const noexcept + void await_suspend(std::coroutine_handle<> coroutine) noexcept { coro_.promise().setContinuation(coroutine); } @@ -585,13 +607,9 @@ auto sync_wait(Await &&await) cv.notify_all(); }; - // HACK: Workarround coroutine frame destructing too early by enforcing - // manual lifetime - AsyncTask *taskPtr; - std::thread thr([&]() { taskPtr = new AsyncTask{task()}; }); + std::thread thr([&]() { task(); }); cv.wait(lk, [&]() { return (bool)flag; }); thr.join(); - delete taskPtr; if (exception_ptr) std::rethrow_exception(exception_ptr); } @@ -612,14 +630,11 @@ auto sync_wait(Await &&await) cv.notify_all(); }; - // HACK: Workarround coroutine frame destructing too early by enforcing - // manual lifetime - AsyncTask *taskPtr; - std::thread thr([&]() { taskPtr = new AsyncTask{task()}; }); + std::thread thr([&]() { task(); }); cv.wait(lk, [&]() { return (bool)flag; }); assert(value.has_value() == true || exception_ptr); thr.join(); - delete taskPtr; + if (exception_ptr) std::rethrow_exception(exception_ptr); return value.value(); @@ -634,11 +649,9 @@ inline auto co_future(Await await) noexcept using Result = await_result_t; std::promise prom; auto fut = prom.get_future(); - std::promise selfProm; - auto selfFut = selfProm.get_future(); - auto task = [](std::promise prom, - Await await, - std::future selfFut) -> AsyncTask { + [](std::promise prom, + Await await, + std::future selfFut) mutable -> AsyncTask { try { if constexpr (std::is_void_v) @@ -653,13 +666,7 @@ inline auto co_future(Await await) noexcept { prom.set_exception(std::current_exception()); } - - AsyncTask *self = selfFut.get(); - delete self; - }; - AsyncTask *taskPtr = new AsyncTask{ - task(std::move(prom), std::move(await), std::move(selfFut))}; - selfProm.set_value(taskPtr); + }(); return fut; }