diff --git a/CMakeLists.txt b/CMakeLists.txt index 9d45e080..d9286538 100755 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -271,7 +271,12 @@ set(DROGON_SOURCES lib/src/Utilities.cc lib/src/WebSocketClientImpl.cc lib/src/WebSocketConnectionImpl.cc - lib/src/WebsocketControllersRouter.cc) + lib/src/WebsocketControllersRouter.cc + lib/src/RateLimiter.cc + lib/src/FixedWindowRateLimiter.cc + lib/src/SlidingWindowRateLimiter.cc + lib/src/TokenBucketRateLimiter.cc + lib/src/Hodor.cc) set(private_headers lib/src/AOPAdvice.h lib/src/CacheFile.h @@ -300,7 +305,10 @@ set(private_headers lib/src/TaskTimeoutFlag.h lib/src/WebSocketClientImpl.h lib/src/WebSocketConnectionImpl.h - lib/src/WebsocketControllersRouter.h) + lib/src/WebsocketControllersRouter.h + lib/src/FixedWindowRateLimiter.h + lib/src/SlidingWindowRateLimiter.h + lib/src/TokenBucketRateLimiter.h) if (NOT WIN32) set(DROGON_SOURCES @@ -556,6 +564,7 @@ set(DROGON_HEADERS lib/inc/drogon/drogon_callbacks.h lib/inc/drogon/PubSubService.h lib/inc/drogon/drogon_test.h + lib/inc/drogon/RateLimiter.h ${CMAKE_CURRENT_BINARY_DIR}/exports/drogon/exports.h) set(private_headers ${private_headers} @@ -697,7 +706,8 @@ set(DROGON_PLUGIN_HEADERS lib/inc/drogon/plugins/Plugin.h lib/inc/drogon/plugins/SecureSSLRedirector.h lib/inc/drogon/plugins/AccessLogger.h - lib/inc/drogon/plugins/RealIpResolver.h) + lib/inc/drogon/plugins/RealIpResolver.h + lib/inc/drogon/plugins/Hodor.h) install(FILES ${DROGON_PLUGIN_HEADERS} DESTINATION ${INSTALL_INCLUDE_DIR}/drogon/plugins) diff --git a/lib/inc/drogon/RateLimiter.h b/lib/inc/drogon/RateLimiter.h new file mode 100644 index 00000000..3b8706b9 --- /dev/null +++ b/lib/inc/drogon/RateLimiter.h @@ -0,0 +1,67 @@ +#pragma once +#include +#include +#include +#include + +namespace drogon +{ +enum class DROGON_EXPORT RateLimiterType +{ + kFixedWindow, + kSlidingWindow, + kTokenBucket +}; +inline RateLimiterType stringToRateLimiterType(const std::string &type) +{ + if (type == "fixedWindow" || type == "fixed_window") + return RateLimiterType::kFixedWindow; + else if (type == "slidingWindow" || type == "sliding_window") + return RateLimiterType::kSlidingWindow; + return RateLimiterType::kTokenBucket; +} +class DROGON_EXPORT RateLimiter; +using RateLimiterPtr = std::shared_ptr; +/** + * @brief This class is used to limit the number of requests per second + * + * */ +class DROGON_EXPORT RateLimiter +{ + public: + /** + * @brief Create a rate limiter + * @param type The type of the rate limiter + * @param capacity The maximum number of requests in the time unit. + * @param timeUnit The time unit of the rate limiter. + * @return A rate limiter pointer + */ + static RateLimiterPtr newRateLimiter( + RateLimiterType type, + size_t capacity, + std::chrono::duration timeUnit = std::chrono::seconds(60)); + /** + * @brief Check if a request is allowed + * + * @return true The request is allowed + * @return false The request is not allowed + */ + virtual bool isAllowed() = 0; +}; +class DROGON_EXPORT SafeRateLimiter : public RateLimiter +{ + public: + SafeRateLimiter(RateLimiterPtr limiter) : limiter_(limiter) + { + } + bool isAllowed() override + { + std::lock_guard lock(mutex_); + return limiter_->isAllowed(); + } + + private: + RateLimiterPtr limiter_; + std::mutex mutex_; +}; +} // namespace drogon diff --git a/lib/inc/drogon/drogon.h b/lib/inc/drogon/drogon.h index d88c4a5e..e5742fd2 100644 --- a/lib/inc/drogon/drogon.h +++ b/lib/inc/drogon/drogon.h @@ -30,6 +30,7 @@ #include #include #include +#include #include #include #include diff --git a/lib/inc/drogon/plugins/Hodor.h b/lib/inc/drogon/plugins/Hodor.h new file mode 100644 index 00000000..e416f0ca --- /dev/null +++ b/lib/inc/drogon/plugins/Hodor.h @@ -0,0 +1,146 @@ +/** + * @file Hodor.h + * @author An Tao + * + * Copyright 2018, An Tao. All rights reserved. + * https://github.com/an-tao/drogon + * Use of this source code is governed by a MIT license + * that can be found in the License file. + * + * Drogon + * + */ +#pragma once +#include +#include +#include +#include +#include +#include +#include + +namespace drogon +{ +namespace plugin +{ +/** + * @brief The Hodor plugin implements a global rate limiter that limits the + * number of requests in a particular time unit. + * The json configuration is as follows: + * + * @code + { + "name": "drogon::plugin::Hodor", + "dependencies": [], + "config": { + // The algorithm used to limit the number of requests. + // The default value is "token_bucket". other values are "fixed_window" +or "sliding_window". + "algorithm": "token_bucket", + // a regular expression (for matching the path of a request) list for +URLs that have to be limited. if the list is empty, all URLs are limited. + "urls": ["^/api/.*", ...], + // The time unit in seconds. the default value is 60. + "time_unit": 60, + // The maximum number of requests in a time unit. the default value 0 +means no limit. + "capacity": 0, + // The maximum number of requests in a time unit for a single IP. the +default value 0 means no limit. + "ip_capacity": 0, + // The maximum number of requests in a time unit for a single user. +a function must be provided to the plugin to get the user id from the request. +the default value 0 means no limit. + "user_capacity": 0, + // Use the RealIpResolver plugin to get the real IP address of the +request. if this option is true, the RealIpResolver plugin should be added to +the dependencies list. the default value is false. + "use_real_ip_resolver": false, + // Multiple threads mode: the default value is true. if this option is +true, some mutexes are used for thread-safe. + "multi_threads": true, + // The message body of the response when the request is rejected. + "rejection_message": "Too many requests", + // In seconds, the minimum expiration time of the limiters for different +IPs or users. the default value is 600. + "limiter_expire_time": 600, + "sub_limits": [ + { + "urls": ["^/api/1/.*", ...], + "capacity": 0, + "ip_capacity": 0, + "user_capacity": 0 + },... + ] + } + } + @endcode + * + * Enable the plugin by adding the configuration to the list of plugins in the + * configuration file. + * */ +class DROGON_EXPORT Hodor : public drogon::Plugin +{ + public: + Hodor() + { + } + void initAndStart(const Json::Value &config) override; + void shutdown() override; + /** + * @brief the method is used to set a function to get the user id from the + * request. users should call this method after calling the app().run() + * method. etc. use the beginning advice of AOP. + * */ + void setUserIdGetter( + std::function(const HttpRequestPtr &)> func) + { + userIdGetter_ = std::move(func); + } + /** + * @brief the method is used to set a function to create the response when + * the rate limit is exceeded. users should call this method after calling + * the app().run() method. etc. use the beginning advice of AOP. + * */ + void setRejectResponseFactory( + std::function func) + { + rejectResponseFactory_ = std::move(func); + } + + private: + struct LimitStrategy + { + std::regex urlsRegex; + size_t capacity{0}; + size_t ipCapacity{0}; + size_t userCapacity{0}; + bool regexFlag{false}; + RateLimiterPtr globalLimiterPtr; + std::unique_ptr> ipLimiterMapPtr; + std::unique_ptr> + userLimiterMapPtr; + }; + LimitStrategy makeLimitStrategy(const Json::Value &config); + std::vector limitStrategies_; + RateLimiterType algorithm_{RateLimiterType::kTokenBucket}; + std::chrono::duration timeUnit_{1.0}; + bool multiThreads_{true}; + bool useRealIpResolver_{false}; + size_t limiterExpireTime_{600}; + std::function(const drogon::HttpRequestPtr &)> + userIdGetter_; + std::function + rejectResponseFactory_; + + void onHttpRequest(const HttpRequestPtr &, + AdviceCallback &&, + AdviceChainCallback &&); + bool checkLimit(const HttpRequestPtr &req, + const LimitStrategy &strategy, + const std::string &ip, + const optional &userId); + HttpResponsePtr rejectResponse_; +}; +} // namespace plugin +} // namespace drogon diff --git a/lib/inc/drogon/plugins/RealIpResolver.h b/lib/inc/drogon/plugins/RealIpResolver.h index 418fb4b0..c4d68168 100644 --- a/lib/inc/drogon/plugins/RealIpResolver.h +++ b/lib/inc/drogon/plugins/RealIpResolver.h @@ -8,6 +8,7 @@ #include #include +#include #include namespace drogon diff --git a/lib/src/FixedWindowRateLimiter.cc b/lib/src/FixedWindowRateLimiter.cc new file mode 100644 index 00000000..f3be3ce9 --- /dev/null +++ b/lib/src/FixedWindowRateLimiter.cc @@ -0,0 +1,31 @@ +#include "FixedWindowRateLimiter.h" + +using namespace drogon; + +FixedWindowRateLimiter::FixedWindowRateLimiter( + size_t capacity, + std::chrono::duration timeUnit) + : capacity_(capacity), + lastTime_(std::chrono::steady_clock::now()), + timeUnit_(timeUnit) +{ +} +// implementation of the fixed window algorithm + +bool FixedWindowRateLimiter::isAllowed() +{ + auto now = std::chrono::steady_clock::now(); + auto duration = std::chrono::duration_cast>( + now - lastTime_); + if (duration >= timeUnit_) + { + currentRequests_ = 0; + lastTime_ = now; + } + if (currentRequests_ < capacity_) + { + currentRequests_++; + return true; + } + return false; +} \ No newline at end of file diff --git a/lib/src/FixedWindowRateLimiter.h b/lib/src/FixedWindowRateLimiter.h new file mode 100644 index 00000000..5aee5bae --- /dev/null +++ b/lib/src/FixedWindowRateLimiter.h @@ -0,0 +1,20 @@ +#pragma once + +#include +#include +namespace drogon +{ +class FixedWindowRateLimiter : public RateLimiter +{ + public: + FixedWindowRateLimiter(size_t capacity, + std::chrono::duration timeUnit); + bool isAllowed() override; + + private: + size_t capacity_; + size_t currentRequests_{0}; + std::chrono::steady_clock::time_point lastTime_; + std::chrono::duration timeUnit_; +}; +} // namespace drogon diff --git a/lib/src/Hodor.cc b/lib/src/Hodor.cc new file mode 100644 index 00000000..a13eb075 --- /dev/null +++ b/lib/src/Hodor.cc @@ -0,0 +1,236 @@ +#include +#include + +using namespace drogon::plugin; +Hodor::LimitStrategy Hodor::makeLimitStrategy(const Json::Value &config) +{ + LimitStrategy strategy; + strategy.capacity = config.get("capacity", 0).asUInt(); + if (config.isMember("urls") && config["urls"].isArray()) + { + std::string regexString; + for (auto &str : config["urls"]) + { + assert(str.isString()); + regexString.append("(").append(str.asString()).append(")|"); + } + if (!regexString.empty()) + { + regexString.resize(regexString.length() - 1); + strategy.urlsRegex = std::regex(regexString); + strategy.regexFlag = true; + } + } + + if (strategy.capacity > 0) + { + if (multiThreads_) + { + strategy.globalLimiterPtr = std::make_shared( + RateLimiter::newRateLimiter(algorithm_, + strategy.capacity, + timeUnit_)); + } + else + { + strategy.globalLimiterPtr = + RateLimiter::newRateLimiter(algorithm_, + strategy.capacity, + timeUnit_); + } + } + strategy.ipCapacity = config.get("ip_capacity", 0).asUInt(); + if (strategy.ipCapacity > 0) + { + strategy.ipLimiterMapPtr = + std::make_unique>( + drogon::app().getLoop(), + timeUnit_.count() / 60 < 1 ? 1 : timeUnit_.count() / 60, + 2, + 100); + } + + strategy.userCapacity = config.get("user_capacity", 0).asUInt(); + if (strategy.userCapacity > 0) + { + strategy.userLimiterMapPtr = + std::make_unique>( + drogon::app().getLoop(), + timeUnit_.count() / 60 < 1 ? 1 : timeUnit_.count() / 60, + 2, + 100); + } + return strategy; +} +void Hodor::initAndStart(const Json::Value &config) +{ + algorithm_ = stringToRateLimiterType( + config.get("algorithm", "token_bucket").asString()); + timeUnit_ = std::chrono::seconds(config.get("time_unit", 60).asUInt()); + + multiThreads_ = config.get("multi_threads", true).asBool(); + + useRealIpResolver_ = config.get("use_real_ip_resolver", false).asBool(); + rejectResponse_ = HttpResponse::newHttpResponse(); + rejectResponse_->setStatusCode(k429TooManyRequests); + rejectResponse_->setBody( + config.get("rejection_message", "Too many requests").asString()); + rejectResponse_->setCloseConnection(true); + limiterExpireTime_ = + (std::min)(static_cast( + config.get("limiter_expire_time", 600).asUInt()), + static_cast(timeUnit_.count() * 3)); + limitStrategies_.emplace_back(makeLimitStrategy(config)); + if (config.isMember("sub_limits") && config["sub_limits"].isArray()) + { + for (auto &subLimit : config["sub_limits"]) + { + assert(subLimit.isObject()); + if (!subLimit["urls"].isArray() || subLimit["urls"].size() == 0) + { + LOG_ERROR + << "The urls of sub_limits must be an array and not empty!"; + continue; + } + if (subLimit["capacity"].asUInt() == 0 && + subLimit["ip_capacity"].asUInt() == 0 && + subLimit["user_capacity"].asUInt() == 0) + { + LOG_ERROR << "At least one capacity of sub_limits must be " + "greater than 0!"; + continue; + } + limitStrategies_.emplace_back(makeLimitStrategy(subLimit)); + } + } + app().registerPreHandlingAdvice([this](const HttpRequestPtr &req, + AdviceCallback &&acb, + AdviceChainCallback &&accb) { + onHttpRequest(req, std::move(acb), std::move(accb)); + }); +} + +void Hodor::shutdown() +{ + LOG_TRACE << "Hodor plugin is shutdown!"; +} +bool Hodor::checkLimit(const HttpRequestPtr &req, + const LimitStrategy &strategy, + const std::string &ip, + const optional &userId) +{ + if (strategy.regexFlag) + { + if (!std::regex_match(req->path(), strategy.urlsRegex)) + { + return true; + } + } + if (strategy.globalLimiterPtr) + { + if (!strategy.globalLimiterPtr->isAllowed()) + { + return false; + } + } + if (strategy.ipCapacity > 0) + { + RateLimiterPtr limiterPtr; + strategy.ipLimiterMapPtr->modify( + ip, + [this, &limiterPtr, &strategy](RateLimiterPtr &ptr) { + if (!ptr) + { + if (multiThreads_) + { + ptr = std::make_shared( + RateLimiter::newRateLimiter(algorithm_, + strategy.ipCapacity, + timeUnit_)); + } + else + { + ptr = RateLimiter::newRateLimiter(algorithm_, + strategy.ipCapacity, + timeUnit_); + } + } + limiterPtr = ptr; + }, + limiterExpireTime_); + if (!limiterPtr->isAllowed()) + { + return false; + } + } + if (strategy.userCapacity > 0) + { + if (!userId.has_value()) + { + return true; + } + RateLimiterPtr limiterPtr; + strategy.userLimiterMapPtr->modify( + *userId, + [this, &strategy, &limiterPtr](RateLimiterPtr &ptr) { + if (!ptr) + { + if (multiThreads_) + { + ptr = std::make_shared( + RateLimiter::newRateLimiter(algorithm_, + strategy.userCapacity, + timeUnit_)); + } + else + { + ptr = RateLimiter::newRateLimiter(algorithm_, + strategy.userCapacity, + timeUnit_); + } + } + limiterPtr = ptr; + }, + limiterExpireTime_); + if (!limiterPtr->isAllowed()) + { + return false; + } + } + return true; +} +void Hodor::onHttpRequest(const drogon::HttpRequestPtr &req, + drogon::AdviceCallback &&adviceCallback, + drogon::AdviceChainCallback &&chainCallback) +{ + std::string ip; + if (useRealIpResolver_) + { + ip = drogon::plugin::RealIpResolver::GetRealAddr(req).toIp(); + } + else + { + ip = req->peerAddr().toIp(); + } + optional userId; + if (userIdGetter_) + { + userId = userIdGetter_(req); + } + for (auto &strategy : limitStrategies_) + { + if (!checkLimit(req, strategy, ip, userId)) + { + if (rejectResponseFactory_) + { + adviceCallback(rejectResponseFactory_(req)); + } + else + { + adviceCallback(rejectResponse_); + } + return; + } + } + chainCallback(); +} \ No newline at end of file diff --git a/lib/src/HttpAppFrameworkImpl.cc b/lib/src/HttpAppFrameworkImpl.cc index 6d24a4b7..b15da4bc 100644 --- a/lib/src/HttpAppFrameworkImpl.cc +++ b/lib/src/HttpAppFrameworkImpl.cc @@ -92,7 +92,10 @@ HttpAppFrameworkImpl::HttpAppFrameworkImpl() postHandlingAdvices_)), websockCtrlsRouterPtr_( new WebsocketControllersRouter(postRoutingAdvices_, - postRoutingObservers_)), + postRoutingObservers_, + preHandlingAdvices_, + preHandlingObservers_, + postHandlingAdvices_)), listenerManagerPtr_(new ListenerManager), pluginsManagerPtr_(new PluginsManager), dbClientManagerPtr_(new orm::DbClientManager), diff --git a/lib/src/RateLimiter.cc b/lib/src/RateLimiter.cc new file mode 100644 index 00000000..821a7dfd --- /dev/null +++ b/lib/src/RateLimiter.cc @@ -0,0 +1,24 @@ +#include +#include "FixedWindowRateLimiter.h" +#include "SlidingWindowRateLimiter.h" +#include "TokenBucketRateLimiter.h" + +using namespace drogon; + +RateLimiterPtr RateLimiter::newRateLimiter( + RateLimiterType type, + size_t capacity, + std::chrono::duration timeUnit) +{ + switch (type) + { + case RateLimiterType::kFixedWindow: + return std::make_shared(capacity, timeUnit); + case RateLimiterType::kSlidingWindow: + return std::make_shared(capacity, + timeUnit); + case RateLimiterType::kTokenBucket: + return std::make_shared(capacity, timeUnit); + } + return std::make_shared(capacity, timeUnit); +} diff --git a/lib/src/SlidingWindowRateLimiter.cc b/lib/src/SlidingWindowRateLimiter.cc new file mode 100644 index 00000000..5f485327 --- /dev/null +++ b/lib/src/SlidingWindowRateLimiter.cc @@ -0,0 +1,58 @@ +#include "SlidingWindowRateLimiter.h" +#include + +using namespace drogon; + +SlidingWindowRateLimiter::SlidingWindowRateLimiter( + size_t capacity, + std::chrono::duration timeUnit) + : capacity_(capacity), + unitStartTime_(std::chrono::steady_clock::now()), + lastTime_(unitStartTime_), + timeUnit_(timeUnit) +{ +} +// implementation of the sliding window algorithm +bool SlidingWindowRateLimiter::isAllowed() +{ + auto now = std::chrono::steady_clock::now(); + unitStartTime_ = + unitStartTime_ + + std::chrono::duration_cast( + std::chrono::duration( + static_cast((uint64_t)( + std::chrono::duration_cast>( + now - unitStartTime_) + .count() / + timeUnit_.count())) * + timeUnit_.count())); + + if (unitStartTime_ > lastTime_) + { + auto duration = + std::chrono::duration_cast>( + unitStartTime_ - lastTime_); + auto startTime = lastTime_; + if (duration >= timeUnit_) + { + previousRequests_ = 0; + } + else + { + previousRequests_ = currentRequests_; + } + currentRequests_ = 0; + } + auto coef = std::chrono::duration_cast>( + now - unitStartTime_) / + timeUnit_; + assert(coef <= 1.0); + auto count = previousRequests_ * (1.0 - coef) + currentRequests_; + if (count < capacity_) + { + currentRequests_++; + lastTime_ = now; + return true; + } + return false; +} diff --git a/lib/src/SlidingWindowRateLimiter.h b/lib/src/SlidingWindowRateLimiter.h new file mode 100644 index 00000000..96ee916f --- /dev/null +++ b/lib/src/SlidingWindowRateLimiter.h @@ -0,0 +1,22 @@ +#pragma once +#include +#include + +namespace drogon +{ +class SlidingWindowRateLimiter : public RateLimiter +{ + public: + SlidingWindowRateLimiter(size_t capacity, + std::chrono::duration timeUnit); + bool isAllowed() override; + + private: + size_t capacity_; + size_t currentRequests_{0}; + size_t previousRequests_{0}; + std::chrono::steady_clock::time_point unitStartTime_; + std::chrono::steady_clock::time_point lastTime_; + std::chrono::duration timeUnit_; +}; +} // namespace drogon \ No newline at end of file diff --git a/lib/src/TokenBucketRateLimiter.cc b/lib/src/TokenBucketRateLimiter.cc new file mode 100644 index 00000000..d39244d2 --- /dev/null +++ b/lib/src/TokenBucketRateLimiter.cc @@ -0,0 +1,30 @@ +#include "TokenBucketRateLimiter.h" + +using namespace drogon; + +TokenBucketRateLimiter::TokenBucketRateLimiter( + size_t capacity, + std::chrono::duration timeUnit) + : capacity_(capacity), + lastTime_(std::chrono::steady_clock::now()), + timeUnit_(timeUnit) +{ +} + +// implementation of the token bucket algorithm +bool TokenBucketRateLimiter::isAllowed() +{ + auto now = std::chrono::steady_clock::now(); + auto duration = std::chrono::duration_cast>( + now - lastTime_); + tokens_ += capacity_ * (duration / timeUnit_); + if (tokens_ > capacity_) + tokens_ = capacity_; + lastTime_ = now; + if (tokens_ > 1.0) + { + tokens_ -= 1.0; + return true; + } + return false; +} \ No newline at end of file diff --git a/lib/src/TokenBucketRateLimiter.h b/lib/src/TokenBucketRateLimiter.h new file mode 100644 index 00000000..04b1fa2a --- /dev/null +++ b/lib/src/TokenBucketRateLimiter.h @@ -0,0 +1,20 @@ +#pragma once + +#include + +namespace drogon +{ +class TokenBucketRateLimiter : public RateLimiter +{ + public: + TokenBucketRateLimiter(size_t capacity, + std::chrono::duration timeUnit); + bool isAllowed() override; + + private: + size_t capacity_; + std::chrono::steady_clock::time_point lastTime_; + std::chrono::duration timeUnit_; + double tokens_{0}; +}; +} // namespace drogon diff --git a/lib/src/WebsocketControllersRouter.cc b/lib/src/WebsocketControllersRouter.cc index 48fc6ed9..e56fd92f 100644 --- a/lib/src/WebsocketControllersRouter.cc +++ b/lib/src/WebsocketControllersRouter.cc @@ -91,6 +91,95 @@ void WebsocketControllersRouter::registerWebSocketController( } } +void WebsocketControllersRouter::doPreHandlingAdvices( + const WebSocketControllerRouterItem &routerItem, + std::string &wsKey, + const HttpRequestImplPtr &req, + std::function &&callback, + const WebSocketConnectionImplPtr &wsConnPtr) +{ + if (req->method() == Options) + { + auto resp = HttpResponse::newHttpResponse(); + resp->setContentTypeCode(ContentType::CT_TEXT_PLAIN); + std::string methods = "OPTIONS,"; + if (routerItem.binders_[Get] && routerItem.binders_[Get]->isCORS_) + { + methods.append("GET,HEAD,"); + } + if (routerItem.binders_[Post] && routerItem.binders_[Post]->isCORS_) + { + methods.append("POST,"); + } + if (routerItem.binders_[Put] && routerItem.binders_[Put]->isCORS_) + { + methods.append("PUT,"); + } + if (routerItem.binders_[Delete] && routerItem.binders_[Delete]->isCORS_) + { + methods.append("DELETE,"); + } + if (routerItem.binders_[Patch] && routerItem.binders_[Patch]->isCORS_) + { + methods.append("PATCH,"); + } + methods.resize(methods.length() - 1); + resp->addHeader("ALLOW", methods); + auto &origin = req->getHeader("Origin"); + if (origin.empty()) + { + resp->addHeader("Access-Control-Allow-Origin", "*"); + } + else + { + resp->addHeader("Access-Control-Allow-Origin", origin); + } + resp->addHeader("Access-Control-Allow-Methods", methods); + auto &headers = req->getHeaderBy("access-control-request-headers"); + if (!headers.empty()) + { + resp->addHeader("Access-Control-Allow-Headers", headers); + } + callback(resp); + return; + } + if (!preHandlingObservers_.empty()) + { + for (auto &observer : preHandlingObservers_) + { + observer(req); + } + } + if (preHandlingAdvices_.empty()) + { + doControllerHandler( + routerItem, wsKey, req, std::move(callback), wsConnPtr); + } + else + { + auto callbackPtr = + std::make_shared>( + std::move(callback)); + doAdvicesChain( + preHandlingAdvices_, + 0, + req, + std::make_shared>( + [callbackPtr](const HttpResponsePtr &resp) { + (*callbackPtr)(resp); + }), + [this, + &routerItem, + wsKey = std::move(wsKey), + req, + callbackPtr, + wsConnPtr]() mutable { + doControllerHandler( + routerItem, wsKey, req, std::move(*callbackPtr), wsConnPtr); + }); + } +} + void WebsocketControllersRouter::route( const HttpRequestImplPtr &req, std::function &&callback, @@ -149,7 +238,7 @@ void WebsocketControllersRouter::route( wsConnPtr, &ctrlInfo, this]() mutable { - doControllerHandler( + doPreHandlingAdvices( ctrlInfo, wsKey, req, @@ -159,7 +248,7 @@ void WebsocketControllersRouter::route( } else { - doControllerHandler( + doPreHandlingAdvices( ctrlInfo, wsKey, req, std::move(callback), wsConnPtr); } return; @@ -169,46 +258,47 @@ void WebsocketControllersRouter::route( auto callbackPtr = std::make_shared< std::function>( std::move(callback)); - doAdvicesChain( - postRoutingAdvices_, - 0, - req, - callbackPtr, - [callbackPtr, - &filters, - req, - &ctrlInfo, - this, - wsKey = std::move(wsKey), - wsConnPtr]() mutable { - if (!filters.empty()) - { - filters_function::doFilters( - filters, + doAdvicesChain(postRoutingAdvices_, + 0, + req, + callbackPtr, + [callbackPtr, + &filters, req, - callbackPtr, - [this, - wsKey = std::move(wsKey), - callbackPtr, - wsConnPtr = std::move(wsConnPtr), - req, - &ctrlInfo]() mutable { - doControllerHandler(ctrlInfo, - wsKey, - req, - std::move(*callbackPtr), - wsConnPtr); - }); - } - else - { - doControllerHandler(ctrlInfo, - wsKey, - req, - std::move(*callbackPtr), - wsConnPtr); - } - }); + &ctrlInfo, + this, + wsKey = std::move(wsKey), + wsConnPtr]() mutable { + if (!filters.empty()) + { + filters_function::doFilters( + filters, + req, + callbackPtr, + [this, + wsKey = std::move(wsKey), + callbackPtr, + wsConnPtr = std::move(wsConnPtr), + req, + &ctrlInfo]() mutable { + doPreHandlingAdvices( + ctrlInfo, + wsKey, + req, + std::move(*callbackPtr), + wsConnPtr); + }); + } + else + { + doPreHandlingAdvices(ctrlInfo, + wsKey, + req, + std::move( + *callbackPtr), + wsConnPtr); + } + }); } return; } @@ -247,52 +337,6 @@ void WebsocketControllersRouter::doControllerHandler( std::function &&callback, const WebSocketConnectionImplPtr &wsConnPtr) { - if (req->method() == Options) - { - auto resp = HttpResponse::newHttpResponse(); - resp->setContentTypeCode(ContentType::CT_TEXT_PLAIN); - std::string methods = "OPTIONS,"; - if (routerItem.binders_[Get] && routerItem.binders_[Get]->isCORS_) - { - methods.append("GET,HEAD,"); - } - if (routerItem.binders_[Post] && routerItem.binders_[Post]->isCORS_) - { - methods.append("POST,"); - } - if (routerItem.binders_[Put] && routerItem.binders_[Put]->isCORS_) - { - methods.append("PUT,"); - } - if (routerItem.binders_[Delete] && routerItem.binders_[Delete]->isCORS_) - { - methods.append("DELETE,"); - } - if (routerItem.binders_[Patch] && routerItem.binders_[Patch]->isCORS_) - { - methods.append("PATCH,"); - } - methods.resize(methods.length() - 1); - resp->addHeader("ALLOW", methods); - - auto &origin = req->getHeader("Origin"); - if (origin.empty()) - { - resp->addHeader("Access-Control-Allow-Origin", "*"); - } - else - { - resp->addHeader("Access-Control-Allow-Origin", origin); - } - resp->addHeader("Access-Control-Allow-Methods", methods); - auto &headers = req->getHeaderBy("access-control-request-headers"); - if (!headers.empty()) - { - resp->addHeader("Access-Control-Allow-Headers", headers); - } - callback(resp); - return; - } wsKey.append("258EAFA5-E914-47DA-95CA-C5AB0DC85B11"); unsigned char accKey[SHA_DIGEST_LENGTH]; SHA1(reinterpret_cast(wsKey.c_str()), @@ -304,6 +348,10 @@ void WebsocketControllersRouter::doControllerHandler( resp->addHeader("Upgrade", "websocket"); resp->addHeader("Connection", "Upgrade"); resp->addHeader("Sec-WebSocket-Accept", base64Key); + for (auto &advice : postHandlingAdvices_) + { + advice(req, resp); + } callback(resp); auto ctrlPtr = routerItem.binders_[req->method()]->controller_; wsConnPtr->setMessageCallback( @@ -335,4 +383,4 @@ void WebsocketControllersRouter::init() } } } -} \ No newline at end of file +} diff --git a/lib/src/WebsocketControllersRouter.h b/lib/src/WebsocketControllersRouter.h index 2e7f6abe..7ddb88b3 100644 --- a/lib/src/WebsocketControllersRouter.h +++ b/lib/src/WebsocketControllersRouter.h @@ -38,9 +38,21 @@ class WebsocketControllersRouter : public trantor::NonCopyable AdviceChainCallback &&)>> &postRoutingAdvices, const std::vector> - &postRoutingObservers) + &postRoutingObservers, + const std::vector> + &preHandlingAdvices, + const std::vector> + &preHandlingObservers, + const std::vector> + &postHandlingAdvices) : postRoutingAdvices_(postRoutingAdvices), - postRoutingObservers_(postRoutingObservers) + postRoutingObservers_(postRoutingObservers), + preHandlingAdvices_(preHandlingAdvices), + preHandlingObservers_(preHandlingObservers), + postHandlingAdvices_(postHandlingAdvices) { } void registerWebSocketController( @@ -78,11 +90,26 @@ class WebsocketControllersRouter : public trantor::NonCopyable &postRoutingAdvices_; const std::vector> &postRoutingObservers_; + const std::vector> + &preHandlingAdvices_; + const std::vector> + &preHandlingObservers_; + const std::vector< + std::function> + &postHandlingAdvices_; void doControllerHandler( const WebSocketControllerRouterItem &routerItem, std::string &wsKey, const HttpRequestImplPtr &req, std::function &&callback, const WebSocketConnectionImplPtr &wsConnPtr); + void doPreHandlingAdvices( + const WebSocketControllerRouterItem &routerItem, + std::string &wsKey, + const HttpRequestImplPtr &req, + std::function &&callback, + const WebSocketConnectionImplPtr &wsConnPtr); }; } // namespace drogon