From 3424d3f2c444e3a408d4d09f44c64bb28bc75bfc Mon Sep 17 00:00:00 2001 From: An Tao Date: Sat, 20 Jun 2020 20:21:14 +0800 Subject: [PATCH] Add a way to set the character set when creating DbClient objects (#486) --- config.example.json | 14 +++- drogon_ctl/create_model.cc | 12 +++ drogon_ctl/templates/config.csp | 5 +- drogon_ctl/templates/model_json.csp | 7 +- lib/inc/drogon/HttpAppFramework.h | 3 +- lib/src/ConfigLoader.cc | 8 +- lib/src/DbClientManager.h | 3 +- lib/src/DbClientManagerSkipped.cc | 3 +- lib/src/HttpAppFrameworkImpl.cc | 6 +- lib/src/HttpAppFrameworkImpl.h | 3 +- orm_lib/inc/drogon/orm/DbClient.h | 1 + orm_lib/src/DbClientManager.cc | 8 +- orm_lib/src/DbConnection.h | 1 + orm_lib/src/mysql_impl/MysqlConnection.cc | 96 ++++++++++++++++++++--- orm_lib/src/mysql_impl/MysqlConnection.h | 5 +- orm_lib/tests/db_test.cc | 8 +- 16 files changed, 154 insertions(+), 29 deletions(-) diff --git a/config.example.json b/config.example.json index 12e166ca..620af1d5 100644 --- a/config.example.json +++ b/config.example.json @@ -47,6 +47,9 @@ //is_fast: false by default, if it is true, the client is faster but user can't call //any synchronous interface of it. "is_fast": false, + //client_encoding: The character set used by the client. it is empty string by default which + //means use the default character set. + //"client_encoding": "", //connection_number: 1 by default, if the 'is_fast' is true, the number is the number of //connections per IO thread, otherwise it is the total number of all connections. "connection_number": 1 @@ -114,7 +117,7 @@ //is_recursive: true by default. If it is set to false, files in sub directories can't be accessed. "is_recursive": true, //filters: string array, the filters applied to the location. - "filters":[] + "filters": [] }], //max_connections: maximum connections number,100000 by default "max_connections": 100000, @@ -227,8 +230,11 @@ }], //custom_config: custom configuration for users. This object can be get by the app().getCustomConfig() method. "custom_config": { - "realm" : "drogonRealm", - "opaque" : "drogonOpaque", - "credentials" : [ {"user" : "drogon", "password": "dr0g0n"} ] + "realm": "drogonRealm", + "opaque": "drogonOpaque", + "credentials": [{ + "user": "drogon", + "password": "dr0g0n" + }] } } \ No newline at end of file diff --git a/drogon_ctl/create_model.cc b/drogon_ctl/create_model.cc index 9753dd9e..5433a6a1 100644 --- a/drogon_ctl/create_model.cc +++ b/drogon_ctl/create_model.cc @@ -722,6 +722,12 @@ void create_model::createModel(const std::string &path, connStr += " password="; connStr += password; } + auto characterSet = config.get("client_encoding", "").asString(); + if (!characterSet.empty()) + { + connStr += " client_encoding="; + connStr += characterSet; + } auto schema = config.get("schema", "public").asString(); DbClientPtr client = drogon::orm::DbClient::newPgClient(connStr, 1); @@ -822,6 +828,12 @@ void create_model::createModel(const std::string &path, connStr += " password="; connStr += password; } + auto characterSet = config.get("client_encoding", "").asString(); + if (!characterSet.empty()) + { + connStr += " client_encoding="; + connStr += characterSet; + } DbClientPtr client = drogon::orm::DbClient::newMysqlClient(connStr, 1); std::cout << "Connect to server..." << std::endl; if (forceOverwrite_) diff --git a/drogon_ctl/templates/config.csp b/drogon_ctl/templates/config.csp index 9e3c8669..0f6649ca 100644 --- a/drogon_ctl/templates/config.csp +++ b/drogon_ctl/templates/config.csp @@ -47,6 +47,9 @@ //is_fast: false by default, if it is true, the client is faster but user can't call //any synchronous interface of it. "is_fast": false, + //client_encoding: The character set used by the client. it is empty string by default which + //means use the default character set. + //"client_encoding": "", //connection_number: 1 by default, if the 'is_fast' is true, the number is the number of //connections per IO thread, otherwise it is the total number of all connections. "connection_number": 1 @@ -114,7 +117,7 @@ //is_recursive: true by default. If it is set to false, files in sub directories can't be accessed. "is_recursive": true, //filters: string array, the filters applied to the location. - "filters":[] + "filters": [] }], //max_connections: maximum connections number,100000 by default "max_connections": 100000, diff --git a/drogon_ctl/templates/model_json.csp b/drogon_ctl/templates/model_json.csp index 3f4191da..354d5db7 100644 --- a/drogon_ctl/templates/model_json.csp +++ b/drogon_ctl/templates/model_json.csp @@ -13,8 +13,11 @@ "schema": "public", //user: User name "user": "", - //passwd: Password - "passwd": "", + //password or passwd: Password + "password": "", + //client_encoding: The character set used by drogon_ctl. it is empty string by default which + //means use the default character set. + //"client_encoding": "", //table: An array of tables to be modelized. if the array is empty, all revealed tables are modelized. "tables": [], "relationships": { diff --git a/lib/inc/drogon/HttpAppFramework.h b/lib/inc/drogon/HttpAppFramework.h index b9126712..066a8660 100644 --- a/lib/inc/drogon/HttpAppFramework.h +++ b/lib/inc/drogon/HttpAppFramework.h @@ -1092,7 +1092,8 @@ class HttpAppFramework : public trantor::NonCopyable const size_t connectionNum = 1, const std::string &filename = "", const std::string &name = "default", - const bool isFast = false) = 0; + const bool isFast = false, + const std::string &characterSet = "") = 0; /// Get the DNS resolver /** diff --git a/lib/src/ConfigLoader.cc b/lib/src/ConfigLoader.cc index e47dc094..fdbf48b4 100644 --- a/lib/src/ConfigLoader.cc +++ b/lib/src/ConfigLoader.cc @@ -470,6 +470,11 @@ static void loadDbClients(const Json::Value &dbClients) auto name = client.get("name", "default").asString(); auto filename = client.get("filename", "").asString(); auto isFast = client.get("is_fast", false).asBool(); + auto characterSet = client.get("characterSet", "").asString(); + if (characterSet.empty()) + { + characterSet = client.get("client_encoding", "").asString(); + } drogon::app().createDbClient(type, host, (unsigned short)port, @@ -479,7 +484,8 @@ static void loadDbClients(const Json::Value &dbClients) connNum, filename, name, - isFast); + isFast, + characterSet); } } static void loadListeners(const Json::Value &listeners) diff --git a/lib/src/DbClientManager.h b/lib/src/DbClientManager.h index dc564775..1385aa28 100644 --- a/lib/src/DbClientManager.h +++ b/lib/src/DbClientManager.h @@ -51,7 +51,8 @@ class DbClientManager : public trantor::NonCopyable const size_t connectionNum, const std::string &filename, const std::string &name, - const bool isFast); + const bool isFast, + const std::string &characterSet); bool areAllDbClientsAvailable() const noexcept; private: diff --git a/lib/src/DbClientManagerSkipped.cc b/lib/src/DbClientManagerSkipped.cc index 033e9663..321a0b88 100644 --- a/lib/src/DbClientManagerSkipped.cc +++ b/lib/src/DbClientManagerSkipped.cc @@ -36,7 +36,8 @@ void DbClientManager::createDbClient(const std::string &dbType, const size_t connectionNum, const std::string &filename, const std::string &name, - const bool isFast) + const bool isFast, + const std::string &characterSet) { LOG_FATAL << "No database is supported by drogon, please install the " "database development library first."; diff --git a/lib/src/HttpAppFrameworkImpl.cc b/lib/src/HttpAppFrameworkImpl.cc index 9cd8e346..a6f27993 100644 --- a/lib/src/HttpAppFrameworkImpl.cc +++ b/lib/src/HttpAppFrameworkImpl.cc @@ -886,7 +886,8 @@ HttpAppFramework &HttpAppFrameworkImpl::createDbClient( const size_t connectionNum, const std::string &filename, const std::string &name, - const bool isFast) + const bool isFast, + const std::string &characterSet) { assert(!running_); dbClientManagerPtr_->createDbClient(dbType, @@ -898,7 +899,8 @@ HttpAppFramework &HttpAppFrameworkImpl::createDbClient( connectionNum, filename, name, - isFast); + isFast, + characterSet); return *this; } diff --git a/lib/src/HttpAppFrameworkImpl.h b/lib/src/HttpAppFrameworkImpl.h index 069b0c73..7eb44976 100644 --- a/lib/src/HttpAppFrameworkImpl.h +++ b/lib/src/HttpAppFrameworkImpl.h @@ -425,7 +425,8 @@ class HttpAppFrameworkImpl : public HttpAppFramework const size_t connectionNum = 1, const std::string &filename = "", const std::string &name = "default", - const bool isFast = false) override; + const bool isFast = false, + const std::string &characterSet = "") override; inline static HttpAppFrameworkImpl &instance() { diff --git a/orm_lib/inc/drogon/orm/DbClient.h b/orm_lib/inc/drogon/orm/DbClient.h index c6ccd226..e77fc9fe 100644 --- a/orm_lib/inc/drogon/orm/DbClient.h +++ b/orm_lib/inc/drogon/orm/DbClient.h @@ -59,6 +59,7 @@ class DbClient : public trantor::NonCopyable * as the operating system name of the user running the application. * - password: Password to be used if the server demands password * authentication. + * - client_encoding: The character set to be used on database connections. * * For other key words on PostgreSQL, see the PostgreSQL documentation. * Only a pair of key values ​​is valid for Sqlite3, and its keyword is diff --git a/orm_lib/src/DbClientManager.cc b/orm_lib/src/DbClientManager.cc index 10f79a2b..9cb5b722 100644 --- a/orm_lib/src/DbClientManager.cc +++ b/orm_lib/src/DbClientManager.cc @@ -94,7 +94,8 @@ void DbClientManager::createDbClient(const std::string &dbType, const size_t connectionNum, const std::string &filename, const std::string &name, - const bool isFast) + const bool isFast, + const std::string &characterSet) { auto connStr = utils::formattedString("host=%s port=%u dbname=%s user=%s", host.c_str(), @@ -108,6 +109,11 @@ void DbClientManager::createDbClient(const std::string &dbType, } std::string type = dbType; std::transform(type.begin(), type.end(), type.begin(), tolower); + if (!characterSet.empty()) + { + connStr += " client_encoding="; + connStr += characterSet; + } DbInfo info; info.connectionInfo_ = connStr; info.connectionNumber_ = connectionNum; diff --git a/orm_lib/src/DbConnection.h b/orm_lib/src/DbConnection.h index d30cf598..b8c97839 100644 --- a/orm_lib/src/DbConnection.h +++ b/orm_lib/src/DbConnection.h @@ -40,6 +40,7 @@ enum class ConnectStatus { None = 0, Connecting, + SettingCharacterSet, Ok, Bad }; diff --git a/orm_lib/src/mysql_impl/MysqlConnection.cc b/orm_lib/src/mysql_impl/MysqlConnection.cc index dab1a62a..ddeba4f9 100644 --- a/orm_lib/src/mysql_impl/MysqlConnection.cc +++ b/orm_lib/src/mysql_impl/MysqlConnection.cc @@ -94,6 +94,10 @@ MysqlConnection::MysqlConnection(trantor::EventLoop *loop, { passwd_ = value; } + else if (key == "client_encoding") + { + characterSet_ = value; + } } loop_->queueInLoop([this]() { MYSQL *ret; @@ -200,15 +204,27 @@ void MysqlConnection::handleTimeout() return; } // I don't think the programe can run to here. - status_ = ConnectStatus::Ok; - if (okCallback_) + if (characterSet_.empty()) { - auto thisPtr = shared_from_this(); - okCallback_(thisPtr); + status_ = ConnectStatus::Ok; + if (okCallback_) + { + auto thisPtr = shared_from_this(); + okCallback_(thisPtr); + } + } + else + { + startSetCharacterSet(); + return; } } setChannel(); } + else if (status_ == ConnectStatus::SettingCharacterSet) + { + continueSetCharacterSet(status); + } else if (status_ == ConnectStatus::Ok) { } @@ -241,11 +257,19 @@ void MysqlConnection::handleEvent() handleClosed(); return; } - status_ = ConnectStatus::Ok; - if (okCallback_) + if (characterSet_.empty()) { - auto thisPtr = shared_from_this(); - okCallback_(thisPtr); + status_ = ConnectStatus::Ok; + if (okCallback_) + { + auto thisPtr = shared_from_this(); + okCallback_(thisPtr); + } + } + else + { + startSetCharacterSet(); + return; } } setChannel(); @@ -320,8 +344,62 @@ void MysqlConnection::handleEvent() return; } } + else if (status_ == ConnectStatus::SettingCharacterSet) + { + continueSetCharacterSet(status); + } +} +void MysqlConnection::continueSetCharacterSet(int status) +{ + int err; + waitStatus_ = mysql_set_character_set_cont(&err, mysqlPtr_.get(), status); + if (waitStatus_ == 0) + { + if (err) + { + LOG_ERROR << "Error(" << err << ") \"" + << mysql_error(mysqlPtr_.get()) << "\""; + LOG_ERROR << "Failed to mysql_real_connect()"; + handleClosed(); + return; + } + status_ = ConnectStatus::Ok; + if (okCallback_) + { + auto thisPtr = shared_from_this(); + okCallback_(thisPtr); + } + } + setChannel(); +} +void MysqlConnection::startSetCharacterSet() +{ + int err; + waitStatus_ = mysql_set_character_set_start(&err, + mysqlPtr_.get(), + characterSet_.data()); + if (waitStatus_ == 0) + { + if (err) + { + LOG_ERROR << "error"; + loop_->queueInLoop( + [thisPtr = shared_from_this()] { thisPtr->outputError(); }); + return; + } + status_ = ConnectStatus::Ok; + if (okCallback_) + { + auto thisPtr = shared_from_this(); + okCallback_(thisPtr); + } + } + else + { + status_ = ConnectStatus::SettingCharacterSet; + } + setChannel(); } - void MysqlConnection::execSqlInLoop( string_view &&sql, size_t paraNum, diff --git a/orm_lib/src/mysql_impl/MysqlConnection.h b/orm_lib/src/mysql_impl/MysqlConnection.h index 4d0d9094..a870bf29 100644 --- a/orm_lib/src/mysql_impl/MysqlConnection.h +++ b/orm_lib/src/mysql_impl/MysqlConnection.h @@ -96,10 +96,11 @@ class MysqlConnection : public DbConnection, std::vector &&format, ResultCallback &&rcb, std::function &&exceptCallback); - + void startSetCharacterSet(); + void continueSetCharacterSet(int status); std::unique_ptr channelPtr_; std::shared_ptr mysqlPtr_; - + std::string characterSet_; void handleTimeout(); void handleClosed(); diff --git a/orm_lib/tests/db_test.cc b/orm_lib/tests/db_test.cc index 4e3da205..d140dc0c 100644 --- a/orm_lib/tests/db_test.cc +++ b/orm_lib/tests/db_test.cc @@ -1967,15 +1967,17 @@ int main(int argc, char *argv[]) trantor::Logger::setLogLevel(trantor::Logger::kDebug); #if USE_POSTGRESQL auto postgre_client = DbClient::newPgClient( - "host=127.0.0.1 port=5432 dbname=postgres user=postgres", 1); + "host=127.0.0.1 port=5432 dbname=postgres user=postgres " + "client_encoding=utf8", + 1); while (!postgre_client->hasAvailableConnections()) { std::this_thread::sleep_for(1s); } #endif #if USE_MYSQL - auto mysql_client = - DbClient::newMysqlClient("host=localhost port=3306 user=root", 1); + auto mysql_client = DbClient::newMysqlClient( + "host=localhost port=3306 user=root client_encoding=utf8mb4", 1); while (!mysql_client->hasAvailableConnections()) { std::this_thread::sleep_for(1s);