Add coroutine mutex (#2095)

This commit is contained in:
fantasy-peak 2024-08-08 15:17:06 +08:00 committed by GitHub
parent 0546032edc
commit c46f149c2c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 209 additions and 0 deletions

View File

@ -811,4 +811,180 @@ inline internal::EventLoopAwaiter<T> queueInLoopCoro(trantor::EventLoop *loop,
return internal::EventLoopAwaiter<T>(std::move(task), loop);
}
class Mutex final
{
class ScopedCoroMutexAwaiter;
class CoroMutexAwaiter;
public:
Mutex() noexcept : state_(unlockedValue()), waiters_(nullptr)
{
}
Mutex(const Mutex &) = delete;
Mutex(Mutex &&) = delete;
Mutex &operator=(const Mutex &) = delete;
Mutex &operator=(Mutex &&) = delete;
~Mutex()
{
[[maybe_unused]] auto state = state_.load(std::memory_order_relaxed);
assert(state == unlockedValue() || state == nullptr);
assert(waiters_ == nullptr);
}
bool try_lock() noexcept
{
void *oldValue = unlockedValue();
return state_.compare_exchange_strong(oldValue,
nullptr,
std::memory_order_acquire,
std::memory_order_relaxed);
}
[[nodiscard]] ScopedCoroMutexAwaiter scoped_lock(
trantor::EventLoop *loop =
trantor::EventLoop::getEventLoopOfCurrentThread()) noexcept
{
return ScopedCoroMutexAwaiter(*this, loop);
}
[[nodiscard]] CoroMutexAwaiter lock(
trantor::EventLoop *loop =
trantor::EventLoop::getEventLoopOfCurrentThread()) noexcept
{
return CoroMutexAwaiter(*this, loop);
}
void unlock() noexcept
{
assert(state_.load(std::memory_order_relaxed) != unlockedValue());
auto *waitersHead = waiters_;
if (waitersHead == nullptr)
{
void *currentState = state_.load(std::memory_order_relaxed);
if (currentState == nullptr)
{
const bool releasedLock =
state_.compare_exchange_strong(currentState,
unlockedValue(),
std::memory_order_release,
std::memory_order_relaxed);
if (releasedLock)
{
return;
}
}
currentState = state_.exchange(nullptr, std::memory_order_acquire);
assert(currentState != unlockedValue());
assert(currentState != nullptr);
auto *waiter = static_cast<CoroMutexAwaiter *>(currentState);
do
{
auto *temp = waiter->next_;
waiter->next_ = waitersHead;
waitersHead = waiter;
waiter = temp;
} while (waiter != nullptr);
}
assert(waitersHead != nullptr);
waiters_ = waitersHead->next_;
if (waitersHead->loop_)
{
auto handle = waitersHead->handle_;
waitersHead->loop_->runInLoop([handle] { handle.resume(); });
}
else
{
waitersHead->handle_.resume();
}
}
private:
class CoroMutexAwaiter
{
public:
CoroMutexAwaiter(Mutex &mutex, trantor::EventLoop *loop) noexcept
: mutex_(mutex), loop_(loop)
{
}
bool await_ready() noexcept
{
return mutex_.try_lock();
}
bool await_suspend(std::coroutine_handle<> handle) noexcept
{
handle_ = handle;
return mutex_.asynclockImpl(this);
}
void await_resume() noexcept
{
}
private:
friend class Mutex;
Mutex &mutex_;
trantor::EventLoop *loop_;
std::coroutine_handle<> handle_;
CoroMutexAwaiter *next_;
};
class ScopedCoroMutexAwaiter : public CoroMutexAwaiter
{
public:
ScopedCoroMutexAwaiter(Mutex &mutex, trantor::EventLoop *loop)
: CoroMutexAwaiter(mutex, loop)
{
}
[[nodiscard]] auto await_resume() noexcept
{
return std::unique_lock<Mutex>{mutex_, std::adopt_lock};
}
};
bool asynclockImpl(CoroMutexAwaiter *awaiter)
{
void *oldValue = state_.load(std::memory_order_relaxed);
while (true)
{
if (oldValue == unlockedValue())
{
void *newValue = nullptr;
if (state_.compare_exchange_weak(oldValue,
newValue,
std::memory_order_acquire,
std::memory_order_relaxed))
{
return false;
}
}
else
{
void *newValue = awaiter;
awaiter->next_ = static_cast<CoroMutexAwaiter *>(oldValue);
if (state_.compare_exchange_weak(oldValue,
newValue,
std::memory_order_release,
std::memory_order_relaxed))
{
return true;
}
}
}
}
void *unlockedValue() noexcept
{
return this;
}
std::atomic<void *> state_;
CoroMutexAwaiter *waiters_;
};
} // namespace drogon

View File

@ -2,6 +2,10 @@
#include <drogon/utils/coroutine.h>
#include <drogon/HttpAppFramework.h>
#include <trantor/net/EventLoopThread.h>
#include <trantor/net/EventLoopThreadPool.h>
#include <chrono>
#include <cstdint>
#include <future>
#include <type_traits>
using namespace drogon;
@ -212,3 +216,32 @@ DROGON_TEST(SwitchThread)
sync_wait(switch_thread());
thread.wait();
}
DROGON_TEST(Mutex)
{
trantor::EventLoopThreadPool pool{3};
pool.start();
Mutex mutex;
async_run([&]() -> Task<> {
co_await switchThreadCoro(pool.getLoop(0));
auto guard = co_await mutex.scoped_lock();
co_await sleepCoro(pool.getLoop(1), std::chrono::seconds(2));
co_return;
});
std::this_thread::sleep_for(std::chrono::milliseconds(100));
std::promise<void> done;
async_run([&]() -> Task<> {
co_await switchThreadCoro(pool.getLoop(2));
auto id = std::this_thread::get_id();
co_await mutex.lock();
CHECK(id == std::this_thread::get_id());
mutex.unlock();
CHECK(id == std::this_thread::get_id());
done.set_value();
co_return;
});
done.get_future().wait();
for (int16_t i = 0; i < 3; i++)
pool.getLoop(i)->quit();
pool.wait();
}