Support postgresql asynchronous notification (LISTEN/NOTIFY). (#1464)

This commit is contained in:
Nitromelon 2023-01-04 23:50:49 +08:00 committed by GitHub
parent 19f08786f0
commit 1618484d74
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
11 changed files with 705 additions and 8 deletions

View File

@ -336,10 +336,12 @@ if (BUILD_POSTGRESQL)
target_link_libraries(${PROJECT_NAME} PRIVATE pg_lib)
set(DROGON_SOURCES
${DROGON_SOURCES}
orm_lib/src/postgresql_impl/PostgreSQLResultImpl.cc)
orm_lib/src/postgresql_impl/PostgreSQLResultImpl.cc
orm_lib/src/postgresql_impl/PgListener.cc)
set(private_headers
${private_headers}
orm_lib/src/postgresql_impl/PostgreSQLResultImpl.h)
orm_lib/src/postgresql_impl/PostgreSQLResultImpl.h
orm_lib/src/postgresql_impl/PgListener.h)
if (LIBPQ_BATCH_MODE)
try_compile(libpq_supports_batch ${CMAKE_BINARY_DIR}/cmaketest
${PROJECT_SOURCE_DIR}/cmake/tests/test_libpq_batch_mode.cc
@ -525,6 +527,7 @@ set(DROGON_SOURCES
orm_lib/src/DbClientImpl.cc
orm_lib/src/DbClientLockFree.cc
orm_lib/src/DbConnection.cc
orm_lib/src/DbListener.cc
orm_lib/src/Exception.cc
orm_lib/src/Field.cc
orm_lib/src/Result.cc
@ -669,6 +672,7 @@ set(ORM_HEADERS
orm_lib/inc/drogon/orm/ArrayParser.h
orm_lib/inc/drogon/orm/Criteria.h
orm_lib/inc/drogon/orm/DbClient.h
orm_lib/inc/drogon/orm/DbListener.h
orm_lib/inc/drogon/orm/DbTypes.h
orm_lib/inc/drogon/orm/Exception.h
orm_lib/inc/drogon/orm/Field.h

View File

@ -0,0 +1,80 @@
/**
*
* @file DbListener.h
* @author Nitromelon
*
* Copyright 2022, An Tao. All rights reserved.
* https://github.com/drogonframework/drogon
* Use of this source code is governed by a MIT license
* that can be found in the License file.
*
* Drogon
*
*/
#pragma once
#include <drogon/exports.h>
#include <functional>
#include <string>
#include <memory>
namespace trantor
{
class EventLoop;
}
namespace drogon
{
namespace orm
{
class DbListener;
using DbListenerPtr = std::shared_ptr<DbListener>;
/// Database asynchronous notification listener abstract class
class DROGON_EXPORT DbListener
{
public:
using MessageCallback = std::function<void(std::string, std::string)>;
virtual ~DbListener();
/// Create a new postgresql notification listener
/**
* @param connInfo: Connection string, the same as DbClient::newPgClient()
* @param loop: The eventloop this DbListener runs in. If empty, a new
* thread will be created.
* @return DbListenerPtr
* @return nullptr if postgresql is not supported.
*/
static DbListenerPtr newPgListener(const std::string &connInfo,
trantor::EventLoop *loop = nullptr);
/// Listen to a channel
/**
* @param channel channel name to listen
* @param messageCallback callback when notification arrives on channel
*
* @note `listen()` can be called on the same channel multiple times.
* In this case, each `messageCallback` will be called when message arrives.
* However, a single `unlisten()` call will cancel all the callbacks.
*
* @note If has connection issues, the listener will keep retrying until
* listen success. The listener will also re-listen to all channels after
* re-connection.
* However, if user passes an invalid channel string, the operation will
* fail with an error log without any other actions. (This behavior may
* change in future. A errorCallback may be added as a parameters.)
*/
virtual void listen(const std::string &channel,
MessageCallback messageCallback) noexcept = 0;
/// Stop listening to channel
/**
* @param channel channel to stop listening
*/
virtual void unlisten(const std::string &channel) noexcept = 0;
};
} // namespace orm
} // namespace drogon

42
orm_lib/src/DbListener.cc Normal file
View File

@ -0,0 +1,42 @@
/**
*
* @file DbListener.cc
* @author Nitromelon
*
* Copyright 2022, An Tao. All rights reserved.
* https://github.com/drogonframework/drogon
* Use of this source code is governed by a MIT license
* that can be found in the License file.
*
* Drogon
*
*/
#include <drogon/config.h>
#include <drogon/orm/DbListener.h>
#include <trantor/utils/Logger.h>
#include <mutex>
#if USE_POSTGRESQL
#include "postgresql_impl/PgListener.h"
#endif
using namespace drogon;
using namespace drogon::orm;
DbListener::~DbListener() = default;
std::shared_ptr<DbListener> DbListener::newPgListener(
const std::string& connInfo,
trantor::EventLoop* loop)
{
#if USE_POSTGRESQL
std::shared_ptr<PgListener> pgListener =
std::make_shared<PgListener>(connInfo, loop);
pgListener->init();
return pgListener;
#else
LOG_ERROR << "Postgresql is not supported by current drogon build";
return nullptr;
#endif
}

View File

@ -401,6 +401,18 @@ void PgConnection::handleRead()
while (!PQisBusy(connectionPtr_.get()))
{
// TODO: should optimize order of checking
// Check notification
std::shared_ptr<PGnotify> notify;
while (
(notify =
std::shared_ptr<PGnotify>(PQnotifies(connectionPtr_.get()),
[](PGnotify *p) { PQfreemem(p); })))
{
messageCallback_({notify->relname}, {notify->extra});
}
// Check query result
res = std::shared_ptr<PGresult>(PQgetResult(connectionPtr_.get()),
[](PGresult *p) { PQclear(p); });
if (!res)

View File

@ -343,6 +343,15 @@ void PgConnection::handleRead()
idleCb_();
}
}
// Check notification
std::shared_ptr<PGnotify> notify;
while (
(notify = std::shared_ptr<PGnotify>(PQnotifies(connectionPtr_.get()),
[](PGnotify *p) { PQfreemem(p); })))
{
messageCallback_({notify->relname}, {notify->extra});
}
}
void PgConnection::doAfterPreparing()

View File

@ -38,6 +38,8 @@ class PgConnection : public DbConnection,
public std::enable_shared_from_this<PgConnection>
{
public:
using MessageCallback =
std::function<void(const std::string &, const std::string &)>;
PgConnection(trantor::EventLoop *loop,
const std::string &connInfo,
bool autoBatch);
@ -88,6 +90,15 @@ class PgConnection : public DbConnection,
void disconnect() override;
const std::shared_ptr<PGconn> &pgConn() const
{
return connectionPtr_;
}
void setMessageCallback(MessageCallback cb)
{
messageCallback_ = std::move(cb);
}
private:
std::shared_ptr<PGconn> connectionPtr_;
trantor::Channel channel_;
@ -134,6 +145,8 @@ class PgConnection : public DbConnection,
#else
std::unordered_map<string_view, std::string> preparedStatementsMap_;
#endif
MessageCallback messageCallback_;
};
} // namespace orm

View File

@ -0,0 +1,333 @@
/**
*
* @file PgListener.cc
* @author Nitromelon
*
* Copyright 2022, An Tao. All rights reserved.
* https://github.com/drogonframework/drogon
* Use of this source code is governed by a MIT license
* that can be found in the License file.
*
* Drogon
*
*/
#include "PgListener.h"
#include "PgConnection.h"
using namespace drogon;
using namespace drogon::orm;
#define MAX_UNLISTEN_RETRY 3
#define MAX_LISTEN_RETRY 10
PgListener::PgListener(std::string connInfo, trantor::EventLoop* loop)
: connectionInfo_(std::move(connInfo)), loop_(loop)
{
if (!loop)
{
threadPtr_ = std::make_unique<trantor::EventLoopThread>();
threadPtr_->run();
loop_ = threadPtr_->getLoop();
}
}
PgListener::~PgListener()
{
if (conn_)
{
conn_->disconnect();
conn_ = nullptr;
}
}
void PgListener::init() noexcept
{
// shared_from_this() can not be called in constructor
std::weak_ptr<PgListener> weakThis = shared_from_this();
loop_->queueInLoop([weakThis]() {
auto thisPtr = weakThis.lock();
if (!thisPtr)
{
return;
}
thisPtr->connHolder_ = thisPtr->newConnection();
});
}
void PgListener::listen(
const std::string& channel,
std::function<void(std::string, std::string)> messageCallback) noexcept
{
if (loop_->isInLoopThread())
{
listenChannels_[channel].push_back(std::move(messageCallback));
listenInLoop(channel, true);
}
else
{
std::weak_ptr<PgListener> weakThis = shared_from_this();
loop_->queueInLoop(
[weakThis, channel, cb = std::move(messageCallback)]() mutable {
auto thisPtr = weakThis.lock();
if (!thisPtr)
{
return;
}
thisPtr->listenChannels_[channel].push_back(std::move(cb));
thisPtr->listenInLoop(channel, true);
});
}
}
void PgListener::unlisten(const std::string& channel) noexcept
{
if (loop_->isInLoopThread())
{
listenChannels_.erase(channel);
listenInLoop(channel, false);
}
else
{
std::weak_ptr<PgListener> weakThis = shared_from_this();
loop_->queueInLoop([weakThis, channel]() {
auto thisPtr = weakThis.lock();
if (!thisPtr)
{
return;
}
thisPtr->listenChannels_.erase(channel);
thisPtr->listenInLoop(channel, false);
});
}
}
void PgListener::onMessage(const std::string& channel,
const std::string& message) const noexcept
{
loop_->assertInLoopThread();
auto iter = listenChannels_.find(channel);
if (iter == listenChannels_.end())
{
return;
}
for (auto& cb : iter->second)
{
cb(channel, message);
}
}
void PgListener::listenAll() noexcept
{
loop_->assertInLoopThread();
listenTasks_.clear();
for (auto& item : listenChannels_)
{
listenTasks_.emplace_back(true, item.first);
}
listenNext();
}
void PgListener::listenNext() noexcept
{
loop_->assertInLoopThread();
if (listenTasks_.empty())
{
return;
}
auto [listen, channel] = listenTasks_.front();
listenTasks_.pop_front();
listenInLoop(channel, listen);
}
void PgListener::listenInLoop(const std::string& channel,
bool listen,
std::shared_ptr<unsigned int> retryCnt)
{
loop_->assertInLoopThread();
if (!retryCnt)
retryCnt = std::make_shared<unsigned int>(0);
if (conn_ && listenTasks_.empty())
{
if (!conn_->isWorking())
{
auto pgConn = std::dynamic_pointer_cast<PgConnection>(conn_);
std::string escapedChannel =
escapeIdentifier(pgConn, channel.c_str(), channel.size());
if (escapedChannel.empty())
{
LOG_ERROR << "Failed to escape pg identifier, stop listen";
// TODO: report
return;
}
// Because DbConnection::execSql() takes string_view as parameter,
// sql must be hold until query finish.
auto sql = std::make_shared<std::string>(
(listen ? "LISTEN " : "UNLISTEN ") + escapedChannel);
std::weak_ptr<PgListener> weakThis = shared_from_this();
conn_->execSql(
*sql,
0,
{},
{},
{},
[listen, channel, sql](const Result& r) {
if (listen)
{
LOG_DEBUG << "Listen channel " << channel;
}
else
{
LOG_DEBUG << "Unlisten channel " << channel;
}
},
[listen, channel, weakThis, sql, retryCnt, loop = loop_](
const std::exception_ptr& exception) {
try
{
std::rethrow_exception(exception);
}
catch (const DrogonDbException& ex)
{
++(*retryCnt);
if (listen)
{
LOG_ERROR << "Failed to listen channel " << channel
<< ", error: " << ex.base().what();
if (*retryCnt > MAX_LISTEN_RETRY)
{
LOG_ERROR << "Failed to listen channel "
<< channel
<< " after max attempt. Stop trying.";
// TODO: report
return;
}
}
else
{
LOG_ERROR << "Failed to unlisten channel "
<< channel
<< ", error: " << ex.base().what();
if (*retryCnt > MAX_UNLISTEN_RETRY)
{
LOG_ERROR << "Failed to unlisten channel "
<< channel
<< " after max attempt. Stop trying.";
// TODO: report?
return;
}
}
auto delay = (*retryCnt) < 5 ? (*retryCnt * 2) : 10;
loop->runAfter(delay, [=]() {
auto thisPtr = weakThis.lock();
if (thisPtr)
{
thisPtr->listenInLoop(channel,
listen,
retryCnt);
}
});
}
});
return;
}
}
if (listenTasks_.size() > 20000)
{
LOG_WARN << "Too many queries in listen buffer. Stop listen channel "
<< channel;
// TODO: report
return;
}
listenTasks_.emplace_back(listen, channel);
}
PgConnectionPtr PgListener::newConnection(
std::shared_ptr<unsigned int> retryCnt)
{
PgConnectionPtr connPtr =
std::make_shared<PgConnection>(loop_, connectionInfo_, false);
std::weak_ptr<PgListener> weakPtr = shared_from_this();
if (!retryCnt)
retryCnt = std::make_shared<unsigned int>(0);
connPtr->setCloseCallback(
[weakPtr, retryCnt](const DbConnectionPtr& closeConnPtr) {
auto thisPtr = weakPtr.lock();
if (!thisPtr)
return;
// Erase the connection
if (closeConnPtr == thisPtr->conn_)
{
thisPtr->conn_.reset();
}
if (closeConnPtr == thisPtr->connHolder_)
{
thisPtr->connHolder_.reset();
}
// Reconnect after delay
++(*retryCnt);
unsigned int delay = (*retryCnt) < 5 ? (*retryCnt * 2) : 10;
thisPtr->loop_->runAfter(delay, [weakPtr, closeConnPtr, retryCnt] {
auto thisPtr = weakPtr.lock();
if (!thisPtr)
return;
assert(!thisPtr->connHolder_);
thisPtr->connHolder_ = thisPtr->newConnection(retryCnt);
});
});
connPtr->setOkCallback(
[weakPtr, retryCnt](const DbConnectionPtr& okConnPtr) {
LOG_TRACE << "connected after " << *retryCnt << " tries";
(*retryCnt) = 0;
auto thisPtr = weakPtr.lock();
if (!thisPtr)
return;
assert(!thisPtr->conn_);
assert(thisPtr->connHolder_ == okConnPtr);
thisPtr->conn_ = okConnPtr;
thisPtr->listenAll();
});
connPtr->setIdleCallback([weakPtr]() {
auto thisPtr = weakPtr.lock();
if (!thisPtr)
return;
thisPtr->listenNext();
});
connPtr->setMessageCallback(
[weakPtr](const std::string& channel, const std::string& message) {
auto thisPtr = weakPtr.lock();
if (thisPtr)
{
thisPtr->onMessage(channel, message);
}
});
return connPtr;
}
std::string PgListener::escapeIdentifier(const PgConnectionPtr& conn,
const char* str,
size_t length)
{
auto res = std::unique_ptr<char, std::function<void(char*)>>(
PQescapeIdentifier(conn->pgConn().get(), str, length), [](char* res) {
if (res)
{
PQfreemem(res);
}
});
if (!res)
{
LOG_ERROR << "Error when escaping identifier ["
<< std::string(str, length) << "]. "
<< PQerrorMessage(conn->pgConn().get());
return {};
}
return std::string{res.get()};
}

View File

@ -0,0 +1,90 @@
/**
*
* @file PgListener.h
* @author Nitromelon
*
* Copyright 2022, An Tao. All rights reserved.
* https://github.com/drogonframework/drogon
* Use of this source code is governed by a MIT license
* that can be found in the License file.
*
* Drogon
*
*/
#pragma once
#include <drogon/orm/DbListener.h>
#include <trantor/net/EventLoopThread.h>
#include <deque>
#include <mutex>
#include <string>
#include <unordered_map>
#include "./PgConnection.h"
namespace drogon
{
namespace orm
{
class PgListener : public DbListener,
public std::enable_shared_from_this<PgListener>
{
public:
PgListener(std::string connInfo, trantor::EventLoop* loop);
~PgListener() override;
void init() noexcept;
trantor::EventLoop* loop() const
{
return loop_;
}
void listen(const std::string& channel,
MessageCallback messageCallback) noexcept override;
void unlisten(const std::string& channel) noexcept override;
// methods below should be called in loop
void onMessage(const std::string& channel,
const std::string& message) const noexcept;
void listenAll() noexcept;
void listenNext() noexcept;
private:
/// Escapes a string for use as an SQL identifier, such as a table, column,
/// or function name. This is useful when a user-supplied identifier might
/// contain special characters that would otherwise not be interpreted as
/// part of the identifier by the SQL parser, or when the identifier might
/// contain upper case characters whose case should be preserved.
/**
* @param str: c-style string to escape. A terminating zero byte is not
* required, and should not be counted in length(If a terminating zero byte
* is found before length bytes are processed, PQescapeIdentifier stops at
* the zero; the behavior is thus rather like strncpy).
* @param length: length of the c-style string
* @return: The return string has all special characters replaced so that
* it will be properly processed as an SQL identifier. A terminating zero
* byte is also added. The return string will also be surrounded by double
* quotes.
*/
static std::string escapeIdentifier(const PgConnectionPtr& conn,
const char* str,
size_t length);
void listenInLoop(const std::string& channel,
bool listen,
std::shared_ptr<unsigned int> = nullptr);
PgConnectionPtr newConnection(std::shared_ptr<unsigned int> = nullptr);
std::string connectionInfo_;
std::unique_ptr<trantor::EventLoopThread> threadPtr_;
trantor::EventLoop* loop_;
DbConnectionPtr connHolder_;
DbConnectionPtr conn_;
std::deque<std::pair<bool, std::string>> listenTasks_;
std::unordered_map<std::string, std::vector<MessageCallback>>
listenChannels_;
};
} // namespace orm
} // namespace drogon

View File

@ -14,6 +14,10 @@ add_executable(pipeline_test
pipeline_test.cpp
)
add_executable(db_listener_test
db_listener_test.cc
)
set_property(TARGET db_test PROPERTY CXX_STANDARD ${DROGON_CXX_STANDARD})
set_property(TARGET db_test PROPERTY CXX_STANDARD_REQUIRED ON)
set_property(TARGET db_test PROPERTY CXX_EXTENSIONS OFF)
@ -21,3 +25,7 @@ set_property(TARGET db_test PROPERTY CXX_EXTENSIONS OFF)
set_property(TARGET pipeline_test PROPERTY CXX_STANDARD ${DROGON_CXX_STANDARD})
set_property(TARGET pipeline_test PROPERTY CXX_STANDARD_REQUIRED ON)
set_property(TARGET pipeline_test PROPERTY CXX_EXTENSIONS OFF)
set_property(TARGET db_listener_test PROPERTY CXX_STANDARD ${DROGON_CXX_STANDARD})
set_property(TARGET db_listener_test PROPERTY CXX_STANDARD_REQUIRED ON)
set_property(TARGET db_listener_test PROPERTY CXX_EXTENSIONS OFF)

View File

@ -0,0 +1,98 @@
/**
*
* @file db_listener_test.cc
* @author Nitromelon
*
* Copyright 2022, Nitromelon. All rights reserved.
* Use of this source code is governed by a MIT license
* that can be found in the License file.
*
* Drogon
*
* Drogon database test program
*
*/
#define DROGON_TEST_MAIN
#include <drogon/drogon_test.h>
#include <drogon/HttpAppFramework.h>
#include <drogon/config.h>
#include <drogon/orm/DbListener.h>
#include <chrono>
using namespace drogon;
using namespace drogon::orm;
using namespace trantor;
using namespace std::chrono_literals;
static const std::string LISTEN_CHANNEL = "listen_test";
#if USE_POSTGRESQL
orm::DbClientPtr postgreClient;
DROGON_TEST(ListenNotifyTest)
{
auto clientPtr = postgreClient;
auto dbListener = DbListener::newPgListener(clientPtr->connectionInfo());
MANDATE(dbListener);
static int numNotifications = 0;
LOG_INFO << "Start listen.";
dbListener->listen(LISTEN_CHANNEL,
[TEST_CTX](const std::string &channel,
const std::string &message) {
MANDATE(channel == LISTEN_CHANNEL);
LOG_INFO << "Message from " << LISTEN_CHANNEL << ": "
<< message;
++numNotifications;
});
std::this_thread::sleep_for(1s);
LOG_INFO << "Start sending notifications.";
for (int i = 0; i < 10; ++i)
{
// Can not use placeholders in LISTEN or NOTIFY command!!!
std::string cmd =
"NOTIFY " + LISTEN_CHANNEL + ", '" + std::to_string(i) + "'";
clientPtr->execSqlAsync(
cmd,
[i](const orm::Result &result) { LOG_INFO << "Notified " << i; },
[](const orm::DrogonDbException &ex) {
LOG_ERROR << "Failed to notify " << ex.base().what();
});
}
std::this_thread::sleep_for(5s);
LOG_INFO << "Unlisten.";
dbListener->unlisten("listen_test");
CHECK(numNotifications == 10);
}
#endif
int main(int argc, char **argv)
{
trantor::Logger::setLogLevel(trantor::Logger::LogLevel::kDebug);
std::string dbConnInfo;
const char *dbUrl = std::getenv("DROGON_TEST_DB_CONN_INFO");
if (dbUrl)
{
dbConnInfo = std::string{dbUrl};
}
else
{
dbConnInfo =
"host=127.0.0.1 port=5432 dbname=postgres user=postgres "
"password=12345 "
"client_encoding=utf8";
}
LOG_INFO << "Database conn info: " << dbConnInfo;
#if USE_POSTGRESQL
postgreClient = orm::DbClient::newPgClient(dbConnInfo, 2, true);
#else
LOG_DEBUG << "Drogon is built without Postgresql. No tests executed.";
return 0;
#endif
int testStatus = test::run(argc, argv);
return testStatus;
}

20
test.sh
View File

@ -182,13 +182,21 @@ if [ "$1" = "-t" ]; then
fi
fi
if [ -f "./orm_lib/tests/pipeline_test" ]; then
echo "Test pipeline mode"
./orm_lib/tests/pipeline_test -s
if [ $? -ne 0 ]; then
echo "Error in testing"
exit -1
fi
echo "Test pipeline mode"
./orm_lib/tests/pipeline_test -s
if [ $? -ne 0 ]; then
echo "Error in testing"
exit -1
fi
fi
if [ -f "./orm_lib/tests/db_listener_test" ]; then
echo "Test DbListener"
./orm_lib/tests/db_listener_test -s
if [ $? -ne 0 ]; then
echo "Error in testing"
exit -1
fi
fi
if [ -f "./nosql_lib/redis/tests/redis_test" ]; then
echo "Test redis"
./nosql_lib/redis/tests/redis_test -s