diff --git a/orm_lib/src/postgresql_impl/PgConnection.cc b/orm_lib/src/postgresql_impl/PgConnection.cc index 0716ba3b..03212ea1 100644 --- a/orm_lib/src/postgresql_impl/PgConnection.cc +++ b/orm_lib/src/postgresql_impl/PgConnection.cc @@ -16,6 +16,7 @@ #include "PostgreSQLResultImpl.h" #include #include +#include #include #include @@ -148,16 +149,17 @@ void PgConnection::pgPoll() } } -void PgConnection::execSql(std::string &&sql, - size_t paraNum, - std::vector &¶meters, - std::vector &&length, - std::vector &&format, - ResultCallback &&rcb, - std::function &&exceptCallback, - std::function &&idleCb) +void PgConnection::execSqlInLoop(std::string &&sql, + size_t paraNum, + std::vector &¶meters, + std::vector &&length, + std::vector &&format, + ResultCallback &&rcb, + std::function &&exceptCallback, + std::function &&idleCb) { LOG_TRACE << sql; + _loop->assertInLoopThread(); assert(paraNum == parameters.size()); assert(paraNum == length.size()); assert(paraNum == format.size()); @@ -170,60 +172,19 @@ void PgConnection::execSql(std::string &&sql, _idleCbPtr = std::make_shared>(std::move(idleCb)); _isWorking = true; _exceptCb = std::move(exceptCallback); - auto thisPtr = shared_from_this(); - if (!_loop->isInLoopThread()) + auto iter = _preparedStatementMap.find(_sql); + if (iter != _preparedStatementMap.end()) { - _loop->queueInLoop([thisPtr, paraNum = std::move(paraNum), parameters = std::move(parameters), length = std::move(length), format = std::move(format)]() { - if (PQsendQueryParams( - thisPtr->_connPtr.get(), - thisPtr->_sql.c_str(), - paraNum, - NULL, - parameters.data(), - length.data(), - format.data(), - 0) == 0) - { - LOG_ERROR << "send query error: " << PQerrorMessage(thisPtr->_connPtr.get()); - if (thisPtr->_isWorking) - { - thisPtr->_isWorking = false; - try - { - throw Failure(PQerrorMessage(thisPtr->_connPtr.get())); - } - catch (...) - { - auto exceptPtr = std::current_exception(); - thisPtr->_exceptCb(exceptPtr); - thisPtr->_exceptCb = decltype(_exceptCb)(); - } - thisPtr->_cb = decltype(_cb)(); - if (thisPtr->_idleCbPtr) - { - auto idle = std::move(thisPtr->_idleCbPtr); - thisPtr->_idleCbPtr.reset(); - (*idle)(); - } - } - return; - } - thisPtr->pgPoll(); - }); - } - else - { - if (PQsendQueryParams( + if (PQsendQueryPrepared( _connPtr.get(), - _sql.c_str(), + iter->second.c_str(), paraNum, - NULL, parameters.data(), length.data(), format.data(), 0) == 0) { - LOG_ERROR << "send query error: " << PQerrorMessage(thisPtr->_connPtr.get()); + LOG_ERROR << "send query error: " << PQerrorMessage(_connPtr.get()); if (_isWorking) { _isWorking = false; @@ -247,8 +208,80 @@ void PgConnection::execSql(std::string &&sql, } return; } - thisPtr->pgPoll(); } + else + { + _isRreparingStatement = true; + auto statementName = getuuid(); + if (PQsendPrepare(_connPtr.get(), statementName.c_str(), _sql.c_str(), paraNum, NULL) == 0) + { + LOG_ERROR << "send query error: " << PQerrorMessage(_connPtr.get()); + if (_isWorking) + { + _isWorking = false; + try + { + throw Failure(PQerrorMessage(_connPtr.get())); + } + catch (...) + { + auto exceptPtr = std::current_exception(); + _exceptCb(exceptPtr); + _exceptCb = decltype(_exceptCb)(); + } + _cb = decltype(_cb)(); + if (_idleCbPtr) + { + auto idle = std::move(_idleCbPtr); + _idleCbPtr.reset(); + (*idle)(); + } + } + return; + } + std::weak_ptr weakPtr = shared_from_this(); + _preparingCallback = [weakPtr, statementName, paraNum, parameters = std::move(parameters), length = std::move(length), format = std::move(format)]() { + auto thisPtr = weakPtr.lock(); + if (!thisPtr) + return; + thisPtr->_isRreparingStatement = false; + thisPtr->_preparedStatementMap[thisPtr->_sql] = statementName; + if (PQsendQueryPrepared( + thisPtr->_connPtr.get(), + statementName.c_str(), + paraNum, + parameters.data(), + length.data(), + format.data(), + 0) == 0) + { + LOG_ERROR << "send query error: " << PQerrorMessage(thisPtr->_connPtr.get()); + if (thisPtr->_isWorking) + { + thisPtr->_isWorking = false; + try + { + throw Failure(PQerrorMessage(thisPtr->_connPtr.get())); + } + catch (...) + { + auto exceptPtr = std::current_exception(); + thisPtr->_exceptCb(exceptPtr); + thisPtr->_exceptCb = decltype(thisPtr->_exceptCb)(); + } + thisPtr->_cb = decltype(thisPtr->_cb)(); + if (thisPtr->_idleCbPtr) + { + auto idle = std::move(thisPtr->_idleCbPtr); + thisPtr->_idleCbPtr.reset(); + (*idle)(); + } + } + return; + } + }; + } + pgPoll(); } void PgConnection::handleRead() @@ -290,6 +323,7 @@ void PgConnection::handleRead() } if (_channel.isWriting()) _channel.disableWriting(); + bool isPreparing = false; while ((res = std::shared_ptr(PQgetResult(_connPtr.get()), [](PGresult *p) { PQclear(p); }))) @@ -317,21 +351,36 @@ void PgConnection::handleRead() { if (_isWorking) { - auto r = makeResult(res, _sql); - _cb(r); - _cb = decltype(_cb)(); - _exceptCb = decltype(_exceptCb)(); + if (_isRreparingStatement) + { + isPreparing = true; + } + else + { + auto r = makeResult(res, _sql); + _cb(r); + _cb = decltype(_cb)(); + _exceptCb = decltype(_exceptCb)(); + } } } } if (_isWorking) { - _isWorking = false; - if (_idleCbPtr) + if(isPreparing) { - auto idle = std::move(_idleCbPtr); - _idleCbPtr.reset(); - (*idle)(); + _preparingCallback(); + _preparingCallback = std::function(); } + else + { + _isWorking = false; + if (_idleCbPtr) + { + auto idle = std::move(_idleCbPtr); + _idleCbPtr.reset(); + (*idle)(); + } + } } } diff --git a/orm_lib/src/postgresql_impl/PgConnection.h b/orm_lib/src/postgresql_impl/PgConnection.h index 489a6140..1897a604 100644 --- a/orm_lib/src/postgresql_impl/PgConnection.h +++ b/orm_lib/src/postgresql_impl/PgConnection.h @@ -24,6 +24,7 @@ #include #include #include +#include namespace drogon { @@ -44,15 +45,40 @@ class PgConnection : public DbConnection, public std::enable_shared_from_this &&format, ResultCallback &&rcb, std::function &&exceptCallback, - std::function &&idleCb) override; + std::function &&idleCb) override + { + if (_loop->isInLoopThread()) + { + execSqlInLoop(std::move(sql), paraNum, std::move(parameters), std::move(length), std::move(format), std::move(rcb), std::move(exceptCallback), std::move(idleCb)); + } + else + { + auto thisPtr = shared_from_this(); + _loop->queueInLoop([thisPtr, sql = std::move(sql), paraNum, parameters = std::move(parameters), length = std::move(length), format = std::move(format), rcb = std::move(rcb), exceptCallback = std::move(exceptCallback), idleCb = std::move(idleCb)]() mutable { + thisPtr->execSqlInLoop(std::move(sql), paraNum, std::move(parameters), std::move(length), std::move(format), std::move(rcb), std::move(exceptCallback), std::move(idleCb)); + }); + } + } virtual void disconnect() override; private: std::shared_ptr _connPtr; trantor::Channel _channel; + std::unordered_map _preparedStatementMap; + bool _isRreparingStatement = false; void handleRead(); void pgPoll(); void handleClosed(); + + void execSqlInLoop(std::string &&sql, + size_t paraNum, + std::vector &¶meters, + std::vector &&length, + std::vector &&format, + ResultCallback &&rcb, + std::function &&exceptCallback, + std::function &&idleCb); + std::function _preparingCallback; }; } // namespace orm