Modify some functions

This commit is contained in:
antao 2019-03-12 13:52:10 +08:00
parent d4710d3da7
commit aa539f85cd
6 changed files with 118 additions and 100 deletions

View File

@ -496,93 +496,18 @@ void HttpAppFrameworkImpl::onConnection(const TcpConnectionPtr &conn)
} }
} }
} }
std::string parseWebsockFrame(trantor::MsgBuffer *buffer)
{
if (buffer->readableBytes() >= 2)
{
auto secondByte = (*buffer)[1];
size_t length = secondByte & 127;
int isMasked = (secondByte & 0x80);
if (isMasked != 0)
{
LOG_TRACE << "data encoded!";
}
else
LOG_TRACE << "plain data";
size_t indexFirstMask = 2;
if (length == 126) void HttpAppFrameworkImpl::onWebsockMessage(const WebSocketConnectionPtr &wsConnPtr, std::string &&message)
{
indexFirstMask = 4;
}
else if (length == 127)
{
indexFirstMask = 10;
}
if (indexFirstMask > 2 && buffer->readableBytes() >= indexFirstMask)
{
if (indexFirstMask == 4)
{
length = (unsigned char)(*buffer)[2];
length = (length << 8) + (unsigned char)(*buffer)[3];
LOG_TRACE << "bytes[2]=" << (unsigned char)(*buffer)[2];
LOG_TRACE << "bytes[3]=" << (unsigned char)(*buffer)[3];
}
else if (indexFirstMask == 10)
{
length = (unsigned char)(*buffer)[2];
length = (length << 8) + (unsigned char)(*buffer)[3];
length = (length << 8) + (unsigned char)(*buffer)[4];
length = (length << 8) + (unsigned char)(*buffer)[5];
length = (length << 8) + (unsigned char)(*buffer)[6];
length = (length << 8) + (unsigned char)(*buffer)[7];
length = (length << 8) + (unsigned char)(*buffer)[8];
length = (length << 8) + (unsigned char)(*buffer)[9];
// length=*((uint64_t *)(buffer->peek()+2));
// length=ntohll(length);
}
else
{
assert(0);
}
}
LOG_TRACE << "websocket message len=" << length;
if (buffer->readableBytes() >= (indexFirstMask + 4 + length))
{
auto masks = buffer->peek() + indexFirstMask;
int indexFirstDataByte = indexFirstMask + 4;
auto rawData = buffer->peek() + indexFirstDataByte;
std::string message;
message.resize(length);
LOG_TRACE << "rawData[0]=" << (unsigned char)rawData[0];
LOG_TRACE << "masks[0]=" << (unsigned char)masks[0];
for (size_t i = 0; i < length; i++)
{
message[i] = (rawData[i] ^ masks[i % 4]);
}
buffer->retrieve(indexFirstMask + 4 + length);
LOG_TRACE << "got message len=" << message.length();
return message;
}
}
return std::string();
}
void HttpAppFrameworkImpl::onWebsockMessage(const WebSocketConnectionPtr &wsConnPtr,
trantor::MsgBuffer *buffer)
{ {
auto wsConnImplPtr = std::dynamic_pointer_cast<WebSocketConnectionImpl>(wsConnPtr); auto wsConnImplPtr = std::dynamic_pointer_cast<WebSocketConnectionImpl>(wsConnPtr);
assert(wsConnImplPtr); assert(wsConnImplPtr);
auto ctrl = wsConnImplPtr->controller(); auto ctrl = wsConnImplPtr->controller();
if (ctrl) if (ctrl)
{ {
std::string message; ctrl->handleNewMessage(wsConnPtr, std::move(message));
while (!(message = parseWebsockFrame(buffer)).empty())
{
LOG_TRACE << "Got websock message:" << message;
ctrl->handleNewMessage(wsConnPtr, std::move(message));
}
} }
} }
void HttpAppFrameworkImpl::setUploadPath(const std::string &uploadPath) void HttpAppFrameworkImpl::setUploadPath(const std::string &uploadPath)
{ {
assert(!uploadPath.empty()); assert(!uploadPath.empty());

View File

@ -151,7 +151,7 @@ class HttpAppFrameworkImpl : public HttpAppFramework
void onNewWebsockRequest(const HttpRequestImplPtr &req, void onNewWebsockRequest(const HttpRequestImplPtr &req,
std::function<void(const HttpResponsePtr &)> &&callback, std::function<void(const HttpResponsePtr &)> &&callback,
const WebSocketConnectionPtr &wsConnPtr); const WebSocketConnectionPtr &wsConnPtr);
void onWebsockMessage(const WebSocketConnectionPtr &wsConnPtr, trantor::MsgBuffer *buffer); void onWebsockMessage(const WebSocketConnectionPtr &wsConnPtr, std::string &&message);
void onWebsockDisconnect(const WebSocketConnectionPtr &wsConnPtr); void onWebsockDisconnect(const WebSocketConnectionPtr &wsConnPtr);
void onConnection(const TcpConnectionPtr &conn); void onConnection(const TcpConnectionPtr &conn);
void addHttpPath(const std::string &path, void addHttpPath(const std::string &path,

View File

@ -240,7 +240,7 @@ class HttpRequestImpl : public HttpRequest
return _date; return _date;
} }
void setReceiveDate(const trantor::Date &date) void setCreationDate(const trantor::Date &date)
{ {
_date = date; _date = date;
} }

View File

@ -73,7 +73,7 @@ bool HttpRequestParser::processRequestLine(const char *begin, const char *end)
return succeed; return succeed;
} }
// return false if any error // Return false if any error
bool HttpRequestParser::parseRequest(MsgBuffer *buf) bool HttpRequestParser::parseRequest(MsgBuffer *buf)
{ {
bool ok = true; bool ok = true;

View File

@ -26,6 +26,92 @@ using namespace std::placeholders;
using namespace drogon; using namespace drogon;
using namespace trantor; using namespace trantor;
// Return false if any error
static bool parseWebsockMessage(MsgBuffer *buffer, std::string &message)
{
assert(message.empty());
if (buffer->readableBytes() >= 2)
{
auto secondByte = (*buffer)[1];
size_t length = secondByte & 127;
int isMasked = (secondByte & 0x80);
if (isMasked != 0)
{
LOG_TRACE << "data encoded!";
}
else
LOG_TRACE << "plain data";
size_t indexFirstMask = 2;
if (length == 126)
{
indexFirstMask = 4;
}
else if (length == 127)
{
indexFirstMask = 10;
}
if (indexFirstMask > 2 && buffer->readableBytes() >= indexFirstMask)
{
if (indexFirstMask == 4)
{
length = (unsigned char)(*buffer)[2];
length = (length << 8) + (unsigned char)(*buffer)[3];
LOG_TRACE << "bytes[2]=" << (unsigned char)(*buffer)[2];
LOG_TRACE << "bytes[3]=" << (unsigned char)(*buffer)[3];
}
else if (indexFirstMask == 10)
{
length = (unsigned char)(*buffer)[2];
length = (length << 8) + (unsigned char)(*buffer)[3];
length = (length << 8) + (unsigned char)(*buffer)[4];
length = (length << 8) + (unsigned char)(*buffer)[5];
length = (length << 8) + (unsigned char)(*buffer)[6];
length = (length << 8) + (unsigned char)(*buffer)[7];
length = (length << 8) + (unsigned char)(*buffer)[8];
length = (length << 8) + (unsigned char)(*buffer)[9];
// length=*((uint64_t *)(buffer->peek()+2));
// length=ntohll(length);
}
else
{
LOG_ERROR << "Websock parsing failed!";
return false;
}
}
LOG_TRACE << "websocket message len=" << length;
if (buffer->readableBytes() >= (indexFirstMask + 4 + length))
{
auto masks = buffer->peek() + indexFirstMask;
int indexFirstDataByte = indexFirstMask + 4;
auto rawData = buffer->peek() + indexFirstDataByte;
message.resize(length);
LOG_TRACE << "rawData[0]=" << (unsigned char)rawData[0];
LOG_TRACE << "masks[0]=" << (unsigned char)masks[0];
for (size_t i = 0; i < length; i++)
{
message[i] = (rawData[i] ^ masks[i % 4]);
}
buffer->retrieve(indexFirstMask + 4 + length);
LOG_TRACE << "got message len=" << message.length();
return true;
}
}
return true;
}
static bool isWebSocket(const HttpRequestImplPtr &req)
{
if (req->getHeaderBy("connection") == "Upgrade" &&
req->getHeaderBy("upgrade") == "websocket")
{
LOG_TRACE << "new websocket request";
return true;
}
return false;
}
static void defaultHttpAsyncCallback(const HttpRequestPtr &, std::function<void(const HttpResponsePtr &resp)> &&callback) static void defaultHttpAsyncCallback(const HttpRequestPtr &, std::function<void(const HttpResponsePtr &resp)> &&callback)
{ {
auto resp = HttpResponse::newNotFoundResponse(); auto resp = HttpResponse::newNotFoundResponse();
@ -105,13 +191,32 @@ void HttpServer::onMessage(const TcpConnectionPtr &conn,
HttpRequestParser *requestParser = any_cast<HttpRequestParser>(conn->getMutableContext()); HttpRequestParser *requestParser = any_cast<HttpRequestParser>(conn->getMutableContext());
int counter = 0; int counter = 0;
// With the pipelining feature or web socket, it is possible to receice multiple messages at once, so // With the pipelining feature or web socket, it is possible to receice multiple messages at once, so
// the while loop is necessary // the while loop is necessary
while (buf->readableBytes() > 0) while (buf->readableBytes() > 0)
{ {
if (requestParser->webSocketConn()) if (requestParser->webSocketConn())
{ {
//Websocket payload,we shouldn't parse it //Websocket payload
_webSocketMessageCallback(requestParser->webSocketConn(), buf); while (1)
{
std::string message;
auto success = parseWebsockMessage(buf, message);
if (success)
{
if (message.empty())
break;
else
{
_webSocketMessageCallback(requestParser->webSocketConn(), std::move(message));
}
}
else
{
//Websock error!
conn->shutdown();
return;
}
}
return; return;
} }
if (!requestParser->parseRequest(buf)) if (!requestParser->parseRequest(buf))
@ -125,8 +230,8 @@ void HttpServer::onMessage(const TcpConnectionPtr &conn,
{ {
requestParser->requestImpl()->setPeerAddr(conn->peerAddr()); requestParser->requestImpl()->setPeerAddr(conn->peerAddr());
requestParser->requestImpl()->setLocalAddr(conn->localAddr()); requestParser->requestImpl()->setLocalAddr(conn->localAddr());
requestParser->requestImpl()->setReceiveDate(trantor::Date::date()); requestParser->requestImpl()->setCreationDate(trantor::Date::date());
if (requestParser->firstReq() && isWebSocket(conn, requestParser->requestImpl())) if (requestParser->firstReq() && isWebSocket(requestParser->requestImpl()))
{ {
auto wsConn = std::make_shared<WebSocketConnectionImpl>(conn); auto wsConn = std::make_shared<WebSocketConnectionImpl>(conn);
_newWebsocketCallback(requestParser->requestImpl(), _newWebsocketCallback(requestParser->requestImpl(),
@ -154,18 +259,6 @@ void HttpServer::onMessage(const TcpConnectionPtr &conn,
} }
} }
bool HttpServer::isWebSocket(const TcpConnectionPtr &conn, const HttpRequestImplPtr &req)
{
if (req->getHeaderBy("connection") == "Upgrade" &&
req->getHeaderBy("upgrade") == "websocket")
{
LOG_TRACE << "new websocket request";
return true;
}
return false;
}
void HttpServer::onRequest(const TcpConnectionPtr &conn, const HttpRequestImplPtr &req) void HttpServer::onRequest(const TcpConnectionPtr &conn, const HttpRequestImplPtr &req)
{ {
const std::string &connection = req->getHeaderBy("connection"); const std::string &connection = req->getHeaderBy("connection");
@ -327,3 +420,4 @@ void HttpServer::sendResponse(const TcpConnectionPtr &conn,
conn->shutdown(); conn->shutdown();
} }
} }

View File

@ -39,7 +39,7 @@ class HttpServer : trantor::NonCopyable
WebSocketNewAsyncCallback; WebSocketNewAsyncCallback;
typedef std::function<void(const WebSocketConnectionPtr &)> typedef std::function<void(const WebSocketConnectionPtr &)>
WebSocketDisconnetCallback; WebSocketDisconnetCallback;
typedef std::function<void(const WebSocketConnectionPtr &, trantor::MsgBuffer *)> typedef std::function<void(const WebSocketConnectionPtr &, std::string &&message)>
WebSocketMessageCallback; WebSocketMessageCallback;
HttpServer(EventLoop *loop, HttpServer(EventLoop *loop,
@ -100,7 +100,6 @@ class HttpServer : trantor::NonCopyable
void onMessage(const TcpConnectionPtr &, void onMessage(const TcpConnectionPtr &,
MsgBuffer *); MsgBuffer *);
void onRequest(const TcpConnectionPtr &, const HttpRequestImplPtr &); void onRequest(const TcpConnectionPtr &, const HttpRequestImplPtr &);
bool isWebSocket(const TcpConnectionPtr &conn, const HttpRequestImplPtr &req);
void sendResponse(const TcpConnectionPtr &, const HttpResponsePtr &, bool isHeadMethod); void sendResponse(const TcpConnectionPtr &, const HttpResponsePtr &, bool isHeadMethod);
trantor::TcpServer _server; trantor::TcpServer _server;
HttpAsyncCallback _httpAsyncCallback; HttpAsyncCallback _httpAsyncCallback;