diff --git a/drogon_ctl/create.cc b/drogon_ctl/create.cc index 8d05c096..3bd75b0c 100755 --- a/drogon_ctl/create.cc +++ b/drogon_ctl/create.cc @@ -27,7 +27,9 @@ std::string create::detail() "drogon_ctl create controller [-s] [-n ] //" "create HttpSimpleController source files\n" "drogon_ctl create controller -a <[namespace::]class_name> //" - "create HttpApiController source files\n"; + "create HttpApiController source files\n" + "drogon_ctl create controller -w [-n ] //" + "create WebSocketController source files\n"; } void create::handleCommand(std::vector ¶meters) diff --git a/drogon_ctl/create_controller.cc b/drogon_ctl/create_controller.cc index 9b2726f8..694b49d8 100755 --- a/drogon_ctl/create_controller.cc +++ b/drogon_ctl/create_controller.cc @@ -39,6 +39,12 @@ void create_controller::handleCommand(std::vector ¶meters) parameters.erase(iter); break; } + else if(*iter=="-w"||*iter=="--websocket") + { + type=WebSocket; + parameters.erase(iter); + break; + } else if(*iter=="-n"||*iter=="--namespace") { if(type==Simple) @@ -87,6 +93,37 @@ void create_controller::handleCommand(std::vector ¶meters) } createSimpleController(parameters,namespaceName); } + else if(type==WebSocket) + { + std::string namespaceName; + for(auto iter=parameters.begin();iter!=parameters.end();iter++) + { + if((*iter)[0]=='-') + { + if(*iter=="-n"||*iter=="--namespace") + { + iter=parameters.erase(iter); + if(iter!=parameters.end()) + { + namespaceName=*iter; + iter=parameters.erase(iter); + break; + } + else + { + std::cout<<"please enter namespace"< &ctlNames,const std::string &namespaceName) +{ + for(auto iter=ctlNames.begin();iter!=ctlNames.end();iter++) + { + if ((*iter)[0] == '-') + { + std::cout<\n"; file<<"using namespace drogon;\n"; std::string indent=""; - if(namespaceName!="") { - file << "namespace " << namespaceName << "{\n"; - indent=" "; + auto namespace_name=namespaceName; + if(namespace_name!="") { + auto pos=namespace_name.find("::"); + while(pos!=std::string::npos) + { + auto namespaceI=namespace_name.substr(0,pos); + namespace_name=namespace_name.substr(pos+2); + file<\n"; file<\n"; + file<<"using namespace drogon;\n"; + std::string indent=""; + auto namespace_name=namespaceName; + if(namespace_name!="") { + auto pos=namespace_name.find("::"); + while(pos!=std::string::npos) + { + auto namespaceI=namespace_name.substr(0,pos); + namespace_name=namespace_name.substr(pos+2); + file<\n"; + file< &apiClasses) { for(auto iter=apiClasses.begin();iter!=apiClasses.end();iter++) diff --git a/drogon_ctl/create_controller.h b/drogon_ctl/create_controller.h index 3373594a..f925f873 100755 --- a/drogon_ctl/create_controller.h +++ b/drogon_ctl/create_controller.h @@ -27,14 +27,21 @@ namespace drogon_ctl protected: enum ControllerType{ Simple=0, - API + API, + WebSocket }; void createSimpleController(std::vector &ctlNames,const std::string &namespaceName=""); void createSimpleController(const std::string &ctlName,const std::string &namespaceName=""); + void createWebsockController(std::vector &ctlNames,const std::string &namespaceName=""); + void createWebsockController(const std::string &ctlName,const std::string &namespaceName=""); + void createApiController(std::vector &apiClasses); void createApiController(const std::string &className); void newSimpleControllerHeaderFile(std::ofstream &file,const std::string &ctlName,const std::string &namespaceName=""); void newSimpleControllerSourceFile(std::ofstream &file,const std::string &ctlName,const std::string &namespaceName=""); + void newWebsockControllerHeaderFile(std::ofstream &file,const std::string &ctlName,const std::string &namespaceName=""); + void newWebsockControllerSourceFile(std::ofstream &file,const std::string &ctlName,const std::string &namespaceName=""); + void newApiControllerHeaderFile(std::ofstream &file,const std::string &className); void newApiControllerSourceFile(std::ofstream &file,const std::string &className,const std::string &filename); diff --git a/lib/inc/drogon/WebSocketConnection.h b/lib/inc/drogon/WebSocketConnection.h index 28f78721..7be27ff9 100644 --- a/lib/inc/drogon/WebSocketConnection.h +++ b/lib/inc/drogon/WebSocketConnection.h @@ -20,7 +20,7 @@ #include #include namespace drogon{ - class WebSocketConnection:public trantor::NonCopyable + class WebSocketConnection { public: WebSocketConnection()= default; diff --git a/lib/inc/drogon/WebSocketController.h b/lib/inc/drogon/WebSocketController.h index 5b973885..a48d688c 100644 --- a/lib/inc/drogon/WebSocketController.h +++ b/lib/inc/drogon/WebSocketController.h @@ -16,6 +16,7 @@ #include #include +#include #include #include #include @@ -38,8 +39,11 @@ namespace drogon class WebSocketControllerBase:public virtual DrObjectBase { public: - virtual void handleNewMessage(const TcpConnectionPtr&, - MsgBuffer*)=0; + //on new data received + virtual void handleNewMessage(const WebSocketConnectionPtr&, + trantor::MsgBuffer*)=0; + //on new connection or after disconnect + virtual void handleConnection(const WebSocketConnectionPtr&)=0; virtual ~WebSocketControllerBase(){} }; diff --git a/lib/src/HttpAppFramework.cc b/lib/src/HttpAppFramework.cc index 8f9cff6a..0b5decff 100755 --- a/lib/src/HttpAppFramework.cc +++ b/lib/src/HttpAppFramework.cc @@ -15,8 +15,8 @@ #include "HttpRequestImpl.h" #include "HttpResponseImpl.h" #include "HttpClientImpl.h" +#include "WebSockectConnectionImpl.h" #include -#include #include #include #include @@ -69,7 +69,11 @@ namespace drogon private: std::vector> _listeners; void onAsyncRequest(const HttpRequestPtr& req,const std::function & callback); - void onNewWebsockRequest(const HttpRequestPtr& req,const std::function & callback); + void onNewWebsockRequest(const HttpRequestPtr& req, + const std::function & callback, + const WebSocketConnectionPtr &wsConnPtr); + void onWebsockMessage(const WebSocketConnectionPtr &wsConnPtr,trantor::MsgBuffer *buffer); + void onWebsockDisconnect(const WebSocketConnectionPtr &wsConnPtr); void readSendFile(const std::string& filePath,const HttpRequestPtr& req, HttpResponse* resp); void addApiPath(const std::string &path, const HttpApiBinderBasePtr &binder, @@ -349,8 +353,9 @@ void HttpAppFrameworkImpl::run() } serverPtr->setIoLoopNum(_threadNum); serverPtr->setHttpAsyncCallback(std::bind(&HttpAppFrameworkImpl::onAsyncRequest,this,_1,_2)); - serverPtr->setNewWebsocketCallback(std::bind(&HttpAppFrameworkImpl::onNewWebsockRequest,this,_1,_2)); - + serverPtr->setNewWebsocketCallback(std::bind(&HttpAppFrameworkImpl::onNewWebsockRequest,this,_1,_2,_3)); + serverPtr->setWebsocketMessageCallback(std::bind(&HttpAppFrameworkImpl::onWebsockMessage,this,_1,_2)); + serverPtr->setDisconnectWebsocketCallback(std::bind(&HttpAppFrameworkImpl::onWebsockDisconnect,this,_1)); serverPtr->start(); servers.push_back(serverPtr); #endif @@ -416,32 +421,69 @@ void HttpAppFrameworkImpl::doFilters(const std::vector &filters, } doFilterChain(filterPtrs,req,callback,needSetJsessionid,session_id,missCallback); } -void HttpAppFrameworkImpl::onNewWebsockRequest(const HttpRequestPtr& req,const std::function & callback) +void HttpAppFrameworkImpl::onWebsockDisconnect(const WebSocketConnectionPtr &wsConnPtr) +{ + auto wsConnImplPtr=std::dynamic_pointer_cast(wsConnPtr); + assert(wsConnImplPtr); + auto ctrl=wsConnImplPtr->controller(); + if(ctrl) + { + ctrl->handleConnection(wsConnPtr); + wsConnImplPtr->setController(WebSocketControllerBasePtr()); + } + +} +void HttpAppFrameworkImpl::onWebsockMessage(const WebSocketConnectionPtr &wsConnPtr, + trantor::MsgBuffer *buffer) +{ + auto wsConnImplPtr=std::dynamic_pointer_cast(wsConnPtr); + assert(wsConnImplPtr); + auto ctrl=wsConnImplPtr->controller(); + if(ctrl) + ctrl->handleNewMessage(wsConnPtr,buffer); +} +void HttpAppFrameworkImpl::onNewWebsockRequest(const HttpRequestPtr& req, + const std::function & callback, + const WebSocketConnectionPtr &wsConnPtr) { -// magic = '258EAFA5-E914-47DA-95CA-C5AB0DC85B11' -// sha1 = hashlib.sha1() -// sha1.update(ws_key + magic) -// return base64.b64encode(sha1.digest()) std::string wsKey=req->getHeader("Sec-WebSocket-Key"); if(!wsKey.empty()) { // magic="258EAFA5-E914-47DA-95CA-C5AB0DC85B11"; - wsKey.append("258EAFA5-E914-47DA-95CA-C5AB0DC85B11"); - unsigned char accKey[SHA_DIGEST_LENGTH]; - SHA1(reinterpret_cast(wsKey.c_str()), wsKey.length(), accKey); - auto base64Key=base64_encode(accKey,SHA_DIGEST_LENGTH); - auto resp=HttpResponse::newHttpResponse(); - resp->setStatusCode(HttpResponse::k101,"Switching Protocols"); - resp->addHeader("Upgrade","websocket"); - resp->addHeader("Connection","Upgrade"); - resp->addHeader("Sec-WebSocket-Accept",base64Key); - callback(*resp); - } else{ - HttpResponseImpl resp; - resp.setStatusCode(HttpResponse::k404NotFound); - resp.setCloseConnection(true); - callback(resp); + WebSocketControllerBasePtr ctrlPtr; + { + std::string pathLower(req->path()); + std::transform(pathLower.begin(),pathLower.end(),pathLower.begin(),tolower); + std::lock_guard guard(_websockCtrlMutex); + if(_websockCtrlMap.find(pathLower)!=_websockCtrlMap.end()) + { + ctrlPtr=_websockCtrlMap[pathLower]; + } + } + if(ctrlPtr) + { + wsKey.append("258EAFA5-E914-47DA-95CA-C5AB0DC85B11"); + unsigned char accKey[SHA_DIGEST_LENGTH]; + SHA1(reinterpret_cast(wsKey.c_str()), wsKey.length(), accKey); + auto base64Key=base64_encode(accKey,SHA_DIGEST_LENGTH); + auto resp=HttpResponse::newHttpResponse(); + resp->setStatusCode(HttpResponse::k101,"Switching Protocols"); + resp->addHeader("Upgrade","websocket"); + resp->addHeader("Connection","Upgrade"); + resp->addHeader("Sec-WebSocket-Accept",base64Key); + callback(*resp); + auto wsConnImplPtr=std::dynamic_pointer_cast(wsConnPtr); + assert(wsConnImplPtr); + wsConnImplPtr->setController(ctrlPtr); + ctrlPtr->handleConnection(wsConnPtr); + return; + } } + HttpResponseImpl resp; + resp.setStatusCode(HttpResponse::k404NotFound); + resp.setCloseConnection(true); + callback(resp); + } void HttpAppFrameworkImpl::onAsyncRequest(const HttpRequestPtr& req,const std::function & callback) { diff --git a/lib/src/HttpContext.h b/lib/src/HttpContext.h index d851c001..8cd72293 100755 --- a/lib/src/HttpContext.h +++ b/lib/src/HttpContext.h @@ -29,6 +29,7 @@ #include "HttpRequestImpl.h" #include "HttpResponseImpl.h" #include +#include using namespace trantor; namespace drogon @@ -132,13 +133,13 @@ namespace drogon } return false; } - bool isWebsock() + WebSocketConnectionPtr webSocketConn() { - return _isWebsock; + return _websockConnPtr; } - void setIsWebsock(bool val) + void setWebsockConnection(const WebSocketConnectionPtr &conn) { - _isWebsock=val; + _websockConnPtr=conn; } private: bool processRequestLine(const char *begin, const char *end); @@ -150,7 +151,7 @@ namespace drogon HttpResponseParseState res_state_; HttpResponseImpl response_; bool _firstRequest=true; - bool _isWebsock=false; + WebSocketConnectionPtr _websockConnPtr; }; } diff --git a/lib/src/HttpResponseImpl.cc b/lib/src/HttpResponseImpl.cc index 4d8c7eec..10cf7a85 100755 --- a/lib/src/HttpResponseImpl.cc +++ b/lib/src/HttpResponseImpl.cc @@ -193,13 +193,17 @@ void HttpResponseImpl::appendToBuffer(MsgBuffer* output) const output->append("\r\n"); snprintf(buf, sizeof buf, "Content-Length: %lu\r\n", _body.size()); output->append(buf); - if (_closeConnection) { - output->append("Connection: close\r\n"); - } else { + if(_headers.find("Connection")==_headers.end()) + { + if (_closeConnection) { + output->append("Connection: close\r\n"); + } else { - output->append("Connection: Keep-Alive\r\n"); + output->append("Connection: Keep-Alive\r\n"); + } } + for (auto it = _headers.begin(); it != _headers.end(); ++it) { diff --git a/lib/src/HttpServer.cc b/lib/src/HttpServer.cc index 1ae1553c..0248affd 100755 --- a/lib/src/HttpServer.cc +++ b/lib/src/HttpServer.cc @@ -43,6 +43,16 @@ static void defaultHttpAsyncCallback(const HttpRequestPtr&,const std::function & callback, + const WebSocketConnectionPtr& wsConnPtr) +{ + HttpResponseImpl resp; + resp.setStatusCode(HttpResponse::k404NotFound); + resp.setCloseConnection(true); + callback(resp); +} + @@ -51,7 +61,7 @@ HttpServer::HttpServer(EventLoop* loop, const std::string& name) : server_(loop, listenAddr, name.c_str()), httpAsyncCallback_(defaultHttpAsyncCallback), - newWebsocketCallback_(defaultHttpAsyncCallback) + newWebsocketCallback_(defaultWebSockAsyncCallback) { server_.setConnectionCallback( std::bind(&HttpServer::onConnection, this, _1)); @@ -81,9 +91,9 @@ void HttpServer::onConnection(const TcpConnectionPtr& conn) HttpContext* context = any_cast(conn->getMutableContext()); // LOG_INFO << "###:" << string(buf->peek(), buf->readableBytes()); - if(context->isWebsock()) + if(context->webSocketConn()) { - //TODO websock disconnect ! + disconnectWebsocketCallback_(context->webSocketConn()); } conn->setContext(std::string("None")); } @@ -95,10 +105,10 @@ void HttpServer::onMessage(const TcpConnectionPtr& conn, HttpContext* context = any_cast(conn->getMutableContext()); // LOG_INFO << "###:" << string(buf->peek(), buf->readableBytes()); - if(context->isWebsock()) + if(context->webSocketConn()) { //websocket payload,we shouldn't parse it - //TODO websock message callback; + webSocketMessageCallback_(context->webSocketConn(),buf); return; } if (!context->parseRequest(buf)) { @@ -113,16 +123,17 @@ void HttpServer::onMessage(const TcpConnectionPtr& conn, context->requestImpl()->setReceiveDate(trantor::Date::date()); if(context->firstReq()&&isWebSocket(conn,context->request())) { + auto wsConn=std::make_shared(conn); newWebsocketCallback_(context->request(),[=](HttpResponse &resp) mutable { if(resp.statusCode()==HttpResponse::k101) { - context->setIsWebsock(true); + context->setWebsockConnection(wsConn); } MsgBuffer buffer; ((HttpResponseImpl &)resp).appendToBuffer(&buffer); conn->send(buffer.peek(),buffer.readableBytes()); - }); + },wsConn); } else onRequest(conn, context->request()); diff --git a/lib/src/HttpServer.h b/lib/src/HttpServer.h index 0b90df71..1de7f21b 100755 --- a/lib/src/HttpServer.h +++ b/lib/src/HttpServer.h @@ -15,6 +15,7 @@ #pragma once +#include "WebSockectConnectionImpl.h" #include #include #include @@ -34,6 +35,15 @@ namespace drogon public: typedef std::function< void (const HttpRequestPtr&,const std::function&)> HttpAsyncCallback; + typedef std::function< void (const HttpRequestPtr&, + const std::function&, + const WebSocketConnectionPtr &)> + WebSocketNewAsyncCallback; + typedef std::function< void (const WebSocketConnectionPtr &)> + WebSocketDisconnetCallback; + typedef std::function< void (const WebSocketConnectionPtr &,trantor::MsgBuffer *)> + WebSocketMessageCallback; + HttpServer(EventLoop* loop, const InetAddress& listenAddr, const std::string& name); @@ -46,10 +56,18 @@ namespace drogon { httpAsyncCallback_= cb; } - void setNewWebsocketCallback(const HttpAsyncCallback& cb) + void setNewWebsocketCallback(const WebSocketNewAsyncCallback& cb) { newWebsocketCallback_=cb; } + void setDisconnectWebsocketCallback(const WebSocketDisconnetCallback& cb) + { + disconnectWebsocketCallback_=cb; + } + void setWebsocketMessageCallback(const WebSocketMessageCallback& cb) + { + webSocketMessageCallback_=cb; + } void setIoLoopNum(int numThreads) { server_.setIoLoopNum(numThreads); @@ -71,8 +89,9 @@ namespace drogon bool isWebSocket(const TcpConnectionPtr& conn, const HttpRequestPtr& req); trantor::TcpServer server_; HttpAsyncCallback httpAsyncCallback_; - HttpAsyncCallback newWebsocketCallback_; - + WebSocketNewAsyncCallback newWebsocketCallback_; + WebSocketDisconnetCallback disconnectWebsocketCallback_; + WebSocketMessageCallback webSocketMessageCallback_; }; diff --git a/lib/src/WebSockectConnectionImpl.cc b/lib/src/WebSockectConnectionImpl.cc index f6091086..f1bfffc3 100644 --- a/lib/src/WebSockectConnectionImpl.cc +++ b/lib/src/WebSockectConnectionImpl.cc @@ -1,34 +1,6 @@ -#include +#include "WebSockectConnectionImpl.h" #include -namespace drogon{ - class WebSocketConnectionImpl:public WebSocketConnection - { - public: - WebSocketConnectionImpl(const trantor::TcpConnectionPtr &conn); - virtual void send(const char *msg,uint64_t len) override; - virtual void send(const std::string &msg) override; - - virtual const trantor::InetAddress& localAddr() const override; - virtual const trantor::InetAddress& peerAddr() const override; - - virtual bool connected() const override; - virtual bool disconnected() const override; - - virtual void shutdown() override;//close write - virtual void forceClose() override;//close - - virtual void setContext(const any& context) override; - virtual const any& getContext() const override; - virtual any* getMutableContext() override; - - private: - std::weak_ptr _tcpConn; - trantor::InetAddress _localAddr; - trantor::InetAddress _peerAddr; - any _context; - }; -} using namespace drogon; WebSocketConnectionImpl::WebSocketConnectionImpl(const trantor::TcpConnectionPtr &conn): _tcpConn(conn), diff --git a/lib/src/WebSockectConnectionImpl.h b/lib/src/WebSockectConnectionImpl.h new file mode 100644 index 00000000..60ea83d6 --- /dev/null +++ b/lib/src/WebSockectConnectionImpl.h @@ -0,0 +1,42 @@ +#pragma once + +#include +#include +namespace drogon{ + class WebSocketConnectionImpl:public WebSocketConnection + { + public: + WebSocketConnectionImpl(const trantor::TcpConnectionPtr &conn); + + virtual void send(const char *msg,uint64_t len) override; + virtual void send(const std::string &msg) override; + + virtual const trantor::InetAddress& localAddr() const override; + virtual const trantor::InetAddress& peerAddr() const override; + + virtual bool connected() const override; + virtual bool disconnected() const override; + + virtual void shutdown() override;//close write + virtual void forceClose() override;//close + + virtual void setContext(const any& context) override; + virtual const any& getContext() const override; + virtual any* getMutableContext() override; + + void setController(const WebSocketControllerBasePtr &ctrl) + { + _ctrlPtr=ctrl; + } + WebSocketControllerBasePtr controller() + { + return _ctrlPtr; + } + private: + std::weak_ptr _tcpConn; + trantor::InetAddress _localAddr; + trantor::InetAddress _peerAddr; + WebSocketControllerBasePtr _ctrlPtr; + any _context; + }; +} \ No newline at end of file