From aa539f85cd2a27702a2405c16dc7bad0a149f14a Mon Sep 17 00:00:00 2001 From: antao Date: Tue, 12 Mar 2019 13:52:10 +0800 Subject: [PATCH] Modify some functions --- lib/src/HttpAppFrameworkImpl.cc | 81 +------------------- lib/src/HttpAppFrameworkImpl.h | 2 +- lib/src/HttpRequestImpl.h | 2 +- lib/src/HttpRequestParser.cc | 2 +- lib/src/HttpServer.cc | 128 +++++++++++++++++++++++++++----- lib/src/HttpServer.h | 3 +- 6 files changed, 118 insertions(+), 100 deletions(-) diff --git a/lib/src/HttpAppFrameworkImpl.cc b/lib/src/HttpAppFrameworkImpl.cc index 97a8e308..2efb91c7 100755 --- a/lib/src/HttpAppFrameworkImpl.cc +++ b/lib/src/HttpAppFrameworkImpl.cc @@ -496,93 +496,18 @@ void HttpAppFrameworkImpl::onConnection(const TcpConnectionPtr &conn) } } } -std::string parseWebsockFrame(trantor::MsgBuffer *buffer) -{ - if (buffer->readableBytes() >= 2) - { - auto secondByte = (*buffer)[1]; - size_t length = secondByte & 127; - int isMasked = (secondByte & 0x80); - if (isMasked != 0) - { - LOG_TRACE << "data encoded!"; - } - else - LOG_TRACE << "plain data"; - size_t indexFirstMask = 2; - if (length == 126) - { - indexFirstMask = 4; - } - else if (length == 127) - { - indexFirstMask = 10; - } - if (indexFirstMask > 2 && buffer->readableBytes() >= indexFirstMask) - { - if (indexFirstMask == 4) - { - length = (unsigned char)(*buffer)[2]; - length = (length << 8) + (unsigned char)(*buffer)[3]; - LOG_TRACE << "bytes[2]=" << (unsigned char)(*buffer)[2]; - LOG_TRACE << "bytes[3]=" << (unsigned char)(*buffer)[3]; - } - else if (indexFirstMask == 10) - { - length = (unsigned char)(*buffer)[2]; - length = (length << 8) + (unsigned char)(*buffer)[3]; - length = (length << 8) + (unsigned char)(*buffer)[4]; - length = (length << 8) + (unsigned char)(*buffer)[5]; - length = (length << 8) + (unsigned char)(*buffer)[6]; - length = (length << 8) + (unsigned char)(*buffer)[7]; - length = (length << 8) + (unsigned char)(*buffer)[8]; - length = (length << 8) + (unsigned char)(*buffer)[9]; - // length=*((uint64_t *)(buffer->peek()+2)); - // length=ntohll(length); - } - else - { - assert(0); - } - } - LOG_TRACE << "websocket message len=" << length; - if (buffer->readableBytes() >= (indexFirstMask + 4 + length)) - { - auto masks = buffer->peek() + indexFirstMask; - int indexFirstDataByte = indexFirstMask + 4; - auto rawData = buffer->peek() + indexFirstDataByte; - std::string message; - message.resize(length); - LOG_TRACE << "rawData[0]=" << (unsigned char)rawData[0]; - LOG_TRACE << "masks[0]=" << (unsigned char)masks[0]; - for (size_t i = 0; i < length; i++) - { - message[i] = (rawData[i] ^ masks[i % 4]); - } - buffer->retrieve(indexFirstMask + 4 + length); - LOG_TRACE << "got message len=" << message.length(); - return message; - } - } - return std::string(); -} -void HttpAppFrameworkImpl::onWebsockMessage(const WebSocketConnectionPtr &wsConnPtr, - trantor::MsgBuffer *buffer) +void HttpAppFrameworkImpl::onWebsockMessage(const WebSocketConnectionPtr &wsConnPtr, std::string &&message) { auto wsConnImplPtr = std::dynamic_pointer_cast(wsConnPtr); assert(wsConnImplPtr); auto ctrl = wsConnImplPtr->controller(); if (ctrl) { - std::string message; - while (!(message = parseWebsockFrame(buffer)).empty()) - { - LOG_TRACE << "Got websock message:" << message; - ctrl->handleNewMessage(wsConnPtr, std::move(message)); - } + ctrl->handleNewMessage(wsConnPtr, std::move(message)); } } + void HttpAppFrameworkImpl::setUploadPath(const std::string &uploadPath) { assert(!uploadPath.empty()); diff --git a/lib/src/HttpAppFrameworkImpl.h b/lib/src/HttpAppFrameworkImpl.h index 5101c880..7e7d163f 100644 --- a/lib/src/HttpAppFrameworkImpl.h +++ b/lib/src/HttpAppFrameworkImpl.h @@ -151,7 +151,7 @@ class HttpAppFrameworkImpl : public HttpAppFramework void onNewWebsockRequest(const HttpRequestImplPtr &req, std::function &&callback, const WebSocketConnectionPtr &wsConnPtr); - void onWebsockMessage(const WebSocketConnectionPtr &wsConnPtr, trantor::MsgBuffer *buffer); + void onWebsockMessage(const WebSocketConnectionPtr &wsConnPtr, std::string &&message); void onWebsockDisconnect(const WebSocketConnectionPtr &wsConnPtr); void onConnection(const TcpConnectionPtr &conn); void addHttpPath(const std::string &path, diff --git a/lib/src/HttpRequestImpl.h b/lib/src/HttpRequestImpl.h index 49ccdc92..87a55831 100755 --- a/lib/src/HttpRequestImpl.h +++ b/lib/src/HttpRequestImpl.h @@ -240,7 +240,7 @@ class HttpRequestImpl : public HttpRequest return _date; } - void setReceiveDate(const trantor::Date &date) + void setCreationDate(const trantor::Date &date) { _date = date; } diff --git a/lib/src/HttpRequestParser.cc b/lib/src/HttpRequestParser.cc index 55d80e0e..8948dde3 100755 --- a/lib/src/HttpRequestParser.cc +++ b/lib/src/HttpRequestParser.cc @@ -73,7 +73,7 @@ bool HttpRequestParser::processRequestLine(const char *begin, const char *end) return succeed; } -// return false if any error +// Return false if any error bool HttpRequestParser::parseRequest(MsgBuffer *buf) { bool ok = true; diff --git a/lib/src/HttpServer.cc b/lib/src/HttpServer.cc index fdfb0e78..8733de10 100755 --- a/lib/src/HttpServer.cc +++ b/lib/src/HttpServer.cc @@ -26,6 +26,92 @@ using namespace std::placeholders; using namespace drogon; using namespace trantor; +// Return false if any error +static bool parseWebsockMessage(MsgBuffer *buffer, std::string &message) +{ + assert(message.empty()); + if (buffer->readableBytes() >= 2) + { + auto secondByte = (*buffer)[1]; + size_t length = secondByte & 127; + int isMasked = (secondByte & 0x80); + if (isMasked != 0) + { + LOG_TRACE << "data encoded!"; + } + else + LOG_TRACE << "plain data"; + size_t indexFirstMask = 2; + + if (length == 126) + { + indexFirstMask = 4; + } + else if (length == 127) + { + indexFirstMask = 10; + } + if (indexFirstMask > 2 && buffer->readableBytes() >= indexFirstMask) + { + if (indexFirstMask == 4) + { + length = (unsigned char)(*buffer)[2]; + length = (length << 8) + (unsigned char)(*buffer)[3]; + LOG_TRACE << "bytes[2]=" << (unsigned char)(*buffer)[2]; + LOG_TRACE << "bytes[3]=" << (unsigned char)(*buffer)[3]; + } + else if (indexFirstMask == 10) + { + length = (unsigned char)(*buffer)[2]; + length = (length << 8) + (unsigned char)(*buffer)[3]; + length = (length << 8) + (unsigned char)(*buffer)[4]; + length = (length << 8) + (unsigned char)(*buffer)[5]; + length = (length << 8) + (unsigned char)(*buffer)[6]; + length = (length << 8) + (unsigned char)(*buffer)[7]; + length = (length << 8) + (unsigned char)(*buffer)[8]; + length = (length << 8) + (unsigned char)(*buffer)[9]; + // length=*((uint64_t *)(buffer->peek()+2)); + // length=ntohll(length); + } + else + { + LOG_ERROR << "Websock parsing failed!"; + return false; + } + } + LOG_TRACE << "websocket message len=" << length; + if (buffer->readableBytes() >= (indexFirstMask + 4 + length)) + { + auto masks = buffer->peek() + indexFirstMask; + int indexFirstDataByte = indexFirstMask + 4; + auto rawData = buffer->peek() + indexFirstDataByte; + message.resize(length); + LOG_TRACE << "rawData[0]=" << (unsigned char)rawData[0]; + LOG_TRACE << "masks[0]=" << (unsigned char)masks[0]; + for (size_t i = 0; i < length; i++) + { + message[i] = (rawData[i] ^ masks[i % 4]); + } + buffer->retrieve(indexFirstMask + 4 + length); + LOG_TRACE << "got message len=" << message.length(); + return true; + } + } + return true; +} + +static bool isWebSocket(const HttpRequestImplPtr &req) +{ + if (req->getHeaderBy("connection") == "Upgrade" && + req->getHeaderBy("upgrade") == "websocket") + { + LOG_TRACE << "new websocket request"; + + return true; + } + return false; +} + static void defaultHttpAsyncCallback(const HttpRequestPtr &, std::function &&callback) { auto resp = HttpResponse::newNotFoundResponse(); @@ -105,13 +191,32 @@ void HttpServer::onMessage(const TcpConnectionPtr &conn, HttpRequestParser *requestParser = any_cast(conn->getMutableContext()); int counter = 0; // With the pipelining feature or web socket, it is possible to receice multiple messages at once, so - // the while loop is necessary + // the while loop is necessary while (buf->readableBytes() > 0) { if (requestParser->webSocketConn()) { - //Websocket payload,we shouldn't parse it - _webSocketMessageCallback(requestParser->webSocketConn(), buf); + //Websocket payload + while (1) + { + std::string message; + auto success = parseWebsockMessage(buf, message); + if (success) + { + if (message.empty()) + break; + else + { + _webSocketMessageCallback(requestParser->webSocketConn(), std::move(message)); + } + } + else + { + //Websock error! + conn->shutdown(); + return; + } + } return; } if (!requestParser->parseRequest(buf)) @@ -125,8 +230,8 @@ void HttpServer::onMessage(const TcpConnectionPtr &conn, { requestParser->requestImpl()->setPeerAddr(conn->peerAddr()); requestParser->requestImpl()->setLocalAddr(conn->localAddr()); - requestParser->requestImpl()->setReceiveDate(trantor::Date::date()); - if (requestParser->firstReq() && isWebSocket(conn, requestParser->requestImpl())) + requestParser->requestImpl()->setCreationDate(trantor::Date::date()); + if (requestParser->firstReq() && isWebSocket(requestParser->requestImpl())) { auto wsConn = std::make_shared(conn); _newWebsocketCallback(requestParser->requestImpl(), @@ -154,18 +259,6 @@ void HttpServer::onMessage(const TcpConnectionPtr &conn, } } -bool HttpServer::isWebSocket(const TcpConnectionPtr &conn, const HttpRequestImplPtr &req) -{ - if (req->getHeaderBy("connection") == "Upgrade" && - req->getHeaderBy("upgrade") == "websocket") - { - LOG_TRACE << "new websocket request"; - - return true; - } - return false; -} - void HttpServer::onRequest(const TcpConnectionPtr &conn, const HttpRequestImplPtr &req) { const std::string &connection = req->getHeaderBy("connection"); @@ -327,3 +420,4 @@ void HttpServer::sendResponse(const TcpConnectionPtr &conn, conn->shutdown(); } } + diff --git a/lib/src/HttpServer.h b/lib/src/HttpServer.h index 73f3279d..eca573c2 100755 --- a/lib/src/HttpServer.h +++ b/lib/src/HttpServer.h @@ -39,7 +39,7 @@ class HttpServer : trantor::NonCopyable WebSocketNewAsyncCallback; typedef std::function WebSocketDisconnetCallback; - typedef std::function + typedef std::function WebSocketMessageCallback; HttpServer(EventLoop *loop, @@ -100,7 +100,6 @@ class HttpServer : trantor::NonCopyable void onMessage(const TcpConnectionPtr &, MsgBuffer *); void onRequest(const TcpConnectionPtr &, const HttpRequestImplPtr &); - bool isWebSocket(const TcpConnectionPtr &conn, const HttpRequestImplPtr &req); void sendResponse(const TcpConnectionPtr &, const HttpResponsePtr &, bool isHeadMethod); trantor::TcpServer _server; HttpAsyncCallback _httpAsyncCallback;