From 5faab6b414de1feb8ad599a5ff22bc8588222ef9 Mon Sep 17 00:00:00 2001 From: An Tao Date: Thu, 4 Jun 2020 19:11:07 +0800 Subject: [PATCH] Modify the WebSocketConnection class (#452) * Add getContextRef method to the WebSocketConnection class * Expose some functions on Windows * Send a close message when closing a web socket connection --- lib/inc/drogon/WebSocketConnection.h | 88 ++++++++++++++++++++++++++-- lib/inc/drogon/utils/Utilities.h | 6 +- lib/src/Utilities.cc | 52 ++++++++-------- lib/src/WebSocketConnectionImpl.cc | 16 ++++- lib/src/WebSocketConnectionImpl.h | 10 ++-- 5 files changed, 134 insertions(+), 38 deletions(-) diff --git a/lib/inc/drogon/WebSocketConnection.h b/lib/inc/drogon/WebSocketConnection.h index 16d39520..eea621ba 100644 --- a/lib/inc/drogon/WebSocketConnection.h +++ b/lib/inc/drogon/WebSocketConnection.h @@ -21,6 +21,66 @@ #include namespace drogon { +enum class CloseCode +{ + /*1000 indicates a normal closure, meaning that the purpose for which the + connection was established has been fulfilled.*/ + kNormalClosure = 1000, + /*1001 indicates that an endpoint is "going away", such as a server going + down or a browser having navigated away from a page.*/ + kEndpointGone = 1001, + /*1002 indicates that an endpoint is terminating the connection due to a + protocol error.*/ + kProtocolError = 1002, + /*1003 indicates that an endpoint is terminating the connection because it + has received a type of data it cannot accept (e.g., an endpoint that + understands only text data MAY send this if it receives a binary + message).*/ + kInvalidMessage = 1003, + /*1005 is a reserved value and MUST NOT be set as a status code in a Close + control frame by an endpoint. It is designated for use in applications + expecting a status code to indicate that no status code was actually + present.*/ + kNone = 1005, + /*1006 is a reserved value and MUST NOT be set as a status code in a Close + control frame by an endpoint. It is designated for use in applications + expecting a status code to indicate that the connection was closed + abnormally, e.g., without sending or receiving a Close control frame. + */ + kAbnormally = 1006, + /*1007 indicates that an endpoint is terminating the connection because it + has received data within a message that was not consistent with the type + of the message (e.g., non-UTF-8 [RFC3629] data within a text message).*/ + kWrongMessageContent = 1007, + /*1008 indicates that an endpoint is terminating the connection because it + has received a message that violates its policy. This is a generic + status code that can be returned when there is no other more suitable + status code (e.g., 1003 or 1009) or if there is a need to hide specific + details about the policy. + */ + kViolation = 1008, + /*1009 indicates that an endpoint is terminating the connection because it + has received a message that is too big for it to process.*/ + kMessageTooBig = 1009, + /*1010 indicates that an endpoint (client) is terminating the connection + because it has expected the server to negotiate one or more extension, + but the server didn't return them in the response message of the + WebSocket handshake. The list of extensions that are needed SHOULD + appear in the /reason/ part of the Close frame. Note that this status + code is not used by the server, because it can fail the WebSocket + handshake instead.*/ + kNeedMoreExtensions = 1010, + /*1011 indicates that a server is terminating the connection because it + encountered an unexpected condition that prevented it from fulfilling the + request.*/ + kUnexpectedCondition = 1011, + /*1015 is a reserved value and MUST NOT be set as a status code in a Close + control frame by an endpoint. It is designated for use in applications + expecting a status code to indicate that the connection was closed due to + a failure to perform a TLS handshake (e.g., the server certificate can't + be verified).*/ + kTLSFailed = 1015 +}; /** * @brief The WebSocket connection abstract class. * @@ -41,7 +101,7 @@ class WebSocketConnection virtual void send( const char *msg, uint64_t len, - const WebSocketMessageType &type = WebSocketMessageType::Text) = 0; + const WebSocketMessageType type = WebSocketMessageType::Text) = 0; /** * @brief Send a message to the peer @@ -51,7 +111,7 @@ class WebSocketConnection */ virtual void send( const std::string &msg, - const WebSocketMessageType &type = WebSocketMessageType::Text) = 0; + const WebSocketMessageType type = WebSocketMessageType::Text) = 0; /// Return the local IP address and port number of the connection virtual const trantor::InetAddress &localAddr() const = 0; @@ -65,9 +125,15 @@ class WebSocketConnection /// Return true if the connection is closed virtual bool disconnected() const = 0; - /// Shut down the write direction, which means that further send operations - /// are disabled. - virtual void shutdown() = 0; + /** + * @brief Shut down the write direction, which means that further send + * operations are disabled. + * + * @param code Please refer to the enum class CloseCode. (RFC6455 7.4.1) + * @param reason The reason for closing the connection. + */ + virtual void shutdown(const CloseCode code = CloseCode::kNormalClosure, + const std::string &reason = "") = 0; /// Close the connection virtual void forceClose() = 0; @@ -104,6 +170,18 @@ class WebSocketConnection return std::static_pointer_cast(contextPtr_); } + /** + * @brief Get the custom data reference from the connection. + * @note Please make sure that the context is available. + * @tparam T The type of the data stored in the context. + * @return T& + */ + template + T &getContextRef() const + { + return *(static_cast(contextPtr_.get())); + } + /// Return true if the context is set by user. bool hasContext() { diff --git a/lib/inc/drogon/utils/Utilities.h b/lib/inc/drogon/utils/Utilities.h index 53f89ddb..608df6ce 100644 --- a/lib/inc/drogon/utils/Utilities.h +++ b/lib/inc/drogon/utils/Utilities.h @@ -21,7 +21,11 @@ #include #include #include - +#ifdef _WIN32 +#include +char *strptime(const char *s, const char *f, struct tm *tm); +time_t timegm(struct tm *tm); +#endif namespace drogon { namespace utils diff --git a/lib/src/Utilities.cc b/lib/src/Utilities.cc index d17b499f..07013825 100644 --- a/lib/src/Utilities.cc +++ b/lib/src/Utilities.cc @@ -44,6 +44,33 @@ #include #include +#ifdef _WIN32 +char *strptime(const char *s, const char *f, struct tm *tm) +{ + // std::get_time is defined such that its + // format parameters are the exact same as strptime. + std::istringstream input(s); + input.imbue(std::locale(setlocale(LC_ALL, nullptr))); + input >> std::get_time(tm, f); + if (input.fail()) + { + return nullptr; + } + return (char *)(s + input.tellg()); +} +time_t timegm(struct tm *tm) +{ + struct tm my_tm; + + memcpy(&my_tm, tm, sizeof(struct tm)); + + /* _mkgmtime() changes the value of the struct tm* you pass in, so + * use a copy + */ + return _mkgmtime(&my_tm); +} +#endif + namespace drogon { namespace utils @@ -243,32 +270,7 @@ std::string hexToBinaryString(const char *ptr, size_t length) } return ret; } -#ifdef _WIN32 -char *strptime(const char *s, const char *f, struct tm *tm) -{ - // std::get_time is defined such that its - // format parameters are the exact same as strptime. - std::istringstream input(s); - input.imbue(std::locale(setlocale(LC_ALL, nullptr))); - input >> std::get_time(tm, f); - if (input.fail()) - { - return nullptr; - } - return (char *)(s + input.tellg()); -} -time_t timegm(struct tm *tm) -{ - struct tm my_tm; - memcpy(&my_tm, tm, sizeof(struct tm)); - - /* _mkgmtime() changes the value of the struct tm* you pass in, so - * use a copy - */ - return _mkgmtime(&my_tm); -} -#endif std::string binaryStringToHex(const unsigned char *ptr, size_t length) { std::string idString; diff --git a/lib/src/WebSocketConnectionImpl.cc b/lib/src/WebSocketConnectionImpl.cc index 42ce3251..90bc0603 100644 --- a/lib/src/WebSocketConnectionImpl.cc +++ b/lib/src/WebSocketConnectionImpl.cc @@ -29,7 +29,7 @@ WebSocketConnectionImpl::WebSocketConnectionImpl( void WebSocketConnectionImpl::send(const char *msg, uint64_t len, - const WebSocketMessageType &type) + const WebSocketMessageType type) { unsigned char opcode; if (type == WebSocketMessageType::Text) @@ -126,7 +126,7 @@ void WebSocketConnectionImpl::sendWsData(const char *msg, tcpConnectionPtr_->send(std::move(bytesFormatted)); } void WebSocketConnectionImpl::send(const std::string &msg, - const WebSocketMessageType &type) + const WebSocketMessageType type) { send(msg.data(), msg.length(), type); } @@ -147,8 +147,18 @@ bool WebSocketConnectionImpl::disconnected() const { return tcpConnectionPtr_->disconnected(); } -void WebSocketConnectionImpl::WebSocketConnectionImpl::shutdown() +void WebSocketConnectionImpl::WebSocketConnectionImpl::shutdown( + const CloseCode code, + const std::string &reason) { + tcpConnectionPtr_->getLoop()->invalidateTimer(pingTimerId_); + std::string message; + message.resize(reason.length() + 2); + auto c = htons(static_cast(code)); + memcpy(&message[0], &c, 2); + if (!reason.empty()) + memcpy(&message[2], reason.data(), reason.length()); + send(message, WebSocketMessageType::Close); tcpConnectionPtr_->shutdown(); } void WebSocketConnectionImpl::WebSocketConnectionImpl::forceClose() diff --git a/lib/src/WebSocketConnectionImpl.h b/lib/src/WebSocketConnectionImpl.h index b57563b4..25e9f500 100644 --- a/lib/src/WebSocketConnectionImpl.h +++ b/lib/src/WebSocketConnectionImpl.h @@ -56,10 +56,10 @@ class WebSocketConnectionImpl virtual void send( const char *msg, uint64_t len, - const WebSocketMessageType &type = WebSocketMessageType::Text) override; + const WebSocketMessageType type = WebSocketMessageType::Text) override; virtual void send( const std::string &msg, - const WebSocketMessageType &type = WebSocketMessageType::Text) override; + const WebSocketMessageType type = WebSocketMessageType::Text) override; virtual const trantor::InetAddress &localAddr() const override; virtual const trantor::InetAddress &peerAddr() const override; @@ -67,8 +67,10 @@ class WebSocketConnectionImpl virtual bool connected() const override; virtual bool disconnected() const override; - virtual void shutdown() override; // close write - virtual void forceClose() override; // close + virtual void shutdown( + const CloseCode code = CloseCode::kNormalClosure, + const std::string &reason = "") override; // close write + virtual void forceClose() override; // close virtual void setPingMessage( const std::string &message,