diff --git a/lib/src/HttpServer.cc b/lib/src/HttpServer.cc index 246e15df..0d86e2c1 100644 --- a/lib/src/HttpServer.cc +++ b/lib/src/HttpServer.cc @@ -182,36 +182,7 @@ void HttpServer::onMessage(const TcpConnectionPtr &conn, MsgBuffer *buf) req->setCreationDate(trantor::Date::date()); req->setSecure(conn->isSSLConnection()); req->setPeerCertificate(conn->peerCertificate()); - if (requestParser->firstReq() && isWebSocket(req)) - { - auto wsConn = std::make_shared(conn); - wsConn->setPingMessage("", std::chrono::seconds{30}); - newWebsocketCallback_( - req, - [conn, wsConn, requestParser, this, req]( - const HttpResponsePtr &resp) mutable { - if (conn->connected()) - { - for (auto &advice : preSendingAdvices_) - { - advice(req, resp); - } - if (resp->statusCode() == k101SwitchingProtocols) - { - requestParser->setWebsockConnection(wsConn); - } - auto httpString = - ((HttpResponseImpl *)resp.get())->renderToBuffer(); - conn->send(httpString); - COZ_PROGRESS - } - }, - wsConn); - } - else - { - requests.push_back(req); - } + requests.push_back(req); requestParser->reset(); } onRequests(conn, requests, requestParser); @@ -248,6 +219,55 @@ void HttpServer::onRequests( { if (requests.empty()) return; + + // will only be checked for the first request + if (requestParser->firstReq() && requests.size() == 1 && + isWebSocket(requests[0])) + { + auto &req = requests[0]; + if (passSyncAdvices(req, + requestParser, + syncAdvices_, + false /* Not pipelined */, + false /* Not HEAD */)) + { + auto wsConn = std::make_shared(conn); + wsConn->setPingMessage("", std::chrono::seconds{30}); + newWebsocketCallback_( + req, + [conn, wsConn, requestParser, this, req]( + const HttpResponsePtr &resp) mutable { + if (conn->connected()) + { + for (auto &advice : preSendingAdvices_) + { + advice(req, resp); + } + if (resp->statusCode() == k101SwitchingProtocols) + { + requestParser->setWebsockConnection(wsConn); + } + auto httpString = + ((HttpResponseImpl *)resp.get())->renderToBuffer(); + conn->send(httpString); + COZ_PROGRESS + } + }, + wsConn); + return; + } + + // flush response for not passing sync advices + if (conn->connected() && !requestParser->getResponseBuffer().empty()) + { + sendResponses(conn, + requestParser->getResponseBuffer(), + requestParser->getBuffer()); + requestParser->getResponseBuffer().clear(); + } + return; + } + if (HttpAppFrameworkImpl::instance().keepaliveRequestsNumber() > 0 && requestParser->numberOfRequestsParsed() >= HttpAppFrameworkImpl::instance().keepaliveRequestsNumber()) @@ -283,8 +303,7 @@ void HttpServer::onRequests( requestParser->pushRequestToPipelining(req, isHeadMethod); reqPipelined = true; } - if (!syncAdvices_.empty() && - !passSyncAdvices( + if (!passSyncAdvices( req, requestParser, syncAdvices_, reqPipelined, isHeadMethod)) { continue; @@ -678,10 +697,14 @@ void HttpServer::sendResponses( static inline bool isWebSocket(const HttpRequestImplPtr &req) { + if (req->method() != Get) + return false; + auto &headers = req->headers(); if (headers.find("upgrade") == headers.end() || headers.find("connection") == headers.end()) return false; + auto connectionField = req->getHeaderBy("connection"); std::transform(connectionField.begin(), connectionField.end(),