Mysql works with 'stmt'

This commit is contained in:
antao 2018-11-29 19:07:36 +08:00
parent cfe55d5495
commit 4bfe8250e9
7 changed files with 317 additions and 58 deletions

View File

@ -19,6 +19,9 @@
#include <drogon/orm/FunctionTraits.h>
#include <drogon/orm/Exception.h>
#include <trantor/utils/Logger.h>
#ifdef USE_MYSQL
#include <mysql.h>
#endif
#include <string.h>
#include <string>
#include <iostream>
@ -32,6 +35,11 @@ namespace drogon
{
namespace orm
{
enum class ClientType
{
PostgreSQL = 0,
Mysql
};
class DbClient;
typedef std::function<void(const Result &)> QueryCallback;
@ -201,7 +209,7 @@ class SqlBinder
public:
friend class Dbclient;
SqlBinder(const std::string &sql, DbClient &client) : _sql(sql), _client(client)
SqlBinder(const std::string &sql, DbClient &client, ClientType type) : _sql(sql), _client(client), _type(type)
{
}
~SqlBinder();
@ -240,26 +248,53 @@ class SqlBinder
typedef typename std::remove_cv<typename std::remove_reference<T>::type>::type ParaType;
std::shared_ptr<void> obj =
std::make_shared<ParaType>(parameter);
switch (sizeof(T))
if (_type == ClientType::PostgreSQL)
{
case 2:
*std::static_pointer_cast<short>(obj) = ntohs(parameter);
break;
case 4:
*std::static_pointer_cast<int>(obj) = ntohl(parameter);
break;
case 8:
*std::static_pointer_cast<long>(obj) = ntohll(parameter);
break;
case 1:
default:
switch (sizeof(T))
{
case 2:
*std::static_pointer_cast<short>(obj) = ntohs(parameter);
break;
case 4:
*std::static_pointer_cast<int>(obj) = ntohl(parameter);
break;
case 8:
*std::static_pointer_cast<long>(obj) = ntohll(parameter);
break;
case 1:
default:
break;
break;
}
_objs.push_back(obj);
_parameters.push_back((char *)obj.get());
_length.push_back(sizeof(T));
_format.push_back(1);
}
else if (_type == ClientType::Mysql)
{
#ifdef USE_MYSQL
_objs.push_back(obj);
_parameters.push_back((char *)obj.get());
_length.push_back(0);
switch (sizeof(T))
{
case 1:
_format.push_back(MYSQL_TYPE_TINY);
break;
case 2:
_format.push_back(MYSQL_TYPE_SHORT);
break;
case 4:
_format.push_back(MYSQL_TYPE_LONG);
break;
case 8:
_format.push_back(MYSQL_TYPE_LONGLONG);
default:
break;
}
#endif
}
_objs.push_back(obj);
_parameters.push_back((char *)obj.get());
_length.push_back(sizeof(T));
_format.push_back(1);
//LOG_TRACE << "Bind parameter:" << parameter;
return *this;
}
@ -279,7 +314,16 @@ class SqlBinder
_paraNum++;
_parameters.push_back((char *)obj->c_str());
_length.push_back(obj->length());
_format.push_back(0);
if (_type == ClientType::PostgreSQL)
{
_format.push_back(0);
}
else if (_type == ClientType::Mysql)
{
#ifdef USE_MYSQL
_format.push_back(MYSQL_TYPE_VAR_STRING);
#endif
}
return *this;
}
self &operator<<(std::string &str)
@ -293,7 +337,16 @@ class SqlBinder
_paraNum++;
_parameters.push_back((char *)obj->c_str());
_length.push_back(obj->length());
_format.push_back(0);
if (_type == ClientType::PostgreSQL)
{
_format.push_back(0);
}
else if (_type == ClientType::Mysql)
{
#ifdef USE_MYSQL
_format.push_back(MYSQL_TYPE_VAR_STRING);
#endif
}
return *this;
}
self &operator<<(trantor::Date &&date)
@ -350,7 +403,7 @@ class SqlBinder
_mode = mode;
return *this;
}
void exec() noexcept(false);
private:
@ -368,6 +421,7 @@ class SqlBinder
bool _execed = false;
bool _destructed = false;
bool _isExceptPtr = false;
ClientType _type;
};
} // namespace internal

View File

@ -20,7 +20,7 @@ using namespace drogon;
internal::SqlBinder DbClient::operator<<(const std::string &sql)
{
return internal::SqlBinder(sql, *this);
return internal::SqlBinder(sql, *this, ((DbClientImpl *)this)->type());
}
#if USE_POSTGRESQL

View File

@ -14,13 +14,7 @@ namespace drogon
{
namespace orm
{
// extern Result makeResult(SqlStatus status, const std::shared_ptr<PGresult> &r = std::shared_ptr<PGresult>(nullptr),
// const std::string &query = "");
enum class ClientType
{
PostgreSQL = 0,
Mysql
};
class DbClientImpl : public DbClient, public std::enable_shared_from_this<DbClientImpl>
{
public:
@ -35,6 +29,7 @@ class DbClientImpl : public DbClient, public std::enable_shared_from_this<DbClie
const std::function<void(const std::exception_ptr &)> &exceptCallback) override;
virtual std::string replaceSqlPlaceHolder(const std::string &sqlStr, const std::string &holderStr) const override;
virtual std::shared_ptr<Transaction> newTransaction() override;
ClientType type() { return _type; }
private:
void ioLoop();

View File

@ -13,6 +13,7 @@
#include "MysqlConnection.h"
#include "MysqlResultImpl.h"
#include "MysqlStmtResultImpl.h"
#include <drogon/utils/Utilities.h>
#include <regex>
#include <algorithm>
@ -31,6 +32,13 @@ Result makeResult(const std::shared_ptr<MYSQL_RES> &r = std::shared_ptr<MYSQL_RE
return Result(std::shared_ptr<MysqlResultImpl>(new MysqlResultImpl(r, query, affectedRows)));
}
Result makeResult(const std::shared_ptr<MYSQL_STMT> &r = std::shared_ptr<MYSQL_STMT>(nullptr),
const std::string &query = "")
{
return Result(std::shared_ptr<MysqlStmtResultImpl>(new MysqlStmtResultImpl(r, query)));
}
} // namespace orm
} // namespace drogon
@ -227,6 +235,7 @@ void MysqlConnection::handleEvent()
return;
}
_waitStatus = mysql_stmt_execute_start(&err, _stmtPtr.get());
LOG_TRACE << "mysql_stmt_execute_start:" << _waitStatus;
_execStatus = ExecStatus_StmtExec;
if (_waitStatus == 0)
{
@ -236,6 +245,7 @@ void MysqlConnection::handleEvent()
outputStmtError();
return;
}
_waitStatus = mysql_stmt_store_result_start(&err, _stmtPtr.get());
_execStatus = ExecStatus_StmtStoreResult;
if (_waitStatus == 0)
@ -310,6 +320,7 @@ void MysqlConnection::handleEvent()
{
int err;
_waitStatus = mysql_stmt_execute_cont(&err, _stmtPtr.get(), status);
LOG_TRACE << "mysql_stmt_execute_cont:" << _waitStatus;
if (_waitStatus == 0)
{
_execStatus = ExecStatus_None;
@ -319,6 +330,7 @@ void MysqlConnection::handleEvent()
return;
}
_waitStatus = mysql_stmt_store_result_start(&err, _stmtPtr.get());
LOG_TRACE << "mysql_stmt_store_result_start:" << _waitStatus;
_execStatus = ExecStatus_StmtStoreResult;
if (_waitStatus == 0)
{
@ -338,6 +350,7 @@ void MysqlConnection::handleEvent()
{
int err;
_waitStatus = mysql_stmt_store_result_cont(&err, _stmtPtr.get(), status);
LOG_TRACE << "mysql_stmt_store_result_cont:" << _waitStatus;
if (_waitStatus == 0)
{
_execStatus = ExecStatus_None;
@ -425,17 +438,21 @@ void MysqlConnection::execSql(const std::string &sql,
outputError();
return;
}
my_bool flag = 1;
mysql_stmt_attr_set(_stmtPtr.get(), STMT_ATTR_UPDATE_MAX_LENGTH, &flag);
_binds.resize(paraNum);
_lengths.resize(paraNum);
_isNulls.resize(paraNum);
memset(_binds.data(), 0, sizeof(MYSQL_BIND) * paraNum);
for (size_t i = 0; i < paraNum; i++)
{
_binds[i].buffer = (void *)parameters[i];
_binds[i].buffer_type = (enum_field_types)format[i];
_binds[i].buffer_length = length[i];
_binds[i].length = (length[i] == 0 ? 0 : (_lengths[i] = length[i], &_lengths[i]));
_binds[i].is_null = (parameters[i] == NULL ? 0 : (_isNulls[i] = true, &_isNulls[i]));
_binds[i].is_null = (parameters[i] != NULL ? 0 : (_isNulls[i] = true, &_isNulls[i]));
}
_loop->runInLoop([=]() {
int err;
@ -452,24 +469,6 @@ void MysqlConnection::execSql(const std::string &sql,
_execStatus = ExecStatus_StmtPrepare;
setChannel();
});
// /* Get the parameter count from the statement */
// auto param_count = mysql_stmt_param_count(stmt);
// if (param_count != paraNum)
// {
// //FIXME,exception callback
// return;
// }
// MYSQL_BIND *bind = new MYSQL_BIND[param_count];
// for (int i = 0; i < param_count; i++)
// {
// bind[i].buffer_type = (enum_field_types)format[i];
// bind[i].buffer = (char *)parameters[i];
// bind[i].buffer_length = length[i];
// bind[i].is_null = parameters[i] == NULL ? 1 : 0;
// bind[i].length = &length[i];
// }
//delete[] bind;
}
void MysqlConnection::outputError()
@ -552,10 +551,8 @@ void MysqlConnection::getResult(MYSQL_RES *res)
void MysqlConnection::getStmtResult()
{
auto resultPtr = std::shared_ptr<MYSQL_RES>(_stmtPtr->default_rset_handler(_stmtPtr.get()), [](MYSQL_RES *r) {
mysql_free_result(r);
});
auto Result = makeResult(resultPtr, _sql, mysql_affected_rows(_mysqlPtr.get()));
LOG_TRACE << "Got " << mysql_stmt_num_rows(_stmtPtr.get()) << " rows";
auto Result = makeResult(_stmtPtr, _sql);
if (_isWorking)
{
_cb(Result);

View File

@ -0,0 +1,66 @@
/**
*
* MysqlStmtResultImpl.cc
* An Tao
*
* Copyright 2018, An Tao. All rights reserved.
* Use of this source code is governed by a MIT license
* that can be found in the License file.
*
*
*/
#include "MysqlStmtResultImpl.h"
#include <assert.h>
#include <algorithm>
using namespace drogon::orm;
Result::size_type MysqlStmtResultImpl::size() const noexcept
{
return _rowsNum;
}
Result::row_size_type MysqlStmtResultImpl::columns() const noexcept
{
return _fieldNum;
}
const char *MysqlStmtResultImpl::columnName(row_size_type number) const
{
assert(number < _fieldNum);
if (_fieldArray)
return _fieldArray[number].name;
return "";
}
Result::size_type MysqlStmtResultImpl::affectedRows() const noexcept
{
return _affectedRows;
}
Result::row_size_type MysqlStmtResultImpl::columnNumber(const char colName[]) const
{
std::string col(colName);
std::transform(col.begin(), col.end(), col.begin(), tolower);
auto iter = _fieldMap.find(col);
if (iter != _fieldMap.end())
return iter->second;
return -1;
}
const char *MysqlStmtResultImpl::getValue(size_type row, row_size_type column) const
{
if (_rowsNum == 0 || _fieldNum == 0)
return NULL;
assert(row < _rowsNum);
assert(column < _fieldNum);
return _rows[row].first[column];
}
bool MysqlStmtResultImpl::isNull(size_type row, row_size_type column) const
{
return getValue(row, column) == NULL;
}
Result::field_size_type MysqlStmtResultImpl::getLength(size_type row, row_size_type column) const
{
if (_rowsNum == 0 || _fieldNum == 0)
return 0;
assert(row < _rowsNum);
assert(column < _fieldNum);
return _rows[row].second[column];
}

View File

@ -0,0 +1,133 @@
/**
*
* MysqlStmtResultImpl.h
* An Tao
*
* Copyright 2018, An Tao. All rights reserved.
* Use of this source code is governed by a MIT license
* that can be found in the License file.
*
*
*/
#pragma once
#include "../ResultImpl.h"
#include <mysql.h>
#include <trantor/utils/Logger.h>
#include <memory>
#include <string>
#include <unordered_map>
#include <algorithm>
#include <vector>
#include <string.h>
namespace drogon
{
namespace orm
{
class MysqlStmtResultImpl : public ResultImpl
{
public:
MysqlStmtResultImpl(const std::shared_ptr<MYSQL_STMT> &r, const std::string &query) noexcept
: _result(r),
_metaData(mysql_stmt_result_metadata(r.get()), [](MYSQL_RES *p) {
if (p)
mysql_free_result(p); }),
_query(query),
_rowsNum(r ? mysql_stmt_num_rows(r.get()) : 0),
_fieldArray(_metaData ? mysql_fetch_fields(_metaData.get()) : nullptr),
_fieldNum(_metaData ? mysql_num_fields(_metaData.get()) : 0),
_affectedRows(r ? mysql_stmt_affected_rows(r.get()) : 0)
{
MYSQL_BIND binds[_fieldNum];
memset(binds, 0, sizeof(MYSQL_BIND) * _fieldNum);
unsigned long lengths[_fieldNum];
my_bool isNulls[_fieldNum];
std::shared_ptr<char> buffers[_fieldNum];
char fakeBuf;
if (_fieldNum > 0)
{
for (row_size_type i = 0; i < _fieldNum; i++)
{
std::string fieldName = _fieldArray[i].name;
std::transform(fieldName.begin(), fieldName.end(), fieldName.begin(), tolower);
_fieldMap[fieldName] = i;
LOG_TRACE << "row[" << fieldName << "].max_length=" << _fieldArray[i].max_length;
if (_rowsNum > 0)
{
if (_fieldArray[i].max_length > 0)
{
buffers[i] = std::shared_ptr<char>(new char[_fieldArray[i].max_length + 1], [](char *p) { delete[] p; });
binds[i].buffer = buffers[i].get();
binds[i].buffer_length = _fieldArray[i].max_length + 1;
}
else
{
binds[i].buffer = &fakeBuf;
binds[i].buffer_length = 1;
}
binds[i].length = &lengths[i];
binds[i].is_null = &isNulls[i];
binds[i].buffer_type = _fieldArray[i].type;
}
}
}
if (size() > 0)
{
if (mysql_stmt_bind_result(r.get(), binds))
{
fprintf(stderr, " mysql_stmt_bind_result() failed\n");
fprintf(stderr, " %s\n", mysql_stmt_error(r.get()));
exit(-1);
}
while (!mysql_stmt_fetch(r.get()))
{
std::vector<char *> row;
std::vector<field_size_type> lengths;
for (row_size_type i = 0; i < _fieldNum; i++)
{
if (*(binds[i].is_null))
{
row.push_back(NULL);
lengths.push_back(0);
}
else
{
auto data = std::shared_ptr<char>(new char[*(binds[i].length)], [](char *p) { delete[] p; });
memcpy(data.get(), binds[i].buffer, *binds[i].length);
_resultData.push_back(data);
row.push_back(data.get());
lengths.push_back(*(binds[i].length));
}
}
_rows.push_back(std::make_pair(row, lengths));
}
}
}
virtual size_type size() const noexcept override;
virtual row_size_type columns() const noexcept override;
virtual const char *columnName(row_size_type Number) const override;
virtual size_type affectedRows() const noexcept override;
virtual row_size_type columnNumber(const char colName[]) const override;
virtual const char *getValue(size_type row, row_size_type column) const override;
virtual bool isNull(size_type row, row_size_type column) const override;
virtual field_size_type getLength(size_type row, row_size_type column) const override;
private:
const std::shared_ptr<MYSQL_STMT> _result;
const std::shared_ptr<MYSQL_RES> _metaData;
const std::string _query;
const Result::size_type _rowsNum;
const MYSQL_FIELD *_fieldArray;
const Result::row_size_type _fieldNum;
const size_type _affectedRows;
std::unordered_map<std::string, row_size_type> _fieldMap;
std::vector<std::pair<std::vector<char *>, std::vector<field_size_type>>> _rows;
std::vector<std::shared_ptr<char>> _resultData;
};
} // namespace orm
} // namespace drogon

View File

@ -11,7 +11,7 @@ int main()
trantor::Logger::setLogLevel(trantor::Logger::TRACE);
auto clientPtr = DbClient::newMysqlClient("host= 127.0.0.1 port =3306 dbname= test user = root ", 1);
sleep(1);
for (int i = 0; i < 10;i++)
for (int i = 0; i < 10; i++)
{
std::string str = formattedString("insert into users (user_id,user_name,org_name) values('%d','antao','default')", i);
*clientPtr << str >> [](const Result &r) {
@ -20,14 +20,14 @@ int main()
std::cerr << e.base().what() << std::endl;
};
}
*clientPtr << "select * from users" >> [](const Result &r) {
std::cout << "rows:" << r.size() << std::endl;
std::cout << "column num:" << r.columns() << std::endl;
for(auto row:r)
{
std::cout << "user_id=" << row["user_id"].as<std::string>() << std::endl;
}
// for (auto row : r)
// {
// std::cout << "user_id=" << row["user_id"].as<std::string>() << std::endl;
// }
// for (auto row : r)
// {
// for (auto f : row)
@ -38,5 +38,19 @@ int main()
} >> [](const DrogonDbException &e) {
std::cerr << e.base().what() << std::endl;
};
*clientPtr << "select * from users where id!=? order by id"
<< 139 >>
[](const Result &r) {
std::cout << "rows:" << r.size() << std::endl;
std::cout << "column num:" << r.columns() << std::endl;
// for (auto row : r)
// {
// std::cout << "user_id=" << row["user_id"].as<std::string>() << " id=" << row["id"].as<int>() << std::endl;
// }
} >>
[](const DrogonDbException &e) {
std::cerr << e.base().what() << std::endl;
};
getchar();
}