Add rate limiter (#1409)

This commit is contained in:
An Tao 2022-10-27 22:49:16 +08:00 committed by GitHub
parent 37a10318ff
commit 164972e2d3
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
16 changed files with 838 additions and 94 deletions

View File

@ -271,7 +271,12 @@ set(DROGON_SOURCES
lib/src/Utilities.cc
lib/src/WebSocketClientImpl.cc
lib/src/WebSocketConnectionImpl.cc
lib/src/WebsocketControllersRouter.cc)
lib/src/WebsocketControllersRouter.cc
lib/src/RateLimiter.cc
lib/src/FixedWindowRateLimiter.cc
lib/src/SlidingWindowRateLimiter.cc
lib/src/TokenBucketRateLimiter.cc
lib/src/Hodor.cc)
set(private_headers
lib/src/AOPAdvice.h
lib/src/CacheFile.h
@ -300,7 +305,10 @@ set(private_headers
lib/src/TaskTimeoutFlag.h
lib/src/WebSocketClientImpl.h
lib/src/WebSocketConnectionImpl.h
lib/src/WebsocketControllersRouter.h)
lib/src/WebsocketControllersRouter.h
lib/src/FixedWindowRateLimiter.h
lib/src/SlidingWindowRateLimiter.h
lib/src/TokenBucketRateLimiter.h)
if (NOT WIN32)
set(DROGON_SOURCES
@ -556,6 +564,7 @@ set(DROGON_HEADERS
lib/inc/drogon/drogon_callbacks.h
lib/inc/drogon/PubSubService.h
lib/inc/drogon/drogon_test.h
lib/inc/drogon/RateLimiter.h
${CMAKE_CURRENT_BINARY_DIR}/exports/drogon/exports.h)
set(private_headers
${private_headers}
@ -697,7 +706,8 @@ set(DROGON_PLUGIN_HEADERS
lib/inc/drogon/plugins/Plugin.h
lib/inc/drogon/plugins/SecureSSLRedirector.h
lib/inc/drogon/plugins/AccessLogger.h
lib/inc/drogon/plugins/RealIpResolver.h)
lib/inc/drogon/plugins/RealIpResolver.h
lib/inc/drogon/plugins/Hodor.h)
install(FILES ${DROGON_PLUGIN_HEADERS}
DESTINATION ${INSTALL_INCLUDE_DIR}/drogon/plugins)

View File

@ -0,0 +1,67 @@
#pragma once
#include <drogon/exports.h>
#include <memory>
#include <chrono>
#include <mutex>
namespace drogon
{
enum class DROGON_EXPORT RateLimiterType
{
kFixedWindow,
kSlidingWindow,
kTokenBucket
};
inline RateLimiterType stringToRateLimiterType(const std::string &type)
{
if (type == "fixedWindow" || type == "fixed_window")
return RateLimiterType::kFixedWindow;
else if (type == "slidingWindow" || type == "sliding_window")
return RateLimiterType::kSlidingWindow;
return RateLimiterType::kTokenBucket;
}
class DROGON_EXPORT RateLimiter;
using RateLimiterPtr = std::shared_ptr<RateLimiter>;
/**
* @brief This class is used to limit the number of requests per second
*
* */
class DROGON_EXPORT RateLimiter
{
public:
/**
* @brief Create a rate limiter
* @param type The type of the rate limiter
* @param capacity The maximum number of requests in the time unit.
* @param timeUnit The time unit of the rate limiter.
* @return A rate limiter pointer
*/
static RateLimiterPtr newRateLimiter(
RateLimiterType type,
size_t capacity,
std::chrono::duration<double> timeUnit = std::chrono::seconds(60));
/**
* @brief Check if a request is allowed
*
* @return true The request is allowed
* @return false The request is not allowed
*/
virtual bool isAllowed() = 0;
};
class DROGON_EXPORT SafeRateLimiter : public RateLimiter
{
public:
SafeRateLimiter(RateLimiterPtr limiter) : limiter_(limiter)
{
}
bool isAllowed() override
{
std::lock_guard<std::mutex> lock(mutex_);
return limiter_->isAllowed();
}
private:
RateLimiterPtr limiter_;
std::mutex mutex_;
};
} // namespace drogon

View File

@ -30,6 +30,7 @@
#include <drogon/plugins/SecureSSLRedirector.h>
#include <drogon/plugins/AccessLogger.h>
#include <drogon/plugins/RealIpResolver.h>
#include <drogon/plugins/Hodor.h>
#include <drogon/Cookie.h>
#include <drogon/Session.h>
#include <drogon/IOThreadStorage.h>

View File

@ -0,0 +1,146 @@
/**
* @file Hodor.h
* @author An Tao
*
* Copyright 2018, An Tao. All rights reserved.
* https://github.com/an-tao/drogon
* Use of this source code is governed by a MIT license
* that can be found in the License file.
*
* Drogon
*
*/
#pragma once
#include <drogon/RateLimiter.h>
#include <drogon/plugins/Plugin.h>
#include <drogon/plugins/RealIpResolver.h>
#include <drogon/HttpAppFramework.h>
#include <drogon/utils/optional.h>
#include <drogon/CacheMap.h>
#include <regex>
namespace drogon
{
namespace plugin
{
/**
* @brief The Hodor plugin implements a global rate limiter that limits the
* number of requests in a particular time unit.
* The json configuration is as follows:
*
* @code
{
"name": "drogon::plugin::Hodor",
"dependencies": [],
"config": {
// The algorithm used to limit the number of requests.
// The default value is "token_bucket". other values are "fixed_window"
or "sliding_window".
"algorithm": "token_bucket",
// a regular expression (for matching the path of a request) list for
URLs that have to be limited. if the list is empty, all URLs are limited.
"urls": ["^/api/.*", ...],
// The time unit in seconds. the default value is 60.
"time_unit": 60,
// The maximum number of requests in a time unit. the default value 0
means no limit.
"capacity": 0,
// The maximum number of requests in a time unit for a single IP. the
default value 0 means no limit.
"ip_capacity": 0,
// The maximum number of requests in a time unit for a single user.
a function must be provided to the plugin to get the user id from the request.
the default value 0 means no limit.
"user_capacity": 0,
// Use the RealIpResolver plugin to get the real IP address of the
request. if this option is true, the RealIpResolver plugin should be added to
the dependencies list. the default value is false.
"use_real_ip_resolver": false,
// Multiple threads mode: the default value is true. if this option is
true, some mutexes are used for thread-safe.
"multi_threads": true,
// The message body of the response when the request is rejected.
"rejection_message": "Too many requests",
// In seconds, the minimum expiration time of the limiters for different
IPs or users. the default value is 600.
"limiter_expire_time": 600,
"sub_limits": [
{
"urls": ["^/api/1/.*", ...],
"capacity": 0,
"ip_capacity": 0,
"user_capacity": 0
},...
]
}
}
@endcode
*
* Enable the plugin by adding the configuration to the list of plugins in the
* configuration file.
* */
class DROGON_EXPORT Hodor : public drogon::Plugin<Hodor>
{
public:
Hodor()
{
}
void initAndStart(const Json::Value &config) override;
void shutdown() override;
/**
* @brief the method is used to set a function to get the user id from the
* request. users should call this method after calling the app().run()
* method. etc. use the beginning advice of AOP.
* */
void setUserIdGetter(
std::function<optional<std::string>(const HttpRequestPtr &)> func)
{
userIdGetter_ = std::move(func);
}
/**
* @brief the method is used to set a function to create the response when
* the rate limit is exceeded. users should call this method after calling
* the app().run() method. etc. use the beginning advice of AOP.
* */
void setRejectResponseFactory(
std::function<HttpResponsePtr(const HttpRequestPtr &)> func)
{
rejectResponseFactory_ = std::move(func);
}
private:
struct LimitStrategy
{
std::regex urlsRegex;
size_t capacity{0};
size_t ipCapacity{0};
size_t userCapacity{0};
bool regexFlag{false};
RateLimiterPtr globalLimiterPtr;
std::unique_ptr<CacheMap<std::string, RateLimiterPtr>> ipLimiterMapPtr;
std::unique_ptr<CacheMap<std::string, RateLimiterPtr>>
userLimiterMapPtr;
};
LimitStrategy makeLimitStrategy(const Json::Value &config);
std::vector<LimitStrategy> limitStrategies_;
RateLimiterType algorithm_{RateLimiterType::kTokenBucket};
std::chrono::duration<double> timeUnit_{1.0};
bool multiThreads_{true};
bool useRealIpResolver_{false};
size_t limiterExpireTime_{600};
std::function<optional<std::string>(const drogon::HttpRequestPtr &)>
userIdGetter_;
std::function<HttpResponsePtr(const HttpRequestPtr &)>
rejectResponseFactory_;
void onHttpRequest(const HttpRequestPtr &,
AdviceCallback &&,
AdviceChainCallback &&);
bool checkLimit(const HttpRequestPtr &req,
const LimitStrategy &strategy,
const std::string &ip,
const optional<std::string> &userId);
HttpResponsePtr rejectResponse_;
};
} // namespace plugin
} // namespace drogon

View File

@ -8,6 +8,7 @@
#include <drogon/plugins/Plugin.h>
#include <trantor/net/InetAddress.h>
#include <drogon/HttpRequest.h>
#include <vector>
namespace drogon

View File

@ -0,0 +1,31 @@
#include "FixedWindowRateLimiter.h"
using namespace drogon;
FixedWindowRateLimiter::FixedWindowRateLimiter(
size_t capacity,
std::chrono::duration<double> timeUnit)
: capacity_(capacity),
lastTime_(std::chrono::steady_clock::now()),
timeUnit_(timeUnit)
{
}
// implementation of the fixed window algorithm
bool FixedWindowRateLimiter::isAllowed()
{
auto now = std::chrono::steady_clock::now();
auto duration = std::chrono::duration_cast<std::chrono::duration<double>>(
now - lastTime_);
if (duration >= timeUnit_)
{
currentRequests_ = 0;
lastTime_ = now;
}
if (currentRequests_ < capacity_)
{
currentRequests_++;
return true;
}
return false;
}

View File

@ -0,0 +1,20 @@
#pragma once
#include <drogon/RateLimiter.h>
#include <chrono>
namespace drogon
{
class FixedWindowRateLimiter : public RateLimiter
{
public:
FixedWindowRateLimiter(size_t capacity,
std::chrono::duration<double> timeUnit);
bool isAllowed() override;
private:
size_t capacity_;
size_t currentRequests_{0};
std::chrono::steady_clock::time_point lastTime_;
std::chrono::duration<double> timeUnit_;
};
} // namespace drogon

236
lib/src/Hodor.cc Normal file
View File

@ -0,0 +1,236 @@
#include <drogon/plugins/Hodor.h>
#include <drogon/plugins/RealIpResolver.h>
using namespace drogon::plugin;
Hodor::LimitStrategy Hodor::makeLimitStrategy(const Json::Value &config)
{
LimitStrategy strategy;
strategy.capacity = config.get("capacity", 0).asUInt();
if (config.isMember("urls") && config["urls"].isArray())
{
std::string regexString;
for (auto &str : config["urls"])
{
assert(str.isString());
regexString.append("(").append(str.asString()).append(")|");
}
if (!regexString.empty())
{
regexString.resize(regexString.length() - 1);
strategy.urlsRegex = std::regex(regexString);
strategy.regexFlag = true;
}
}
if (strategy.capacity > 0)
{
if (multiThreads_)
{
strategy.globalLimiterPtr = std::make_shared<SafeRateLimiter>(
RateLimiter::newRateLimiter(algorithm_,
strategy.capacity,
timeUnit_));
}
else
{
strategy.globalLimiterPtr =
RateLimiter::newRateLimiter(algorithm_,
strategy.capacity,
timeUnit_);
}
}
strategy.ipCapacity = config.get("ip_capacity", 0).asUInt();
if (strategy.ipCapacity > 0)
{
strategy.ipLimiterMapPtr =
std::make_unique<CacheMap<std::string, RateLimiterPtr>>(
drogon::app().getLoop(),
timeUnit_.count() / 60 < 1 ? 1 : timeUnit_.count() / 60,
2,
100);
}
strategy.userCapacity = config.get("user_capacity", 0).asUInt();
if (strategy.userCapacity > 0)
{
strategy.userLimiterMapPtr =
std::make_unique<CacheMap<std::string, RateLimiterPtr>>(
drogon::app().getLoop(),
timeUnit_.count() / 60 < 1 ? 1 : timeUnit_.count() / 60,
2,
100);
}
return strategy;
}
void Hodor::initAndStart(const Json::Value &config)
{
algorithm_ = stringToRateLimiterType(
config.get("algorithm", "token_bucket").asString());
timeUnit_ = std::chrono::seconds(config.get("time_unit", 60).asUInt());
multiThreads_ = config.get("multi_threads", true).asBool();
useRealIpResolver_ = config.get("use_real_ip_resolver", false).asBool();
rejectResponse_ = HttpResponse::newHttpResponse();
rejectResponse_->setStatusCode(k429TooManyRequests);
rejectResponse_->setBody(
config.get("rejection_message", "Too many requests").asString());
rejectResponse_->setCloseConnection(true);
limiterExpireTime_ =
(std::min)(static_cast<size_t>(
config.get("limiter_expire_time", 600).asUInt()),
static_cast<size_t>(timeUnit_.count() * 3));
limitStrategies_.emplace_back(makeLimitStrategy(config));
if (config.isMember("sub_limits") && config["sub_limits"].isArray())
{
for (auto &subLimit : config["sub_limits"])
{
assert(subLimit.isObject());
if (!subLimit["urls"].isArray() || subLimit["urls"].size() == 0)
{
LOG_ERROR
<< "The urls of sub_limits must be an array and not empty!";
continue;
}
if (subLimit["capacity"].asUInt() == 0 &&
subLimit["ip_capacity"].asUInt() == 0 &&
subLimit["user_capacity"].asUInt() == 0)
{
LOG_ERROR << "At least one capacity of sub_limits must be "
"greater than 0!";
continue;
}
limitStrategies_.emplace_back(makeLimitStrategy(subLimit));
}
}
app().registerPreHandlingAdvice([this](const HttpRequestPtr &req,
AdviceCallback &&acb,
AdviceChainCallback &&accb) {
onHttpRequest(req, std::move(acb), std::move(accb));
});
}
void Hodor::shutdown()
{
LOG_TRACE << "Hodor plugin is shutdown!";
}
bool Hodor::checkLimit(const HttpRequestPtr &req,
const LimitStrategy &strategy,
const std::string &ip,
const optional<std::string> &userId)
{
if (strategy.regexFlag)
{
if (!std::regex_match(req->path(), strategy.urlsRegex))
{
return true;
}
}
if (strategy.globalLimiterPtr)
{
if (!strategy.globalLimiterPtr->isAllowed())
{
return false;
}
}
if (strategy.ipCapacity > 0)
{
RateLimiterPtr limiterPtr;
strategy.ipLimiterMapPtr->modify(
ip,
[this, &limiterPtr, &strategy](RateLimiterPtr &ptr) {
if (!ptr)
{
if (multiThreads_)
{
ptr = std::make_shared<SafeRateLimiter>(
RateLimiter::newRateLimiter(algorithm_,
strategy.ipCapacity,
timeUnit_));
}
else
{
ptr = RateLimiter::newRateLimiter(algorithm_,
strategy.ipCapacity,
timeUnit_);
}
}
limiterPtr = ptr;
},
limiterExpireTime_);
if (!limiterPtr->isAllowed())
{
return false;
}
}
if (strategy.userCapacity > 0)
{
if (!userId.has_value())
{
return true;
}
RateLimiterPtr limiterPtr;
strategy.userLimiterMapPtr->modify(
*userId,
[this, &strategy, &limiterPtr](RateLimiterPtr &ptr) {
if (!ptr)
{
if (multiThreads_)
{
ptr = std::make_shared<SafeRateLimiter>(
RateLimiter::newRateLimiter(algorithm_,
strategy.userCapacity,
timeUnit_));
}
else
{
ptr = RateLimiter::newRateLimiter(algorithm_,
strategy.userCapacity,
timeUnit_);
}
}
limiterPtr = ptr;
},
limiterExpireTime_);
if (!limiterPtr->isAllowed())
{
return false;
}
}
return true;
}
void Hodor::onHttpRequest(const drogon::HttpRequestPtr &req,
drogon::AdviceCallback &&adviceCallback,
drogon::AdviceChainCallback &&chainCallback)
{
std::string ip;
if (useRealIpResolver_)
{
ip = drogon::plugin::RealIpResolver::GetRealAddr(req).toIp();
}
else
{
ip = req->peerAddr().toIp();
}
optional<std::string> userId;
if (userIdGetter_)
{
userId = userIdGetter_(req);
}
for (auto &strategy : limitStrategies_)
{
if (!checkLimit(req, strategy, ip, userId))
{
if (rejectResponseFactory_)
{
adviceCallback(rejectResponseFactory_(req));
}
else
{
adviceCallback(rejectResponse_);
}
return;
}
}
chainCallback();
}

View File

@ -92,7 +92,10 @@ HttpAppFrameworkImpl::HttpAppFrameworkImpl()
postHandlingAdvices_)),
websockCtrlsRouterPtr_(
new WebsocketControllersRouter(postRoutingAdvices_,
postRoutingObservers_)),
postRoutingObservers_,
preHandlingAdvices_,
preHandlingObservers_,
postHandlingAdvices_)),
listenerManagerPtr_(new ListenerManager),
pluginsManagerPtr_(new PluginsManager),
dbClientManagerPtr_(new orm::DbClientManager),

24
lib/src/RateLimiter.cc Normal file
View File

@ -0,0 +1,24 @@
#include <drogon/RateLimiter.h>
#include "FixedWindowRateLimiter.h"
#include "SlidingWindowRateLimiter.h"
#include "TokenBucketRateLimiter.h"
using namespace drogon;
RateLimiterPtr RateLimiter::newRateLimiter(
RateLimiterType type,
size_t capacity,
std::chrono::duration<double> timeUnit)
{
switch (type)
{
case RateLimiterType::kFixedWindow:
return std::make_shared<FixedWindowRateLimiter>(capacity, timeUnit);
case RateLimiterType::kSlidingWindow:
return std::make_shared<SlidingWindowRateLimiter>(capacity,
timeUnit);
case RateLimiterType::kTokenBucket:
return std::make_shared<TokenBucketRateLimiter>(capacity, timeUnit);
}
return std::make_shared<TokenBucketRateLimiter>(capacity, timeUnit);
}

View File

@ -0,0 +1,58 @@
#include "SlidingWindowRateLimiter.h"
#include <assert.h>
using namespace drogon;
SlidingWindowRateLimiter::SlidingWindowRateLimiter(
size_t capacity,
std::chrono::duration<double> timeUnit)
: capacity_(capacity),
unitStartTime_(std::chrono::steady_clock::now()),
lastTime_(unitStartTime_),
timeUnit_(timeUnit)
{
}
// implementation of the sliding window algorithm
bool SlidingWindowRateLimiter::isAllowed()
{
auto now = std::chrono::steady_clock::now();
unitStartTime_ =
unitStartTime_ +
std::chrono::duration_cast<decltype(unitStartTime_)::duration>(
std::chrono::duration<double>(
static_cast<double>((uint64_t)(
std::chrono::duration_cast<std::chrono::duration<double>>(
now - unitStartTime_)
.count() /
timeUnit_.count())) *
timeUnit_.count()));
if (unitStartTime_ > lastTime_)
{
auto duration =
std::chrono::duration_cast<std::chrono::duration<double>>(
unitStartTime_ - lastTime_);
auto startTime = lastTime_;
if (duration >= timeUnit_)
{
previousRequests_ = 0;
}
else
{
previousRequests_ = currentRequests_;
}
currentRequests_ = 0;
}
auto coef = std::chrono::duration_cast<std::chrono::duration<double>>(
now - unitStartTime_) /
timeUnit_;
assert(coef <= 1.0);
auto count = previousRequests_ * (1.0 - coef) + currentRequests_;
if (count < capacity_)
{
currentRequests_++;
lastTime_ = now;
return true;
}
return false;
}

View File

@ -0,0 +1,22 @@
#pragma once
#include <drogon/RateLimiter.h>
#include <chrono>
namespace drogon
{
class SlidingWindowRateLimiter : public RateLimiter
{
public:
SlidingWindowRateLimiter(size_t capacity,
std::chrono::duration<double> timeUnit);
bool isAllowed() override;
private:
size_t capacity_;
size_t currentRequests_{0};
size_t previousRequests_{0};
std::chrono::steady_clock::time_point unitStartTime_;
std::chrono::steady_clock::time_point lastTime_;
std::chrono::duration<double> timeUnit_;
};
} // namespace drogon

View File

@ -0,0 +1,30 @@
#include "TokenBucketRateLimiter.h"
using namespace drogon;
TokenBucketRateLimiter::TokenBucketRateLimiter(
size_t capacity,
std::chrono::duration<double> timeUnit)
: capacity_(capacity),
lastTime_(std::chrono::steady_clock::now()),
timeUnit_(timeUnit)
{
}
// implementation of the token bucket algorithm
bool TokenBucketRateLimiter::isAllowed()
{
auto now = std::chrono::steady_clock::now();
auto duration = std::chrono::duration_cast<std::chrono::duration<double>>(
now - lastTime_);
tokens_ += capacity_ * (duration / timeUnit_);
if (tokens_ > capacity_)
tokens_ = capacity_;
lastTime_ = now;
if (tokens_ > 1.0)
{
tokens_ -= 1.0;
return true;
}
return false;
}

View File

@ -0,0 +1,20 @@
#pragma once
#include <drogon/RateLimiter.h>
namespace drogon
{
class TokenBucketRateLimiter : public RateLimiter
{
public:
TokenBucketRateLimiter(size_t capacity,
std::chrono::duration<double> timeUnit);
bool isAllowed() override;
private:
size_t capacity_;
std::chrono::steady_clock::time_point lastTime_;
std::chrono::duration<double> timeUnit_;
double tokens_{0};
};
} // namespace drogon

View File

@ -91,6 +91,95 @@ void WebsocketControllersRouter::registerWebSocketController(
}
}
void WebsocketControllersRouter::doPreHandlingAdvices(
const WebSocketControllerRouterItem &routerItem,
std::string &wsKey,
const HttpRequestImplPtr &req,
std::function<void(const HttpResponsePtr &)> &&callback,
const WebSocketConnectionImplPtr &wsConnPtr)
{
if (req->method() == Options)
{
auto resp = HttpResponse::newHttpResponse();
resp->setContentTypeCode(ContentType::CT_TEXT_PLAIN);
std::string methods = "OPTIONS,";
if (routerItem.binders_[Get] && routerItem.binders_[Get]->isCORS_)
{
methods.append("GET,HEAD,");
}
if (routerItem.binders_[Post] && routerItem.binders_[Post]->isCORS_)
{
methods.append("POST,");
}
if (routerItem.binders_[Put] && routerItem.binders_[Put]->isCORS_)
{
methods.append("PUT,");
}
if (routerItem.binders_[Delete] && routerItem.binders_[Delete]->isCORS_)
{
methods.append("DELETE,");
}
if (routerItem.binders_[Patch] && routerItem.binders_[Patch]->isCORS_)
{
methods.append("PATCH,");
}
methods.resize(methods.length() - 1);
resp->addHeader("ALLOW", methods);
auto &origin = req->getHeader("Origin");
if (origin.empty())
{
resp->addHeader("Access-Control-Allow-Origin", "*");
}
else
{
resp->addHeader("Access-Control-Allow-Origin", origin);
}
resp->addHeader("Access-Control-Allow-Methods", methods);
auto &headers = req->getHeaderBy("access-control-request-headers");
if (!headers.empty())
{
resp->addHeader("Access-Control-Allow-Headers", headers);
}
callback(resp);
return;
}
if (!preHandlingObservers_.empty())
{
for (auto &observer : preHandlingObservers_)
{
observer(req);
}
}
if (preHandlingAdvices_.empty())
{
doControllerHandler(
routerItem, wsKey, req, std::move(callback), wsConnPtr);
}
else
{
auto callbackPtr =
std::make_shared<std::function<void(const HttpResponsePtr &)>>(
std::move(callback));
doAdvicesChain(
preHandlingAdvices_,
0,
req,
std::make_shared<std::function<void(const HttpResponsePtr &)>>(
[callbackPtr](const HttpResponsePtr &resp) {
(*callbackPtr)(resp);
}),
[this,
&routerItem,
wsKey = std::move(wsKey),
req,
callbackPtr,
wsConnPtr]() mutable {
doControllerHandler(
routerItem, wsKey, req, std::move(*callbackPtr), wsConnPtr);
});
}
}
void WebsocketControllersRouter::route(
const HttpRequestImplPtr &req,
std::function<void(const HttpResponsePtr &)> &&callback,
@ -149,7 +238,7 @@ void WebsocketControllersRouter::route(
wsConnPtr,
&ctrlInfo,
this]() mutable {
doControllerHandler(
doPreHandlingAdvices(
ctrlInfo,
wsKey,
req,
@ -159,7 +248,7 @@ void WebsocketControllersRouter::route(
}
else
{
doControllerHandler(
doPreHandlingAdvices(
ctrlInfo, wsKey, req, std::move(callback), wsConnPtr);
}
return;
@ -169,46 +258,47 @@ void WebsocketControllersRouter::route(
auto callbackPtr = std::make_shared<
std::function<void(const HttpResponsePtr &)>>(
std::move(callback));
doAdvicesChain(
postRoutingAdvices_,
0,
req,
callbackPtr,
[callbackPtr,
&filters,
req,
&ctrlInfo,
this,
wsKey = std::move(wsKey),
wsConnPtr]() mutable {
if (!filters.empty())
{
filters_function::doFilters(
filters,
doAdvicesChain(postRoutingAdvices_,
0,
req,
callbackPtr,
[callbackPtr,
&filters,
req,
callbackPtr,
[this,
wsKey = std::move(wsKey),
callbackPtr,
wsConnPtr = std::move(wsConnPtr),
req,
&ctrlInfo]() mutable {
doControllerHandler(ctrlInfo,
wsKey,
req,
std::move(*callbackPtr),
wsConnPtr);
});
}
else
{
doControllerHandler(ctrlInfo,
wsKey,
req,
std::move(*callbackPtr),
wsConnPtr);
}
});
&ctrlInfo,
this,
wsKey = std::move(wsKey),
wsConnPtr]() mutable {
if (!filters.empty())
{
filters_function::doFilters(
filters,
req,
callbackPtr,
[this,
wsKey = std::move(wsKey),
callbackPtr,
wsConnPtr = std::move(wsConnPtr),
req,
&ctrlInfo]() mutable {
doPreHandlingAdvices(
ctrlInfo,
wsKey,
req,
std::move(*callbackPtr),
wsConnPtr);
});
}
else
{
doPreHandlingAdvices(ctrlInfo,
wsKey,
req,
std::move(
*callbackPtr),
wsConnPtr);
}
});
}
return;
}
@ -247,52 +337,6 @@ void WebsocketControllersRouter::doControllerHandler(
std::function<void(const HttpResponsePtr &)> &&callback,
const WebSocketConnectionImplPtr &wsConnPtr)
{
if (req->method() == Options)
{
auto resp = HttpResponse::newHttpResponse();
resp->setContentTypeCode(ContentType::CT_TEXT_PLAIN);
std::string methods = "OPTIONS,";
if (routerItem.binders_[Get] && routerItem.binders_[Get]->isCORS_)
{
methods.append("GET,HEAD,");
}
if (routerItem.binders_[Post] && routerItem.binders_[Post]->isCORS_)
{
methods.append("POST,");
}
if (routerItem.binders_[Put] && routerItem.binders_[Put]->isCORS_)
{
methods.append("PUT,");
}
if (routerItem.binders_[Delete] && routerItem.binders_[Delete]->isCORS_)
{
methods.append("DELETE,");
}
if (routerItem.binders_[Patch] && routerItem.binders_[Patch]->isCORS_)
{
methods.append("PATCH,");
}
methods.resize(methods.length() - 1);
resp->addHeader("ALLOW", methods);
auto &origin = req->getHeader("Origin");
if (origin.empty())
{
resp->addHeader("Access-Control-Allow-Origin", "*");
}
else
{
resp->addHeader("Access-Control-Allow-Origin", origin);
}
resp->addHeader("Access-Control-Allow-Methods", methods);
auto &headers = req->getHeaderBy("access-control-request-headers");
if (!headers.empty())
{
resp->addHeader("Access-Control-Allow-Headers", headers);
}
callback(resp);
return;
}
wsKey.append("258EAFA5-E914-47DA-95CA-C5AB0DC85B11");
unsigned char accKey[SHA_DIGEST_LENGTH];
SHA1(reinterpret_cast<const unsigned char *>(wsKey.c_str()),
@ -304,6 +348,10 @@ void WebsocketControllersRouter::doControllerHandler(
resp->addHeader("Upgrade", "websocket");
resp->addHeader("Connection", "Upgrade");
resp->addHeader("Sec-WebSocket-Accept", base64Key);
for (auto &advice : postHandlingAdvices_)
{
advice(req, resp);
}
callback(resp);
auto ctrlPtr = routerItem.binders_[req->method()]->controller_;
wsConnPtr->setMessageCallback(
@ -335,4 +383,4 @@ void WebsocketControllersRouter::init()
}
}
}
}
}

View File

@ -38,9 +38,21 @@ class WebsocketControllersRouter : public trantor::NonCopyable
AdviceChainCallback &&)>>
&postRoutingAdvices,
const std::vector<std::function<void(const HttpRequestPtr &)>>
&postRoutingObservers)
&postRoutingObservers,
const std::vector<std::function<void(const HttpRequestPtr &,
AdviceCallback &&,
AdviceChainCallback &&)>>
&preHandlingAdvices,
const std::vector<std::function<void(const HttpRequestPtr &)>>
&preHandlingObservers,
const std::vector<std::function<void(const HttpRequestPtr &,
const HttpResponsePtr &)>>
&postHandlingAdvices)
: postRoutingAdvices_(postRoutingAdvices),
postRoutingObservers_(postRoutingObservers)
postRoutingObservers_(postRoutingObservers),
preHandlingAdvices_(preHandlingAdvices),
preHandlingObservers_(preHandlingObservers),
postHandlingAdvices_(postHandlingAdvices)
{
}
void registerWebSocketController(
@ -78,11 +90,26 @@ class WebsocketControllersRouter : public trantor::NonCopyable
&postRoutingAdvices_;
const std::vector<std::function<void(const HttpRequestPtr &)>>
&postRoutingObservers_;
const std::vector<std::function<void(const HttpRequestPtr &,
AdviceCallback &&,
AdviceChainCallback &&)>>
&preHandlingAdvices_;
const std::vector<std::function<void(const HttpRequestPtr &)>>
&preHandlingObservers_;
const std::vector<
std::function<void(const HttpRequestPtr &, const HttpResponsePtr &)>>
&postHandlingAdvices_;
void doControllerHandler(
const WebSocketControllerRouterItem &routerItem,
std::string &wsKey,
const HttpRequestImplPtr &req,
std::function<void(const HttpResponsePtr &)> &&callback,
const WebSocketConnectionImplPtr &wsConnPtr);
void doPreHandlingAdvices(
const WebSocketControllerRouterItem &routerItem,
std::string &wsKey,
const HttpRequestImplPtr &req,
std::function<void(const HttpResponsePtr &)> &&callback,
const WebSocketConnectionImplPtr &wsConnPtr);
};
} // namespace drogon