Add coroutine mutex (#2095)
This commit is contained in:
parent
0546032edc
commit
c46f149c2c
|
@ -811,4 +811,180 @@ inline internal::EventLoopAwaiter<T> queueInLoopCoro(trantor::EventLoop *loop,
|
|||
return internal::EventLoopAwaiter<T>(std::move(task), loop);
|
||||
}
|
||||
|
||||
class Mutex final
|
||||
{
|
||||
class ScopedCoroMutexAwaiter;
|
||||
class CoroMutexAwaiter;
|
||||
|
||||
public:
|
||||
Mutex() noexcept : state_(unlockedValue()), waiters_(nullptr)
|
||||
{
|
||||
}
|
||||
|
||||
Mutex(const Mutex &) = delete;
|
||||
Mutex(Mutex &&) = delete;
|
||||
Mutex &operator=(const Mutex &) = delete;
|
||||
Mutex &operator=(Mutex &&) = delete;
|
||||
|
||||
~Mutex()
|
||||
{
|
||||
[[maybe_unused]] auto state = state_.load(std::memory_order_relaxed);
|
||||
assert(state == unlockedValue() || state == nullptr);
|
||||
assert(waiters_ == nullptr);
|
||||
}
|
||||
|
||||
bool try_lock() noexcept
|
||||
{
|
||||
void *oldValue = unlockedValue();
|
||||
return state_.compare_exchange_strong(oldValue,
|
||||
nullptr,
|
||||
std::memory_order_acquire,
|
||||
std::memory_order_relaxed);
|
||||
}
|
||||
|
||||
[[nodiscard]] ScopedCoroMutexAwaiter scoped_lock(
|
||||
trantor::EventLoop *loop =
|
||||
trantor::EventLoop::getEventLoopOfCurrentThread()) noexcept
|
||||
{
|
||||
return ScopedCoroMutexAwaiter(*this, loop);
|
||||
}
|
||||
|
||||
[[nodiscard]] CoroMutexAwaiter lock(
|
||||
trantor::EventLoop *loop =
|
||||
trantor::EventLoop::getEventLoopOfCurrentThread()) noexcept
|
||||
{
|
||||
return CoroMutexAwaiter(*this, loop);
|
||||
}
|
||||
|
||||
void unlock() noexcept
|
||||
{
|
||||
assert(state_.load(std::memory_order_relaxed) != unlockedValue());
|
||||
auto *waitersHead = waiters_;
|
||||
if (waitersHead == nullptr)
|
||||
{
|
||||
void *currentState = state_.load(std::memory_order_relaxed);
|
||||
if (currentState == nullptr)
|
||||
{
|
||||
const bool releasedLock =
|
||||
state_.compare_exchange_strong(currentState,
|
||||
unlockedValue(),
|
||||
std::memory_order_release,
|
||||
std::memory_order_relaxed);
|
||||
if (releasedLock)
|
||||
{
|
||||
return;
|
||||
}
|
||||
}
|
||||
currentState = state_.exchange(nullptr, std::memory_order_acquire);
|
||||
assert(currentState != unlockedValue());
|
||||
assert(currentState != nullptr);
|
||||
auto *waiter = static_cast<CoroMutexAwaiter *>(currentState);
|
||||
do
|
||||
{
|
||||
auto *temp = waiter->next_;
|
||||
waiter->next_ = waitersHead;
|
||||
waitersHead = waiter;
|
||||
waiter = temp;
|
||||
} while (waiter != nullptr);
|
||||
}
|
||||
assert(waitersHead != nullptr);
|
||||
waiters_ = waitersHead->next_;
|
||||
if (waitersHead->loop_)
|
||||
{
|
||||
auto handle = waitersHead->handle_;
|
||||
waitersHead->loop_->runInLoop([handle] { handle.resume(); });
|
||||
}
|
||||
else
|
||||
{
|
||||
waitersHead->handle_.resume();
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
class CoroMutexAwaiter
|
||||
{
|
||||
public:
|
||||
CoroMutexAwaiter(Mutex &mutex, trantor::EventLoop *loop) noexcept
|
||||
: mutex_(mutex), loop_(loop)
|
||||
{
|
||||
}
|
||||
|
||||
bool await_ready() noexcept
|
||||
{
|
||||
return mutex_.try_lock();
|
||||
}
|
||||
|
||||
bool await_suspend(std::coroutine_handle<> handle) noexcept
|
||||
{
|
||||
handle_ = handle;
|
||||
return mutex_.asynclockImpl(this);
|
||||
}
|
||||
|
||||
void await_resume() noexcept
|
||||
{
|
||||
}
|
||||
|
||||
private:
|
||||
friend class Mutex;
|
||||
|
||||
Mutex &mutex_;
|
||||
trantor::EventLoop *loop_;
|
||||
std::coroutine_handle<> handle_;
|
||||
CoroMutexAwaiter *next_;
|
||||
};
|
||||
|
||||
class ScopedCoroMutexAwaiter : public CoroMutexAwaiter
|
||||
{
|
||||
public:
|
||||
ScopedCoroMutexAwaiter(Mutex &mutex, trantor::EventLoop *loop)
|
||||
: CoroMutexAwaiter(mutex, loop)
|
||||
{
|
||||
}
|
||||
|
||||
[[nodiscard]] auto await_resume() noexcept
|
||||
{
|
||||
return std::unique_lock<Mutex>{mutex_, std::adopt_lock};
|
||||
}
|
||||
};
|
||||
|
||||
bool asynclockImpl(CoroMutexAwaiter *awaiter)
|
||||
{
|
||||
void *oldValue = state_.load(std::memory_order_relaxed);
|
||||
while (true)
|
||||
{
|
||||
if (oldValue == unlockedValue())
|
||||
{
|
||||
void *newValue = nullptr;
|
||||
if (state_.compare_exchange_weak(oldValue,
|
||||
newValue,
|
||||
std::memory_order_acquire,
|
||||
std::memory_order_relaxed))
|
||||
{
|
||||
return false;
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
void *newValue = awaiter;
|
||||
awaiter->next_ = static_cast<CoroMutexAwaiter *>(oldValue);
|
||||
if (state_.compare_exchange_weak(oldValue,
|
||||
newValue,
|
||||
std::memory_order_release,
|
||||
std::memory_order_relaxed))
|
||||
{
|
||||
return true;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void *unlockedValue() noexcept
|
||||
{
|
||||
return this;
|
||||
}
|
||||
|
||||
std::atomic<void *> state_;
|
||||
CoroMutexAwaiter *waiters_;
|
||||
};
|
||||
|
||||
} // namespace drogon
|
||||
|
|
|
@ -2,6 +2,10 @@
|
|||
#include <drogon/utils/coroutine.h>
|
||||
#include <drogon/HttpAppFramework.h>
|
||||
#include <trantor/net/EventLoopThread.h>
|
||||
#include <trantor/net/EventLoopThreadPool.h>
|
||||
#include <chrono>
|
||||
#include <cstdint>
|
||||
#include <future>
|
||||
#include <type_traits>
|
||||
|
||||
using namespace drogon;
|
||||
|
@ -212,3 +216,32 @@ DROGON_TEST(SwitchThread)
|
|||
sync_wait(switch_thread());
|
||||
thread.wait();
|
||||
}
|
||||
|
||||
DROGON_TEST(Mutex)
|
||||
{
|
||||
trantor::EventLoopThreadPool pool{3};
|
||||
pool.start();
|
||||
Mutex mutex;
|
||||
async_run([&]() -> Task<> {
|
||||
co_await switchThreadCoro(pool.getLoop(0));
|
||||
auto guard = co_await mutex.scoped_lock();
|
||||
co_await sleepCoro(pool.getLoop(1), std::chrono::seconds(2));
|
||||
co_return;
|
||||
});
|
||||
std::this_thread::sleep_for(std::chrono::milliseconds(100));
|
||||
std::promise<void> done;
|
||||
async_run([&]() -> Task<> {
|
||||
co_await switchThreadCoro(pool.getLoop(2));
|
||||
auto id = std::this_thread::get_id();
|
||||
co_await mutex.lock();
|
||||
CHECK(id == std::this_thread::get_id());
|
||||
mutex.unlock();
|
||||
CHECK(id == std::this_thread::get_id());
|
||||
done.set_value();
|
||||
co_return;
|
||||
});
|
||||
done.get_future().wait();
|
||||
for (int16_t i = 0; i < 3; i++)
|
||||
pool.getLoop(i)->quit();
|
||||
pool.wait();
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue