diff --git a/drogon_ctl/create_model.cc b/drogon_ctl/create_model.cc index 97a17adf..6de42e4b 100644 --- a/drogon_ctl/create_model.cc +++ b/drogon_ctl/create_model.cc @@ -70,6 +70,7 @@ void create_model::createModelClassFromPG(const std::string &path, const DbClien data["hasPrimaryKey"] = (int)0; data["primaryKeyName"] = ""; data["dbName"] = _dbname; + data["rdbms"] = std::string("postgresql"); std::vector cols; *client << "SELECT * \ FROM information_schema.columns \ @@ -306,6 +307,7 @@ void create_model::createModelClassFromMysql(const std::string &path, const DbCl data["hasPrimaryKey"] = (int)0; data["primaryKeyName"] = ""; data["dbName"] = _dbname; + data["rdbms"] = std::string("mysql"); std::vector cols; int i = 0; *client << "desc " + tableName << Mode::Blocking >> @@ -342,7 +344,12 @@ void create_model::createModelClassFromMysql(const std::string &path, const DbCl info._colType = "int64_t"; info._colLength = 8; } - else if (type.find("float") == 0 || type.find("double") == 0) + else if (type.find("float") == 0) + { + info._colType = "float"; + info._colLength = sizeof(float); + } + else if (type.find("double") == 0) { info._colType = "double"; info._colLength = sizeof(double); diff --git a/drogon_ctl/templates/model_cc.csp b/drogon_ctl/templates/model_cc.csp index 338f345f..ebc8978b 100644 --- a/drogon_ctl/templates/model_cc.csp +++ b/drogon_ctl/templates/model_cc.csp @@ -192,6 +192,26 @@ const std::string &{{className}}::getColumnName(size_t index) noexcept(false) } $$<<"\n"; } + if(@@.get("rdbms")=="postgresql") + { + $$<<"void "<("rdbms")=="mysql") + { + $$<<"void "<(id);\n"; + break; + } + } + $$<<"}\n"; + } if(@@.get("hasPrimaryKey")>1) { $$<<"typename "<>("columns"); void outputArgs(drogon::orm::internal::SqlBinder &binder) const; const std::vector updateColumns() const; void updateArgs(drogon::orm::internal::SqlBinder &binder) const; + ///For mysql only + void updateId(const unsigned long long id); <%c++ for(auto col:cols) { diff --git a/orm_lib/inc/drogon/orm/Mapper.h b/orm_lib/inc/drogon/orm/Mapper.h index 2e43c3c5..d652b3ab 100644 --- a/orm_lib/inc/drogon/orm/Mapper.h +++ b/orm_lib/inc/drogon/orm/Mapper.h @@ -646,13 +646,16 @@ inline void Mapper::insert(T &obj) noexcept(false) sql += ","; } sql[sql.length() - 1] = ')'; //Replace the last ',' - sql += "values ("; + sql += " values ("; for (int i = 0; i < T::insertColumns().size(); i++) { sql += "$?,"; } sql[sql.length() - 1] = ')'; //Replace the last ',' - sql += " returning *"; + if (_client->type() == ClientType::PostgreSQL) + { + sql += " returning *"; + } sql = replaceSqlPlaceHolder(sql, "$?"); Result r(nullptr); { @@ -664,8 +667,16 @@ inline void Mapper::insert(T &obj) noexcept(false) }; binder.exec(); //Maybe throw exception; } - assert(r.size() == 1); - obj = T(r[0]); + if (_client->type() == ClientType::PostgreSQL) + { + assert(r.size() == 1); + obj = T(r[0]); + } + else if (_client->type() == ClientType::Mysql) + { + auto id = r.insertId(); + obj.updateId(id); + } } template inline void Mapper::insert(const T &obj, @@ -682,19 +693,32 @@ inline void Mapper::insert(const T &obj, sql += ","; } sql[sql.length() - 1] = ')'; //Replace the last ',' - sql += "values ("; + sql += " values ("; for (int i = 0; i < T::insertColumns().size(); i++) { sql += "$?,"; } sql[sql.length() - 1] = ')'; //Replace the last ',' - sql += " returning *"; + if (_client->type() == ClientType::PostgreSQL) + { + sql += " returning *"; + } sql = replaceSqlPlaceHolder(sql, "$?"); auto binder = *_client << sql; obj.outputArgs(binder); binder >> [=](const Result &r) { - assert(r.size() == 1); - rcb(T(r[0])); + if (_client->type() == ClientType::PostgreSQL) + { + assert(r.size() == 1); + rcb(T(r[0])); + } + else if (_client->type() == ClientType::Mysql) + { + auto id = r.insertId(); + auto newObj = obj; + newObj.updateId(id); + rcb(newObj); + } }; binder >> ecb; } @@ -711,21 +735,34 @@ inline std::future Mapper::insertFuture(const T &obj) noexcept sql += ","; } sql[sql.length() - 1] = ')'; //Replace the last ',' - sql += "values ("; + sql += " values ("; for (int i = 0; i < T::insertColumns().size(); i++) { sql += "$?,"; } sql[sql.length() - 1] = ')'; //Replace the last ',' - sql += " returning *"; + if (_client->type() == ClientType::PostgreSQL) + { + sql += " returning *"; + } sql = replaceSqlPlaceHolder(sql, "$?"); auto binder = *_client << sql; obj.outputArgs(binder); std::shared_ptr> prom = std::make_shared>(); binder >> [=](const Result &r) { - assert(r.size() == 1); - prom->set_value(T(r[0])); + if (_client->type() == ClientType::PostgreSQL) + { + assert(r.size() == 1); + prom->set_value(T(r[0])); + } + else if (_client->type() == ClientType::Mysql) + { + auto id = r.insertId(); + auto newObj = obj; + newObj.updateId(id); + prom->set_value(newObj); + } }; binder >> [=](const std::exception_ptr &e) { prom->set_exception(e); diff --git a/orm_lib/inc/drogon/orm/Result.h b/orm_lib/inc/drogon/orm/Result.h index 7a669842..56a731dd 100644 --- a/orm_lib/inc/drogon/orm/Result.h +++ b/orm_lib/inc/drogon/orm/Result.h @@ -80,6 +80,9 @@ class Result size_type affectedRows() const noexcept; + /// For Mysql database only + unsigned long long insertId() const noexcept; + private: ResultImplPtr _resultPtr; std::string _query; diff --git a/orm_lib/inc/drogon/orm/SqlBinder.h b/orm_lib/inc/drogon/orm/SqlBinder.h index ad051cc4..bc96d0c6 100644 --- a/orm_lib/inc/drogon/orm/SqlBinder.h +++ b/orm_lib/inc/drogon/orm/SqlBinder.h @@ -11,7 +11,7 @@ */ #pragma once - +#include #include #include #include diff --git a/orm_lib/src/Result.cc b/orm_lib/src/Result.cc index 1a007ded..f5fc34ea 100644 --- a/orm_lib/src/Result.cc +++ b/orm_lib/src/Result.cc @@ -143,3 +143,7 @@ Result::field_size_type Result::getLength(Result::size_type row, Result::row_siz { return _resultPtr->getLength(row,column); } +unsigned long long Result::insertId() const noexcept +{ + return _resultPtr->insertId(); +} \ No newline at end of file diff --git a/orm_lib/src/ResultImpl.h b/orm_lib/src/ResultImpl.h index 43bd76e2..94375db4 100644 --- a/orm_lib/src/ResultImpl.h +++ b/orm_lib/src/ResultImpl.h @@ -29,6 +29,7 @@ public: virtual const char *getValue(size_type row, row_size_type column) const = 0; virtual bool isNull(size_type row, row_size_type column) const = 0; virtual field_size_type getLength(size_type row, row_size_type column) const = 0; + virtual unsigned long long insertId() const noexcept { return 0; } virtual ~ResultImpl() {} }; diff --git a/orm_lib/src/mysql_impl/MysqlConnection.cc b/orm_lib/src/mysql_impl/MysqlConnection.cc index 56293554..a3b60878 100644 --- a/orm_lib/src/mysql_impl/MysqlConnection.cc +++ b/orm_lib/src/mysql_impl/MysqlConnection.cc @@ -26,9 +26,10 @@ namespace orm Result makeResult(const std::shared_ptr &r = std::shared_ptr(nullptr), const std::string &query = "", - Result::size_type affectedRows = 0) + Result::size_type affectedRows = 0, + unsigned long long insertId = 0) { - return Result(std::shared_ptr(new MysqlResultImpl(r, query, affectedRows))); + return Result(std::shared_ptr(new MysqlResultImpl(r, query, affectedRows, insertId))); } } // namespace orm @@ -275,6 +276,7 @@ void MysqlConnection::execSql(const std::string &sql, const std::function &exceptCallback, const std::function &idleCb) { + LOG_TRACE << sql; assert(paraNum == parameters.size()); assert(paraNum == length.size()); assert(paraNum == format.size()); @@ -298,6 +300,7 @@ void MysqlConnection::execSql(const std::string &sql, if (seekPos == std::string::npos) { _sql.append(sql.substr(pos)); + pos = seekPos; break; } else @@ -336,6 +339,10 @@ void MysqlConnection::execSql(const std::string &sql, } } } + if (pos < sql.length()) + { + _sql.append(sql.substr(pos)); + } } else { @@ -412,7 +419,7 @@ void MysqlConnection::getResult(MYSQL_RES *res) auto resultPtr = std::shared_ptr(res, [](MYSQL_RES *r) { mysql_free_result(r); }); - auto Result = makeResult(resultPtr, _sql, mysql_affected_rows(_mysqlPtr.get())); + auto Result = makeResult(resultPtr, _sql, mysql_affected_rows(_mysqlPtr.get()), mysql_insert_id(_mysqlPtr.get())); if (_isWorking) { _cb(Result); diff --git a/orm_lib/src/mysql_impl/MysqlResultImpl.cc b/orm_lib/src/mysql_impl/MysqlResultImpl.cc index ab369a88..6075c456 100644 --- a/orm_lib/src/mysql_impl/MysqlResultImpl.cc +++ b/orm_lib/src/mysql_impl/MysqlResultImpl.cc @@ -65,3 +65,7 @@ Result::field_size_type MysqlResultImpl::getLength(size_type row, row_size_type assert(column < _fieldNum); return (*_rowsPtr)[row].second[column]; } +unsigned long long MysqlResultImpl::insertId() const noexcept +{ + return _insertId; +} diff --git a/orm_lib/src/mysql_impl/MysqlResultImpl.h b/orm_lib/src/mysql_impl/MysqlResultImpl.h index 91f9fe9e..f7299b15 100644 --- a/orm_lib/src/mysql_impl/MysqlResultImpl.h +++ b/orm_lib/src/mysql_impl/MysqlResultImpl.h @@ -28,13 +28,17 @@ namespace orm class MysqlResultImpl : public ResultImpl { public: - MysqlResultImpl(const std::shared_ptr &r, const std::string &query, size_type affectedRows) noexcept + MysqlResultImpl(const std::shared_ptr &r, + const std::string &query, + size_type affectedRows, + unsigned long long insertId) noexcept : _result(r), _query(query), _rowsNum(_result ? mysql_num_rows(_result.get()) : 0), _fieldArray(r ? mysql_fetch_fields(r.get()) : nullptr), _fieldNum(r ? mysql_num_fields(r.get()) : 0), - _affectedRows(affectedRows) + _affectedRows(affectedRows), + _insertId(insertId) { if (_fieldNum > 0) { @@ -68,6 +72,7 @@ class MysqlResultImpl : public ResultImpl virtual const char *getValue(size_type row, row_size_type column) const override; virtual bool isNull(size_type row, row_size_type column) const override; virtual field_size_type getLength(size_type row, row_size_type column) const override; + virtual unsigned long long insertId() const noexcept override; private: const std::shared_ptr _result; @@ -76,6 +81,7 @@ class MysqlResultImpl : public ResultImpl const MYSQL_FIELD *_fieldArray; const Result::row_size_type _fieldNum; const size_type _affectedRows; + const unsigned long long _insertId; std::shared_ptr> _fieldMapPtr; std::shared_ptr>>> _rowsPtr; };