From e87c9a75f65a813b3e681dd2efdfebbe4c380376 Mon Sep 17 00:00:00 2001 From: antao Date: Mon, 8 Apr 2019 16:37:24 +0800 Subject: [PATCH] Modify the implementation of WebSocket --- examples/simple_example_test/WebSocketTest.cc | 2 +- lib/inc/drogon/WebSocketClient.h | 2 +- lib/src/HttpAppFrameworkImpl.cc | 28 +- lib/src/HttpAppFrameworkImpl.h | 6 +- lib/src/HttpRequestParser.h | 2 +- lib/src/HttpServer.cc | 33 +- lib/src/HttpServer.h | 19 +- lib/src/HttpUtils.cc | 106 ------- lib/src/HttpUtils.h | 1 - lib/src/WebSockectConnectionImpl.cc | 178 ----------- lib/src/WebSockectConnectionImpl.h | 63 ---- lib/src/WebSocketClientImpl.cc | 39 +-- lib/src/WebSocketClientImpl.h | 11 +- lib/src/WebSocketConnectionImpl.cc | 292 ++++++++++++++++++ lib/src/WebSocketConnectionImpl.h | 145 +++++++++ lib/src/WebsocketControllersRouter.cc | 15 +- lib/src/WebsocketControllersRouter.h | 5 +- test.sh | 5 +- 18 files changed, 483 insertions(+), 469 deletions(-) delete mode 100644 lib/src/WebSockectConnectionImpl.cc delete mode 100644 lib/src/WebSockectConnectionImpl.h create mode 100644 lib/src/WebSocketConnectionImpl.cc create mode 100644 lib/src/WebSocketConnectionImpl.h diff --git a/examples/simple_example_test/WebSocketTest.cc b/examples/simple_example_test/WebSocketTest.cc index f5a2d74d..3de3a722 100644 --- a/examples/simple_example_test/WebSocketTest.cc +++ b/examples/simple_example_test/WebSocketTest.cc @@ -35,7 +35,7 @@ int main(int argc, char *argv[]) if (r == ReqResult::Ok) { std::cout << "ws connected!" << std::endl; - wsPtr->getConnection()->send("hello"); + wsPtr->getConnection()->send("hello!"); } else { diff --git a/lib/inc/drogon/WebSocketClient.h b/lib/inc/drogon/WebSocketClient.h index d2d193c1..2972216b 100644 --- a/lib/inc/drogon/WebSocketClient.h +++ b/lib/inc/drogon/WebSocketClient.h @@ -34,7 +34,7 @@ class WebSocketClient { public: /// Get the WebSocket connection that is typically used to send messages. - virtual const WebSocketConnectionPtr &getConnection() = 0; + virtual WebSocketConnectionPtr getConnection() = 0; /// Set messages handler. When a message is recieved from the server, the @param callback is called. virtual void setMessageHandler(const std::function &callback) = 0; diff --git a/lib/src/HttpAppFrameworkImpl.cc b/lib/src/HttpAppFrameworkImpl.cc index 0764ac44..4e1ad782 100755 --- a/lib/src/HttpAppFrameworkImpl.cc +++ b/lib/src/HttpAppFrameworkImpl.cc @@ -317,8 +317,6 @@ void HttpAppFrameworkImpl::run() } serverPtr->setHttpAsyncCallback(std::bind(&HttpAppFrameworkImpl::onAsyncRequest, this, _1, _2)); serverPtr->setNewWebsocketCallback(std::bind(&HttpAppFrameworkImpl::onNewWebsockRequest, this, _1, _2, _3)); - serverPtr->setWebsocketMessageCallback(std::bind(&HttpAppFrameworkImpl::onWebsockMessage, this, _1, _2, _3)); - serverPtr->setDisconnectWebsocketCallback(std::bind(&HttpAppFrameworkImpl::onWebsockDisconnect, this, _1)); serverPtr->setConnectionCallback(std::bind(&HttpAppFrameworkImpl::onConnection, this, _1)); serverPtr->kickoffIdleConnections(_idleConnectionTimeout); serverPtr->start(); @@ -356,8 +354,6 @@ void HttpAppFrameworkImpl::run() serverPtr->setIoLoopNum(_threadNum); serverPtr->setHttpAsyncCallback(std::bind(&HttpAppFrameworkImpl::onAsyncRequest, this, _1, _2)); serverPtr->setNewWebsocketCallback(std::bind(&HttpAppFrameworkImpl::onNewWebsockRequest, this, _1, _2, _3)); - serverPtr->setWebsocketMessageCallback(std::bind(&HttpAppFrameworkImpl::onWebsockMessage, this, _1, _2, _3)); - serverPtr->setDisconnectWebsocketCallback(std::bind(&HttpAppFrameworkImpl::onWebsockDisconnect, this, _1)); serverPtr->setConnectionCallback(std::bind(&HttpAppFrameworkImpl::onConnection, this, _1)); serverPtr->kickoffIdleConnections(_idleConnectionTimeout); serverPtr->start(); @@ -472,17 +468,7 @@ void HttpAppFrameworkImpl::createDbClients(const std::vector(wsConnPtr); - assert(wsConnImplPtr); - auto ctrl = wsConnImplPtr->controller(); - if (ctrl) - { - ctrl->handleConnectionClosed(wsConnPtr); - wsConnImplPtr->setController(WebSocketControllerBasePtr()); - } -} + void HttpAppFrameworkImpl::onConnection(const TcpConnectionPtr &conn) { static std::mutex mtx; @@ -540,16 +526,6 @@ void HttpAppFrameworkImpl::onConnection(const TcpConnectionPtr &conn) } } -void HttpAppFrameworkImpl::onWebsockMessage(const WebSocketConnectionPtr &wsConnPtr, std::string &&message, const WebSocketMessageType &type) -{ - auto wsConnImplPtr = std::dynamic_pointer_cast(wsConnPtr); - assert(wsConnImplPtr); - auto ctrl = wsConnImplPtr->controller(); - if (ctrl) - { - ctrl->handleNewMessage(wsConnPtr, std::move(message), type); - } -} void HttpAppFrameworkImpl::setUploadPath(const std::string &uploadPath) { @@ -571,7 +547,7 @@ void HttpAppFrameworkImpl::setUploadPath(const std::string &uploadPath) } void HttpAppFrameworkImpl::onNewWebsockRequest(const HttpRequestImplPtr &req, std::function &&callback, - const WebSocketConnectionPtr &wsConnPtr) + const WebSocketConnectionImplPtr &wsConnPtr) { _websockCtrlsRouter.route(req, std::move(callback), wsConnPtr); } diff --git a/lib/src/HttpAppFrameworkImpl.h b/lib/src/HttpAppFrameworkImpl.h index 2dec8fae..45b1220f 100644 --- a/lib/src/HttpAppFrameworkImpl.h +++ b/lib/src/HttpAppFrameworkImpl.h @@ -18,7 +18,7 @@ #include "HttpResponseImpl.h" #include "HttpClientImpl.h" #include "SharedLibManager.h" -#include "WebSockectConnectionImpl.h" +#include "WebSocketConnectionImpl.h" #include "HttpControllersRouter.h" #include "HttpSimpleControllersRouter.h" #include "WebsocketControllersRouter.h" @@ -177,9 +177,7 @@ class HttpAppFrameworkImpl : public HttpAppFramework void onAsyncRequest(const HttpRequestImplPtr &req, std::function &&callback); void onNewWebsockRequest(const HttpRequestImplPtr &req, std::function &&callback, - const WebSocketConnectionPtr &wsConnPtr); - void onWebsockMessage(const WebSocketConnectionPtr &wsConnPtr, std::string &&message, const WebSocketMessageType &type); - void onWebsockDisconnect(const WebSocketConnectionPtr &wsConnPtr); + const WebSocketConnectionImplPtr &wsConnPtr); void onConnection(const TcpConnectionPtr &conn); void addHttpPath(const std::string &path, const internal::HttpBinderBasePtr &binder, diff --git a/lib/src/HttpRequestParser.h b/lib/src/HttpRequestParser.h index 4b942aaf..b7ed634b 100755 --- a/lib/src/HttpRequestParser.h +++ b/lib/src/HttpRequestParser.h @@ -15,7 +15,7 @@ #pragma once #include "HttpRequestImpl.h" -#include "WebSockectConnectionImpl.h" +#include "WebSocketConnectionImpl.h" #include #include #include diff --git a/lib/src/HttpServer.cc b/lib/src/HttpServer.cc index c8e07030..1de0f98e 100755 --- a/lib/src/HttpServer.cc +++ b/lib/src/HttpServer.cc @@ -51,7 +51,7 @@ static void defaultHttpAsyncCallback(const HttpRequestPtr &, std::function &&callback, - const WebSocketConnectionPtr &wsConnPtr) + const WebSocketConnectionImplPtr &wsConnPtr) { auto resp = HttpResponse::newNotFoundResponse(); resp->setCloseConnection(true); @@ -104,7 +104,7 @@ void HttpServer::onConnection(const TcpConnectionPtr &conn) { if (requestParser->webSocketConn()) { - _disconnectWebsocketCallback(requestParser->webSocketConn()); + requestParser->webSocketConn()->onClose(); } #if (CXX_STD > 14) conn->getMutableContext()->reset(); //reset(): since c++17 @@ -125,33 +125,7 @@ void HttpServer::onMessage(const TcpConnectionPtr &conn, if (requestParser->webSocketConn()) { //Websocket payload - while (buf->readableBytes() > 0) - { - std::string message; - WebSocketMessageType type; - auto success = parseWebsockMessage(buf, message, type); - if (success) - { - if (type == WebSocketMessageType::Ping) - { - //ping - requestParser->webSocketConn()->send(message, WebSocketMessageType::Pong); - } - else if (type == WebSocketMessageType::Close) - { - //close - conn->shutdown(); - } - _webSocketMessageCallback(requestParser->webSocketConn(), std::move(message), type); - } - else - { - //Websock error! - conn->shutdown(); - return; - } - } - return; + requestParser->webSocketConn()->onNewMessage(conn, buf); } else { @@ -182,6 +156,7 @@ void HttpServer::onMessage(const TcpConnectionPtr &conn, if (resp->statusCode() == k101SwitchingProtocols) { requestParser->setWebsockConnection(wsConn); + } auto httpString = std::dynamic_pointer_cast(resp)->renderToString(); conn->send(httpString); diff --git a/lib/src/HttpServer.h b/lib/src/HttpServer.h index e9eb4e8c..1b8290af 100755 --- a/lib/src/HttpServer.h +++ b/lib/src/HttpServer.h @@ -14,9 +14,10 @@ #pragma once -#include "WebSockectConnectionImpl.h" +#include "WebSocketConnectionImpl.h" #include "HttpRequestImpl.h" #include +#include #include #include #include @@ -35,12 +36,8 @@ class HttpServer : trantor::NonCopyable typedef std::function &&)> HttpAsyncCallback; typedef std::function &&, - const WebSocketConnectionPtr &)> + const WebSocketConnectionImplPtr &)> WebSocketNewAsyncCallback; - typedef std::function - WebSocketDisconnetCallback; - typedef std::function - WebSocketMessageCallback; HttpServer(EventLoop *loop, const InetAddress &listenAddr, @@ -58,14 +55,6 @@ class HttpServer : trantor::NonCopyable { _newWebsocketCallback = cb; } - void setDisconnectWebsocketCallback(const WebSocketDisconnetCallback &cb) - { - _disconnectWebsocketCallback = cb; - } - void setWebsocketMessageCallback(const WebSocketMessageCallback &cb) - { - _webSocketMessageCallback = cb; - } void setConnectionCallback(const ConnectionCallback &cb) { _connectionCallback = cb; @@ -104,8 +93,6 @@ class HttpServer : trantor::NonCopyable trantor::TcpServer _server; HttpAsyncCallback _httpAsyncCallback; WebSocketNewAsyncCallback _newWebsocketCallback; - WebSocketDisconnetCallback _disconnectWebsocketCallback; - WebSocketMessageCallback _webSocketMessageCallback; trantor::ConnectionCallback _connectionCallback; }; diff --git a/lib/src/HttpUtils.cc b/lib/src/HttpUtils.cc index 6b0793f5..916fe090 100644 --- a/lib/src/HttpUtils.cc +++ b/lib/src/HttpUtils.cc @@ -375,110 +375,4 @@ const string_view &statusCodeToString(int code) } } -// Return false if any error -bool parseWebsockMessage(trantor::MsgBuffer *buffer, std::string &message, WebSocketMessageType &type) -{ - assert(message.empty()); - if (buffer->readableBytes() >= 2) - { - - unsigned char opcode = (*buffer)[0] & 0x0f; - switch (opcode) - { - case 1: - type = WebSocketMessageType::Text; - break; - case 2: - type = WebSocketMessageType::Binary; - break; - case 8: - type = WebSocketMessageType::Close; - break; - case 9: - type = WebSocketMessageType::Ping; - break; - case 10: - type = WebSocketMessageType::Pong; - break; - default: - type = WebSocketMessageType::Unknown; - break; - } - auto secondByte = (*buffer)[1]; - size_t length = secondByte & 127; - int isMasked = (secondByte & 0x80); - if (isMasked != 0) - { - LOG_TRACE << "data encoded!"; - } - else - LOG_TRACE << "plain data"; - size_t indexFirstMask = 2; - - if (length == 126) - { - indexFirstMask = 4; - } - else if (length == 127) - { - indexFirstMask = 10; - } - if (indexFirstMask > 2 && buffer->readableBytes() >= indexFirstMask) - { - if (indexFirstMask == 4) - { - length = (unsigned char)(*buffer)[2]; - length = (length << 8) + (unsigned char)(*buffer)[3]; - } - else if (indexFirstMask == 10) - { - length = (unsigned char)(*buffer)[2]; - length = (length << 8) + (unsigned char)(*buffer)[3]; - length = (length << 8) + (unsigned char)(*buffer)[4]; - length = (length << 8) + (unsigned char)(*buffer)[5]; - length = (length << 8) + (unsigned char)(*buffer)[6]; - length = (length << 8) + (unsigned char)(*buffer)[7]; - length = (length << 8) + (unsigned char)(*buffer)[8]; - length = (length << 8) + (unsigned char)(*buffer)[9]; - // length=*((uint64_t *)(buffer->peek()+2)); - // length=ntohll(length); - } - else - { - LOG_ERROR << "Websock parsing failed!"; - return false; - } - } - if (isMasked != 0) - { - if (buffer->readableBytes() >= (indexFirstMask + 4 + length)) - { - auto masks = buffer->peek() + indexFirstMask; - int indexFirstDataByte = indexFirstMask + 4; - auto rawData = buffer->peek() + indexFirstDataByte; - message.resize(length); - for (size_t i = 0; i < length; i++) - { - message[i] = (rawData[i] ^ masks[i % 4]); - } - buffer->retrieve(indexFirstMask + 4 + length); - LOG_TRACE << "got message len=" << message.length(); - return true; - } - } - else - { - if (buffer->readableBytes() >= (indexFirstMask + length)) - { - auto rawData = buffer->peek() + indexFirstMask; - message.append(rawData, length); - buffer->retrieve(indexFirstMask + length); - LOG_TRACE << "got message len=" << message.length(); - return true; - } - } - } - return true; -} - } // namespace drogon \ No newline at end of file diff --git a/lib/src/HttpUtils.h b/lib/src/HttpUtils.h index 09dd7d41..58478ee4 100644 --- a/lib/src/HttpUtils.h +++ b/lib/src/HttpUtils.h @@ -31,6 +31,5 @@ namespace drogon const string_view &webContentTypeToString(ContentType contenttype); const string_view &statusCodeToString(int code); -bool parseWebsockMessage(trantor::MsgBuffer *buffer, std::string &message, WebSocketMessageType &type); } // namespace drogon diff --git a/lib/src/WebSockectConnectionImpl.cc b/lib/src/WebSockectConnectionImpl.cc deleted file mode 100644 index e4bfbc84..00000000 --- a/lib/src/WebSockectConnectionImpl.cc +++ /dev/null @@ -1,178 +0,0 @@ -/** - * - * WebSocketConnectionImpl.cc - * An Tao - * - * Copyright 2018, An Tao. All rights reserved. - * https://github.com/an-tao/drogon - * Use of this source code is governed by a MIT license - * that can be found in the License file. - * - * Drogon - * - */ - -#include "WebSockectConnectionImpl.h" -#include -#include - -using namespace drogon; -WebSocketConnectionImpl::WebSocketConnectionImpl(const trantor::TcpConnectionPtr &conn, bool isServer) - : _tcpConn(conn), - _localAddr(conn->localAddr()), - _peerAddr(conn->peerAddr()), - _isServer(isServer) -{ -} - -void WebSocketConnectionImpl::send(const char *msg, uint64_t len, const WebSocketMessageType &type) -{ - unsigned char opcode; - if (type == WebSocketMessageType::Text) - opcode = 1; - else if (type == WebSocketMessageType::Binary) - opcode = 2; - else if (type == WebSocketMessageType::Close) - opcode = 8; - else if (type == WebSocketMessageType::Ping) - opcode = 9; - else if (type == WebSocketMessageType::Pong) - opcode = 10; - else - { - opcode = 0; - assert(0); - } - - sendWsData(msg, len, opcode); -} - -void WebSocketConnectionImpl::sendWsData(const char *msg, size_t len, unsigned char opcode) -{ - - LOG_TRACE << "send " << len << " bytes"; - auto conn = _tcpConn.lock(); - if (conn) - { - //Format the frame - std::string bytesFormatted; - bytesFormatted.resize(len + 10); - bytesFormatted[0] = char(0x80 | (opcode & 0x0f)); - - int indexStartRawData = -1; - - if (len <= 125) - { - bytesFormatted[1] = len; - indexStartRawData = 2; - } - else if (len <= 65535) - { - bytesFormatted[1] = 126; - bytesFormatted[2] = ((len >> 8) & 255); - bytesFormatted[3] = ((len)&255); - LOG_TRACE << "bytes[2]=" << (size_t)bytesFormatted[2]; - LOG_TRACE << "bytes[3]=" << (size_t)bytesFormatted[3]; - indexStartRawData = 4; - } - else - { - bytesFormatted[1] = 127; - bytesFormatted[2] = ((len >> 56) & 255); - bytesFormatted[3] = ((len >> 48) & 255); - bytesFormatted[4] = ((len >> 40) & 255); - bytesFormatted[5] = ((len >> 32) & 255); - bytesFormatted[6] = ((len >> 24) & 255); - bytesFormatted[7] = ((len >> 16) & 255); - bytesFormatted[8] = ((len >> 8) & 255); - bytesFormatted[9] = ((len)&255); - - indexStartRawData = 10; - } - if (!_isServer) - { - //Add masking key; - static std::once_flag once; - std::call_once(once, []() { - std::srand(time(nullptr)); - }); - int random = std::rand(); - - bytesFormatted[1] = (bytesFormatted[1] | 0x80); - bytesFormatted.resize(indexStartRawData + 4 + len); - *((int *)&bytesFormatted[indexStartRawData]) = random; - for (size_t i = 0; i < len; i++) - { - bytesFormatted[indexStartRawData + 4 + i] = (msg[i] ^ bytesFormatted[indexStartRawData + (i % 4)]); - } - } - else - { - bytesFormatted.resize(indexStartRawData); - bytesFormatted.append(msg, len); - } - - conn->send(bytesFormatted); - } -} -void WebSocketConnectionImpl::send(const std::string &msg, const WebSocketMessageType &type) -{ - send(msg.data(), msg.length(), type); -} -const trantor::InetAddress &WebSocketConnectionImpl::localAddr() const -{ - return _localAddr; -} -const trantor::InetAddress &WebSocketConnectionImpl::peerAddr() const -{ - return _peerAddr; -} - -bool WebSocketConnectionImpl::connected() const -{ - auto conn = _tcpConn.lock(); - if (conn) - { - return conn->connected(); - } - return false; -} -bool WebSocketConnectionImpl::disconnected() const -{ - auto conn = _tcpConn.lock(); - if (conn) - { - return conn->disconnected(); - } - return true; -} -void WebSocketConnectionImpl::WebSocketConnectionImpl::shutdown() -{ - auto conn = _tcpConn.lock(); - if (conn) - { - conn->shutdown(); - } -} -void WebSocketConnectionImpl::WebSocketConnectionImpl::forceClose() -{ - auto conn = _tcpConn.lock(); - if (conn) - { - conn->forceClose(); - } -} - -void WebSocketConnectionImpl::setContext(const any &context) -{ - _context = context; -} -const any &WebSocketConnectionImpl::WebSocketConnectionImpl::getContext() const -{ - return _context; -} -any *WebSocketConnectionImpl::WebSocketConnectionImpl::getMutableContext() -{ - return &_context; -} - diff --git a/lib/src/WebSockectConnectionImpl.h b/lib/src/WebSockectConnectionImpl.h deleted file mode 100644 index 4fb4824a..00000000 --- a/lib/src/WebSockectConnectionImpl.h +++ /dev/null @@ -1,63 +0,0 @@ -/** - * - * WebSocketConnectionImpl.h - * An Tao - * - * Copyright 2018, An Tao. All rights reserved. - * https://github.com/an-tao/drogon - * Use of this source code is governed by a MIT license - * that can be found in the License file. - * - * Drogon - * - */ - -#pragma once - -#include -#include -namespace drogon -{ -class WebSocketConnectionImpl : public WebSocketConnection -{ - public: - explicit WebSocketConnectionImpl(const trantor::TcpConnectionPtr &conn, bool isServer = true); - - virtual void send(const char *msg, uint64_t len, const WebSocketMessageType &type = WebSocketMessageType::Text) override; - virtual void send(const std::string &msg, const WebSocketMessageType &type = WebSocketMessageType::Text) override; - - virtual const trantor::InetAddress &localAddr() const override; - virtual const trantor::InetAddress &peerAddr() const override; - - virtual bool connected() const override; - virtual bool disconnected() const override; - - virtual void shutdown() override; //close write - virtual void forceClose() override; //close - - virtual void setContext(const any &context) override; - virtual const any &getContext() const override; - virtual any *getMutableContext() override; - - void setController(const WebSocketControllerBasePtr &ctrl) - { - _ctrlPtr = ctrl; - } - WebSocketControllerBasePtr controller() - { - return _ctrlPtr; - } - - private: - std::weak_ptr _tcpConn; - trantor::InetAddress _localAddr; - trantor::InetAddress _peerAddr; - WebSocketControllerBasePtr _ctrlPtr; - any _context; - bool _isServer = true; - - void sendWsData(const char *msg, size_t len, unsigned char opcode); -}; - -typedef std::shared_ptr WebSocketConnectionImplPtr; -} // namespace drogon diff --git a/lib/src/WebSocketClientImpl.cc b/lib/src/WebSocketClientImpl.cc index cb475825..ae9b32e7 100644 --- a/lib/src/WebSocketClientImpl.cc +++ b/lib/src/WebSocketClientImpl.cc @@ -121,6 +121,7 @@ void WebSocketClientImpl::connectToServerInLoop() { LOG_TRACE << "connection disconnect"; thisPtr->_connectionClosedCallback(thisPtr); + thisPtr->_websockConnPtr.reset(); thisPtr->_loop->runAfter(1.0, [thisPtr]() { thisPtr->reconnect(); }); @@ -154,36 +155,8 @@ void WebSocketClientImpl::connectToServerInLoop() void WebSocketClientImpl::onRecvWsMessage(const trantor::TcpConnectionPtr &connPtr, trantor::MsgBuffer *msgBuffer) { - std::string message; - WebSocketMessageType type; - auto success = parseWebsockMessage(msgBuffer, message, type); - if (success) - { - if (type == WebSocketMessageType::Close) - { - //close - connPtr->shutdown(); - } - else if (type == WebSocketMessageType::Ping) - { - //ping - if (_websockConnPtr) - { - _websockConnPtr->send(message, WebSocketMessageType::Pong); - } - } - _messageCallback(std::move(message), shared_from_this(), type); - } - else - { - //Websock error! - connPtr->shutdown(); - auto thisPtr = shared_from_this(); - _loop->runAfter(1.0, [thisPtr]() { - thisPtr->reconnect(); - }); - return; - } + assert(_websockConnPtr); + _websockConnPtr->onNewMessage(connPtr, msgBuffer); } void WebSocketClientImpl::onRecvMessage(const trantor::TcpConnectionPtr &connPtr, trantor::MsgBuffer *msgBuffer) @@ -234,6 +207,12 @@ void WebSocketClientImpl::onRecvMessage(const trantor::TcpConnectionPtr &connPtr _upgraded = true; _websockConnPtr = std::make_shared(connPtr, false); + auto thisPtr = shared_from_this(); + _websockConnPtr->setMessageCallback([thisPtr](std::string &&message, + const WebSocketConnectionImplPtr &connPtr, + const WebSocketMessageType &type) { + thisPtr->_messageCallback(std::move(message), thisPtr, type); + }); _requestCallback(ReqResult::Ok, resp, shared_from_this()); if (msgBuffer->readableBytes() > 0) { diff --git a/lib/src/WebSocketClientImpl.h b/lib/src/WebSocketClientImpl.h index 0fdc6d0b..9744ec3c 100644 --- a/lib/src/WebSocketClientImpl.h +++ b/lib/src/WebSocketClientImpl.h @@ -14,7 +14,7 @@ #pragma once -#include "WebSockectConnectionImpl.h" +#include "WebSocketConnectionImpl.h" #include #include #include @@ -29,7 +29,7 @@ namespace drogon class WebSocketClientImpl : public WebSocketClient, public std::enable_shared_from_this { public: - virtual const WebSocketConnectionPtr &getConnection() override + virtual WebSocketConnectionPtr getConnection() override { return _websockConnPtr; } @@ -50,6 +50,7 @@ class WebSocketClientImpl : public WebSocketClient, public std::enable_shared_fr virtual void connectToServer(const HttpRequestPtr &request, const WebSocketRequestCallback &callback) override { + assert(callback); if (_loop->isInLoopThread()) { _upgradeRequest = request; @@ -87,10 +88,10 @@ class WebSocketClientImpl : public WebSocketClient, public std::enable_shared_fr trantor::TimerId _heartbeatTimerId; HttpRequestPtr _upgradeRequest; - std::function _messageCallback; - std::function _connectionClosedCallback; + std::function _messageCallback = [](std::string &&message, const WebSocketClientPtr &, const WebSocketMessageType &) {}; + std::function _connectionClosedCallback = [](const WebSocketClientPtr &) {}; WebSocketRequestCallback _requestCallback; - WebSocketConnectionPtr _websockConnPtr; + WebSocketConnectionImplPtr _websockConnPtr; void connectToServerInLoop(); void sendReq(const trantor::TcpConnectionPtr &connPtr); diff --git a/lib/src/WebSocketConnectionImpl.cc b/lib/src/WebSocketConnectionImpl.cc new file mode 100644 index 00000000..994ce77b --- /dev/null +++ b/lib/src/WebSocketConnectionImpl.cc @@ -0,0 +1,292 @@ +/** + * + * WebSocketConnectionImpl.cc + * An Tao + * + * Copyright 2018, An Tao. All rights reserved. + * https://github.com/an-tao/drogon + * Use of this source code is governed by a MIT license + * that can be found in the License file. + * + * Drogon + * + */ + +#include "WebSocketConnectionImpl.h" +#include +#include +#include + +using namespace drogon; +WebSocketConnectionImpl::WebSocketConnectionImpl(const trantor::TcpConnectionPtr &conn, bool isServer) + : _tcpConn(conn), + _localAddr(conn->localAddr()), + _peerAddr(conn->peerAddr()), + _isServer(isServer) +{ +} + +void WebSocketConnectionImpl::send(const char *msg, uint64_t len, const WebSocketMessageType &type) +{ + unsigned char opcode; + if (type == WebSocketMessageType::Text) + opcode = 1; + else if (type == WebSocketMessageType::Binary) + opcode = 2; + else if (type == WebSocketMessageType::Close) + { + assert(len <= 125); + opcode = 8; + } + else if (type == WebSocketMessageType::Ping) + { + assert(len <= 125); + opcode = 9; + } + else if (type == WebSocketMessageType::Pong) + { + assert(len <= 125); + opcode = 10; + } + else + { + opcode = 0; + assert(0); + } + sendWsData(msg, len, opcode); +} + +void WebSocketConnectionImpl::sendWsData(const char *msg, size_t len, unsigned char opcode) +{ + + LOG_TRACE << "send " << len << " bytes"; + + //Format the frame + std::string bytesFormatted; + bytesFormatted.resize(len + 10); + bytesFormatted[0] = char(0x80 | (opcode & 0x0f)); + + int indexStartRawData = -1; + + if (len <= 125) + { + bytesFormatted[1] = len; + indexStartRawData = 2; + } + else if (len <= 65535) + { + bytesFormatted[1] = 126; + bytesFormatted[2] = ((len >> 8) & 255); + bytesFormatted[3] = ((len)&255); + LOG_TRACE << "bytes[2]=" << (size_t)bytesFormatted[2]; + LOG_TRACE << "bytes[3]=" << (size_t)bytesFormatted[3]; + indexStartRawData = 4; + } + else + { + bytesFormatted[1] = 127; + bytesFormatted[2] = ((len >> 56) & 255); + bytesFormatted[3] = ((len >> 48) & 255); + bytesFormatted[4] = ((len >> 40) & 255); + bytesFormatted[5] = ((len >> 32) & 255); + bytesFormatted[6] = ((len >> 24) & 255); + bytesFormatted[7] = ((len >> 16) & 255); + bytesFormatted[8] = ((len >> 8) & 255); + bytesFormatted[9] = ((len)&255); + + indexStartRawData = 10; + } + if (!_isServer) + { + //Add masking key; + static std::once_flag once; + std::call_once(once, []() { + std::srand(time(nullptr)); + }); + int random = std::rand(); + + bytesFormatted[1] = (bytesFormatted[1] | 0x80); + bytesFormatted.resize(indexStartRawData + 4 + len); + *((int *)&bytesFormatted[indexStartRawData]) = random; + for (size_t i = 0; i < len; i++) + { + bytesFormatted[indexStartRawData + 4 + i] = (msg[i] ^ bytesFormatted[indexStartRawData + (i % 4)]); + } + } + else + { + bytesFormatted.resize(indexStartRawData); + bytesFormatted.append(msg, len); + } + _tcpConn->send(bytesFormatted); +} +void WebSocketConnectionImpl::send(const std::string &msg, const WebSocketMessageType &type) +{ + send(msg.data(), msg.length(), type); +} +const trantor::InetAddress &WebSocketConnectionImpl::localAddr() const +{ + return _localAddr; +} +const trantor::InetAddress &WebSocketConnectionImpl::peerAddr() const +{ + return _peerAddr; +} + +bool WebSocketConnectionImpl::connected() const +{ + return _tcpConn->connected(); +} +bool WebSocketConnectionImpl::disconnected() const +{ + return _tcpConn->disconnected(); +} +void WebSocketConnectionImpl::WebSocketConnectionImpl::shutdown() +{ + _tcpConn->shutdown(); +} +void WebSocketConnectionImpl::WebSocketConnectionImpl::forceClose() +{ + _tcpConn->forceClose(); +} + +void WebSocketConnectionImpl::setContext(const any &context) +{ + _context = context; +} +const any &WebSocketConnectionImpl::WebSocketConnectionImpl::getContext() const +{ + return _context; +} +any *WebSocketConnectionImpl::WebSocketConnectionImpl::getMutableContext() +{ + return &_context; +} + +bool WebSocketMessageParser::parse(trantor::MsgBuffer *buffer) +{ + //According to the rfc6455 + _gotAll = false; + if (buffer->readableBytes() >= 2) + { + unsigned char opcode = (*buffer)[0] & 0x0f; + bool isControlFrame = false; + switch (opcode) + { + case 0: + //continuation frame + break; + case 1: + _type = WebSocketMessageType::Text; + break; + case 2: + _type = WebSocketMessageType::Binary; + break; + case 8: + _type = WebSocketMessageType::Close; + isControlFrame = true; + break; + case 9: + _type = WebSocketMessageType::Ping; + isControlFrame = true; + break; + case 10: + _type = WebSocketMessageType::Pong; + isControlFrame = true; + break; + default: + LOG_ERROR << "Unknown frame type"; + return false; + break; + } + + bool isFin = (((*buffer)[0] & 0x80) == 0x80); + if (!isFin && isControlFrame) + { + //rfc6455-5.5 + LOG_ERROR << "Bad frame: all control frames MUST NOT be fragmented"; + return false; + } + auto secondByte = (*buffer)[1]; + size_t length = secondByte & 127; + int isMasked = (secondByte & 0x80); + if (isMasked != 0) + { + LOG_TRACE << "data encoded!"; + } + else + LOG_TRACE << "plain data"; + size_t indexFirstMask = 2; + + if (length == 126) + { + indexFirstMask = 4; + } + else if (length == 127) + { + indexFirstMask = 10; + } + if (indexFirstMask > 2 && buffer->readableBytes() >= indexFirstMask) + { + if (isControlFrame) + { + //rfc6455-5.5 + LOG_ERROR << "Bad frame: all control frames MUST have a payload length of 125 bytes or less"; + return false; + } + if (indexFirstMask == 4) + { + length = (unsigned char)(*buffer)[2]; + length = (length << 8) + (unsigned char)(*buffer)[3]; + } + else if (indexFirstMask == 10) + { + length = (unsigned char)(*buffer)[2]; + length = (length << 8) + (unsigned char)(*buffer)[3]; + length = (length << 8) + (unsigned char)(*buffer)[4]; + length = (length << 8) + (unsigned char)(*buffer)[5]; + length = (length << 8) + (unsigned char)(*buffer)[6]; + length = (length << 8) + (unsigned char)(*buffer)[7]; + length = (length << 8) + (unsigned char)(*buffer)[8]; + length = (length << 8) + (unsigned char)(*buffer)[9]; + } + else + { + LOG_ERROR << "Websock parsing failed!"; + return false; + } + } + if (isMasked != 0) + { + if (buffer->readableBytes() >= (indexFirstMask + 4 + length)) + { + auto masks = buffer->peek() + indexFirstMask; + int indexFirstDataByte = indexFirstMask + 4; + auto rawData = buffer->peek() + indexFirstDataByte; + auto oldLen = _message.length(); + _message.resize(oldLen + length); + for (size_t i = 0; i < length; i++) + { + _message[oldLen + i] = (rawData[i] ^ masks[i % 4]); + } + if (isFin) + _gotAll = true; + buffer->retrieve(indexFirstMask + 4 + length); + return true; + } + } + else + { + if (buffer->readableBytes() >= (indexFirstMask + length)) + { + auto rawData = buffer->peek() + indexFirstMask; + _message.append(rawData, length); + if (isFin) + _gotAll = true; + buffer->retrieve(indexFirstMask + length); + return true; + } + } + } + return true; +} \ No newline at end of file diff --git a/lib/src/WebSocketConnectionImpl.h b/lib/src/WebSocketConnectionImpl.h new file mode 100644 index 00000000..e906d47d --- /dev/null +++ b/lib/src/WebSocketConnectionImpl.h @@ -0,0 +1,145 @@ +/** + * + * WebSocketConnectionImpl.h + * An Tao + * + * Copyright 2018, An Tao. All rights reserved. + * https://github.com/an-tao/drogon + * Use of this source code is governed by a MIT license + * that can be found in the License file. + * + * Drogon + * + */ + +#pragma once + +#include +#include +namespace drogon +{ + +class WebSocketConnectionImpl; +typedef std::shared_ptr WebSocketConnectionImplPtr; + +class WebSocketMessageParser +{ + public: + bool parse(trantor::MsgBuffer *buffer); + bool gotAll(std::string &message, WebSocketMessageType &type) + { + assert(message.empty()); + if (!_gotAll) + return false; + message.swap(_message); + type = _type; + return true; + } + + private: + std::string _message; + WebSocketMessageType _type; + bool _gotAll = false; +}; + +class WebSocketConnectionImpl : public WebSocketConnection, public std::enable_shared_from_this +{ + public: + explicit WebSocketConnectionImpl(const trantor::TcpConnectionPtr &conn, bool isServer = true); + + virtual void send(const char *msg, uint64_t len, const WebSocketMessageType &type = WebSocketMessageType::Text) override; + virtual void send(const std::string &msg, const WebSocketMessageType &type = WebSocketMessageType::Text) override; + + virtual const trantor::InetAddress &localAddr() const override; + virtual const trantor::InetAddress &peerAddr() const override; + + virtual bool connected() const override; + virtual bool disconnected() const override; + + virtual void shutdown() override; //close write + virtual void forceClose() override; //close + + virtual void setContext(const any &context) override; + virtual const any &getContext() const override; + virtual any *getMutableContext() override; + + void setMessageCallback(const std::function &callback) + { + _messageCallback = callback; + } + + void setCloseCallback(const std::function &callback) + { + _closeCallback = callback; + } + + void onNewMessage(const trantor::TcpConnectionPtr &connPtr, trantor::MsgBuffer *buffer) + { + while (buffer->readableBytes() > 0) + { + + auto success = _parser.parse(buffer); + if (success) + { + std::string message; + WebSocketMessageType type; + if (_parser.gotAll(message, type)) + { + if (type == WebSocketMessageType::Ping) + { + //ping + send(message, WebSocketMessageType::Pong); + } + else if (type == WebSocketMessageType::Close) + { + //close + connPtr->shutdown(); + } + else if (type == WebSocketMessageType::Unknown) + { + return; + } + _messageCallback(std::move(message), shared_from_this(), type); + } + else + { + return; + } + } + else + { + //Websock error! + connPtr->shutdown(); + return; + } + } + return; + } + + void onClose() + { + _closeCallback(shared_from_this()); + } + + private: + trantor::TcpConnectionPtr _tcpConn; + trantor::InetAddress _localAddr; + trantor::InetAddress _peerAddr; + any _context; + bool _isServer = true; + std::function + _messageCallback = [](std::string &&, + const WebSocketConnectionImplPtr &, + const WebSocketMessageType &) {}; + std::function _closeCallback = [](const WebSocketConnectionImplPtr &) {}; + + void sendWsData(const char *msg, size_t len, unsigned char opcode); + + WebSocketMessageParser _parser; +}; + +} // namespace drogon diff --git a/lib/src/WebsocketControllersRouter.cc b/lib/src/WebsocketControllersRouter.cc index 4d776067..9f7eaeed 100644 --- a/lib/src/WebsocketControllersRouter.cc +++ b/lib/src/WebsocketControllersRouter.cc @@ -41,7 +41,7 @@ void WebsocketControllersRouter::registerWebSocketController(const std::string & void WebsocketControllersRouter::route(const HttpRequestImplPtr &req, std::function &&callback, - const WebSocketConnectionPtr &wsConnPtr) + const WebSocketConnectionImplPtr &wsConnPtr) { std::string wsKey = req->getHeaderBy("sec-websocket-key"); if (!wsKey.empty()) @@ -81,7 +81,7 @@ void WebsocketControllersRouter::doControllerHandler(const WebSocketControllerBa std::string &wsKey, const HttpRequestImplPtr &req, std::function &&callback, - const WebSocketConnectionPtr &wsConnPtr) + const WebSocketConnectionImplPtr &wsConnPtr) { wsKey.append("258EAFA5-E914-47DA-95CA-C5AB0DC85B11"); unsigned char accKey[SHA_DIGEST_LENGTH]; @@ -93,9 +93,14 @@ void WebsocketControllersRouter::doControllerHandler(const WebSocketControllerBa resp->addHeader("Connection", "Upgrade"); resp->addHeader("Sec-WebSocket-Accept", base64Key); callback(resp); - auto wsConnImplPtr = std::dynamic_pointer_cast(wsConnPtr); - assert(wsConnImplPtr); - wsConnImplPtr->setController(ctrlPtr); + wsConnPtr->setMessageCallback([ctrlPtr](std::string &&message, + const WebSocketConnectionImplPtr &connPtr, + const WebSocketMessageType &type) { + ctrlPtr->handleNewMessage(connPtr, std::move(message), type); + }); + wsConnPtr->setCloseCallback([ctrlPtr](const WebSocketConnectionImplPtr &connPtr) { + ctrlPtr->handleConnectionClosed(connPtr); + }); ctrlPtr->handleNewConnection(req, wsConnPtr); return; } diff --git a/lib/src/WebsocketControllersRouter.h b/lib/src/WebsocketControllersRouter.h index 2c0db1a7..cecaa1a3 100644 --- a/lib/src/WebsocketControllersRouter.h +++ b/lib/src/WebsocketControllersRouter.h @@ -15,6 +15,7 @@ #pragma once #include "HttpRequestImpl.h" #include "HttpResponseImpl.h" +#include "WebSocketConnectionImpl.h" #include #include #include @@ -36,7 +37,7 @@ class WebsocketControllersRouter : public trantor::NonCopyable const std::vector &filters); void route(const HttpRequestImplPtr &req, std::function &&callback, - const WebSocketConnectionPtr &wsConnPtr); + const WebSocketConnectionImplPtr &wsConnPtr); void init(); private: @@ -53,6 +54,6 @@ class WebsocketControllersRouter : public trantor::NonCopyable std::string &wsKey, const HttpRequestImplPtr &req, std::function &&callback, - const WebSocketConnectionPtr &wsConnPtr); + const WebSocketConnectionImplPtr &wsConnPtr); }; } // namespace drogon \ No newline at end of file diff --git a/test.sh b/test.sh index f5c5bce5..6b6deed8 100755 --- a/test.sh +++ b/test.sh @@ -19,6 +19,7 @@ killall -9 webapp sleep 4 +echo "Test http requests and responses." ./webapp_test if [ $? -ne 0 ];then @@ -27,6 +28,7 @@ if [ $? -ne 0 ];then fi #Test WebSocket +echo "Test the WebSocket" ./websocket_test -t if [ $? -ne 0 ];then echo "Error in testing" @@ -34,6 +36,7 @@ if [ $? -ne 0 ];then fi #Test pipelining +echo "Test the pipelining" ./pipelining_test if [ $? -ne 0 ];then echo "Error in testing" @@ -43,7 +46,7 @@ fi killall -9 webapp #Test drogon_ctl - +echo "Test the drogon_ctl" rm -rf drogon_test drogon_ctl create project drogon_test