Add rate limiter (#1409)
This commit is contained in:
parent
37a10318ff
commit
164972e2d3
|
@ -271,7 +271,12 @@ set(DROGON_SOURCES
|
|||
lib/src/Utilities.cc
|
||||
lib/src/WebSocketClientImpl.cc
|
||||
lib/src/WebSocketConnectionImpl.cc
|
||||
lib/src/WebsocketControllersRouter.cc)
|
||||
lib/src/WebsocketControllersRouter.cc
|
||||
lib/src/RateLimiter.cc
|
||||
lib/src/FixedWindowRateLimiter.cc
|
||||
lib/src/SlidingWindowRateLimiter.cc
|
||||
lib/src/TokenBucketRateLimiter.cc
|
||||
lib/src/Hodor.cc)
|
||||
set(private_headers
|
||||
lib/src/AOPAdvice.h
|
||||
lib/src/CacheFile.h
|
||||
|
@ -300,7 +305,10 @@ set(private_headers
|
|||
lib/src/TaskTimeoutFlag.h
|
||||
lib/src/WebSocketClientImpl.h
|
||||
lib/src/WebSocketConnectionImpl.h
|
||||
lib/src/WebsocketControllersRouter.h)
|
||||
lib/src/WebsocketControllersRouter.h
|
||||
lib/src/FixedWindowRateLimiter.h
|
||||
lib/src/SlidingWindowRateLimiter.h
|
||||
lib/src/TokenBucketRateLimiter.h)
|
||||
|
||||
if (NOT WIN32)
|
||||
set(DROGON_SOURCES
|
||||
|
@ -556,6 +564,7 @@ set(DROGON_HEADERS
|
|||
lib/inc/drogon/drogon_callbacks.h
|
||||
lib/inc/drogon/PubSubService.h
|
||||
lib/inc/drogon/drogon_test.h
|
||||
lib/inc/drogon/RateLimiter.h
|
||||
${CMAKE_CURRENT_BINARY_DIR}/exports/drogon/exports.h)
|
||||
set(private_headers
|
||||
${private_headers}
|
||||
|
@ -697,7 +706,8 @@ set(DROGON_PLUGIN_HEADERS
|
|||
lib/inc/drogon/plugins/Plugin.h
|
||||
lib/inc/drogon/plugins/SecureSSLRedirector.h
|
||||
lib/inc/drogon/plugins/AccessLogger.h
|
||||
lib/inc/drogon/plugins/RealIpResolver.h)
|
||||
lib/inc/drogon/plugins/RealIpResolver.h
|
||||
lib/inc/drogon/plugins/Hodor.h)
|
||||
install(FILES ${DROGON_PLUGIN_HEADERS}
|
||||
DESTINATION ${INSTALL_INCLUDE_DIR}/drogon/plugins)
|
||||
|
||||
|
|
|
@ -0,0 +1,67 @@
|
|||
#pragma once
|
||||
#include <drogon/exports.h>
|
||||
#include <memory>
|
||||
#include <chrono>
|
||||
#include <mutex>
|
||||
|
||||
namespace drogon
|
||||
{
|
||||
enum class DROGON_EXPORT RateLimiterType
|
||||
{
|
||||
kFixedWindow,
|
||||
kSlidingWindow,
|
||||
kTokenBucket
|
||||
};
|
||||
inline RateLimiterType stringToRateLimiterType(const std::string &type)
|
||||
{
|
||||
if (type == "fixedWindow" || type == "fixed_window")
|
||||
return RateLimiterType::kFixedWindow;
|
||||
else if (type == "slidingWindow" || type == "sliding_window")
|
||||
return RateLimiterType::kSlidingWindow;
|
||||
return RateLimiterType::kTokenBucket;
|
||||
}
|
||||
class DROGON_EXPORT RateLimiter;
|
||||
using RateLimiterPtr = std::shared_ptr<RateLimiter>;
|
||||
/**
|
||||
* @brief This class is used to limit the number of requests per second
|
||||
*
|
||||
* */
|
||||
class DROGON_EXPORT RateLimiter
|
||||
{
|
||||
public:
|
||||
/**
|
||||
* @brief Create a rate limiter
|
||||
* @param type The type of the rate limiter
|
||||
* @param capacity The maximum number of requests in the time unit.
|
||||
* @param timeUnit The time unit of the rate limiter.
|
||||
* @return A rate limiter pointer
|
||||
*/
|
||||
static RateLimiterPtr newRateLimiter(
|
||||
RateLimiterType type,
|
||||
size_t capacity,
|
||||
std::chrono::duration<double> timeUnit = std::chrono::seconds(60));
|
||||
/**
|
||||
* @brief Check if a request is allowed
|
||||
*
|
||||
* @return true The request is allowed
|
||||
* @return false The request is not allowed
|
||||
*/
|
||||
virtual bool isAllowed() = 0;
|
||||
};
|
||||
class DROGON_EXPORT SafeRateLimiter : public RateLimiter
|
||||
{
|
||||
public:
|
||||
SafeRateLimiter(RateLimiterPtr limiter) : limiter_(limiter)
|
||||
{
|
||||
}
|
||||
bool isAllowed() override
|
||||
{
|
||||
std::lock_guard<std::mutex> lock(mutex_);
|
||||
return limiter_->isAllowed();
|
||||
}
|
||||
|
||||
private:
|
||||
RateLimiterPtr limiter_;
|
||||
std::mutex mutex_;
|
||||
};
|
||||
} // namespace drogon
|
|
@ -30,6 +30,7 @@
|
|||
#include <drogon/plugins/SecureSSLRedirector.h>
|
||||
#include <drogon/plugins/AccessLogger.h>
|
||||
#include <drogon/plugins/RealIpResolver.h>
|
||||
#include <drogon/plugins/Hodor.h>
|
||||
#include <drogon/Cookie.h>
|
||||
#include <drogon/Session.h>
|
||||
#include <drogon/IOThreadStorage.h>
|
||||
|
|
|
@ -0,0 +1,146 @@
|
|||
/**
|
||||
* @file Hodor.h
|
||||
* @author An Tao
|
||||
*
|
||||
* Copyright 2018, An Tao. All rights reserved.
|
||||
* https://github.com/an-tao/drogon
|
||||
* Use of this source code is governed by a MIT license
|
||||
* that can be found in the License file.
|
||||
*
|
||||
* Drogon
|
||||
*
|
||||
*/
|
||||
#pragma once
|
||||
#include <drogon/RateLimiter.h>
|
||||
#include <drogon/plugins/Plugin.h>
|
||||
#include <drogon/plugins/RealIpResolver.h>
|
||||
#include <drogon/HttpAppFramework.h>
|
||||
#include <drogon/utils/optional.h>
|
||||
#include <drogon/CacheMap.h>
|
||||
#include <regex>
|
||||
|
||||
namespace drogon
|
||||
{
|
||||
namespace plugin
|
||||
{
|
||||
/**
|
||||
* @brief The Hodor plugin implements a global rate limiter that limits the
|
||||
* number of requests in a particular time unit.
|
||||
* The json configuration is as follows:
|
||||
*
|
||||
* @code
|
||||
{
|
||||
"name": "drogon::plugin::Hodor",
|
||||
"dependencies": [],
|
||||
"config": {
|
||||
// The algorithm used to limit the number of requests.
|
||||
// The default value is "token_bucket". other values are "fixed_window"
|
||||
or "sliding_window".
|
||||
"algorithm": "token_bucket",
|
||||
// a regular expression (for matching the path of a request) list for
|
||||
URLs that have to be limited. if the list is empty, all URLs are limited.
|
||||
"urls": ["^/api/.*", ...],
|
||||
// The time unit in seconds. the default value is 60.
|
||||
"time_unit": 60,
|
||||
// The maximum number of requests in a time unit. the default value 0
|
||||
means no limit.
|
||||
"capacity": 0,
|
||||
// The maximum number of requests in a time unit for a single IP. the
|
||||
default value 0 means no limit.
|
||||
"ip_capacity": 0,
|
||||
// The maximum number of requests in a time unit for a single user.
|
||||
a function must be provided to the plugin to get the user id from the request.
|
||||
the default value 0 means no limit.
|
||||
"user_capacity": 0,
|
||||
// Use the RealIpResolver plugin to get the real IP address of the
|
||||
request. if this option is true, the RealIpResolver plugin should be added to
|
||||
the dependencies list. the default value is false.
|
||||
"use_real_ip_resolver": false,
|
||||
// Multiple threads mode: the default value is true. if this option is
|
||||
true, some mutexes are used for thread-safe.
|
||||
"multi_threads": true,
|
||||
// The message body of the response when the request is rejected.
|
||||
"rejection_message": "Too many requests",
|
||||
// In seconds, the minimum expiration time of the limiters for different
|
||||
IPs or users. the default value is 600.
|
||||
"limiter_expire_time": 600,
|
||||
"sub_limits": [
|
||||
{
|
||||
"urls": ["^/api/1/.*", ...],
|
||||
"capacity": 0,
|
||||
"ip_capacity": 0,
|
||||
"user_capacity": 0
|
||||
},...
|
||||
]
|
||||
}
|
||||
}
|
||||
@endcode
|
||||
*
|
||||
* Enable the plugin by adding the configuration to the list of plugins in the
|
||||
* configuration file.
|
||||
* */
|
||||
class DROGON_EXPORT Hodor : public drogon::Plugin<Hodor>
|
||||
{
|
||||
public:
|
||||
Hodor()
|
||||
{
|
||||
}
|
||||
void initAndStart(const Json::Value &config) override;
|
||||
void shutdown() override;
|
||||
/**
|
||||
* @brief the method is used to set a function to get the user id from the
|
||||
* request. users should call this method after calling the app().run()
|
||||
* method. etc. use the beginning advice of AOP.
|
||||
* */
|
||||
void setUserIdGetter(
|
||||
std::function<optional<std::string>(const HttpRequestPtr &)> func)
|
||||
{
|
||||
userIdGetter_ = std::move(func);
|
||||
}
|
||||
/**
|
||||
* @brief the method is used to set a function to create the response when
|
||||
* the rate limit is exceeded. users should call this method after calling
|
||||
* the app().run() method. etc. use the beginning advice of AOP.
|
||||
* */
|
||||
void setRejectResponseFactory(
|
||||
std::function<HttpResponsePtr(const HttpRequestPtr &)> func)
|
||||
{
|
||||
rejectResponseFactory_ = std::move(func);
|
||||
}
|
||||
|
||||
private:
|
||||
struct LimitStrategy
|
||||
{
|
||||
std::regex urlsRegex;
|
||||
size_t capacity{0};
|
||||
size_t ipCapacity{0};
|
||||
size_t userCapacity{0};
|
||||
bool regexFlag{false};
|
||||
RateLimiterPtr globalLimiterPtr;
|
||||
std::unique_ptr<CacheMap<std::string, RateLimiterPtr>> ipLimiterMapPtr;
|
||||
std::unique_ptr<CacheMap<std::string, RateLimiterPtr>>
|
||||
userLimiterMapPtr;
|
||||
};
|
||||
LimitStrategy makeLimitStrategy(const Json::Value &config);
|
||||
std::vector<LimitStrategy> limitStrategies_;
|
||||
RateLimiterType algorithm_{RateLimiterType::kTokenBucket};
|
||||
std::chrono::duration<double> timeUnit_{1.0};
|
||||
bool multiThreads_{true};
|
||||
bool useRealIpResolver_{false};
|
||||
size_t limiterExpireTime_{600};
|
||||
std::function<optional<std::string>(const drogon::HttpRequestPtr &)>
|
||||
userIdGetter_;
|
||||
std::function<HttpResponsePtr(const HttpRequestPtr &)>
|
||||
rejectResponseFactory_;
|
||||
|
||||
void onHttpRequest(const HttpRequestPtr &,
|
||||
AdviceCallback &&,
|
||||
AdviceChainCallback &&);
|
||||
bool checkLimit(const HttpRequestPtr &req,
|
||||
const LimitStrategy &strategy,
|
||||
const std::string &ip,
|
||||
const optional<std::string> &userId);
|
||||
HttpResponsePtr rejectResponse_;
|
||||
};
|
||||
} // namespace plugin
|
||||
} // namespace drogon
|
|
@ -8,6 +8,7 @@
|
|||
|
||||
#include <drogon/plugins/Plugin.h>
|
||||
#include <trantor/net/InetAddress.h>
|
||||
#include <drogon/HttpRequest.h>
|
||||
#include <vector>
|
||||
|
||||
namespace drogon
|
||||
|
|
|
@ -0,0 +1,31 @@
|
|||
#include "FixedWindowRateLimiter.h"
|
||||
|
||||
using namespace drogon;
|
||||
|
||||
FixedWindowRateLimiter::FixedWindowRateLimiter(
|
||||
size_t capacity,
|
||||
std::chrono::duration<double> timeUnit)
|
||||
: capacity_(capacity),
|
||||
lastTime_(std::chrono::steady_clock::now()),
|
||||
timeUnit_(timeUnit)
|
||||
{
|
||||
}
|
||||
// implementation of the fixed window algorithm
|
||||
|
||||
bool FixedWindowRateLimiter::isAllowed()
|
||||
{
|
||||
auto now = std::chrono::steady_clock::now();
|
||||
auto duration = std::chrono::duration_cast<std::chrono::duration<double>>(
|
||||
now - lastTime_);
|
||||
if (duration >= timeUnit_)
|
||||
{
|
||||
currentRequests_ = 0;
|
||||
lastTime_ = now;
|
||||
}
|
||||
if (currentRequests_ < capacity_)
|
||||
{
|
||||
currentRequests_++;
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
|
@ -0,0 +1,20 @@
|
|||
#pragma once
|
||||
|
||||
#include <drogon/RateLimiter.h>
|
||||
#include <chrono>
|
||||
namespace drogon
|
||||
{
|
||||
class FixedWindowRateLimiter : public RateLimiter
|
||||
{
|
||||
public:
|
||||
FixedWindowRateLimiter(size_t capacity,
|
||||
std::chrono::duration<double> timeUnit);
|
||||
bool isAllowed() override;
|
||||
|
||||
private:
|
||||
size_t capacity_;
|
||||
size_t currentRequests_{0};
|
||||
std::chrono::steady_clock::time_point lastTime_;
|
||||
std::chrono::duration<double> timeUnit_;
|
||||
};
|
||||
} // namespace drogon
|
|
@ -0,0 +1,236 @@
|
|||
#include <drogon/plugins/Hodor.h>
|
||||
#include <drogon/plugins/RealIpResolver.h>
|
||||
|
||||
using namespace drogon::plugin;
|
||||
Hodor::LimitStrategy Hodor::makeLimitStrategy(const Json::Value &config)
|
||||
{
|
||||
LimitStrategy strategy;
|
||||
strategy.capacity = config.get("capacity", 0).asUInt();
|
||||
if (config.isMember("urls") && config["urls"].isArray())
|
||||
{
|
||||
std::string regexString;
|
||||
for (auto &str : config["urls"])
|
||||
{
|
||||
assert(str.isString());
|
||||
regexString.append("(").append(str.asString()).append(")|");
|
||||
}
|
||||
if (!regexString.empty())
|
||||
{
|
||||
regexString.resize(regexString.length() - 1);
|
||||
strategy.urlsRegex = std::regex(regexString);
|
||||
strategy.regexFlag = true;
|
||||
}
|
||||
}
|
||||
|
||||
if (strategy.capacity > 0)
|
||||
{
|
||||
if (multiThreads_)
|
||||
{
|
||||
strategy.globalLimiterPtr = std::make_shared<SafeRateLimiter>(
|
||||
RateLimiter::newRateLimiter(algorithm_,
|
||||
strategy.capacity,
|
||||
timeUnit_));
|
||||
}
|
||||
else
|
||||
{
|
||||
strategy.globalLimiterPtr =
|
||||
RateLimiter::newRateLimiter(algorithm_,
|
||||
strategy.capacity,
|
||||
timeUnit_);
|
||||
}
|
||||
}
|
||||
strategy.ipCapacity = config.get("ip_capacity", 0).asUInt();
|
||||
if (strategy.ipCapacity > 0)
|
||||
{
|
||||
strategy.ipLimiterMapPtr =
|
||||
std::make_unique<CacheMap<std::string, RateLimiterPtr>>(
|
||||
drogon::app().getLoop(),
|
||||
timeUnit_.count() / 60 < 1 ? 1 : timeUnit_.count() / 60,
|
||||
2,
|
||||
100);
|
||||
}
|
||||
|
||||
strategy.userCapacity = config.get("user_capacity", 0).asUInt();
|
||||
if (strategy.userCapacity > 0)
|
||||
{
|
||||
strategy.userLimiterMapPtr =
|
||||
std::make_unique<CacheMap<std::string, RateLimiterPtr>>(
|
||||
drogon::app().getLoop(),
|
||||
timeUnit_.count() / 60 < 1 ? 1 : timeUnit_.count() / 60,
|
||||
2,
|
||||
100);
|
||||
}
|
||||
return strategy;
|
||||
}
|
||||
void Hodor::initAndStart(const Json::Value &config)
|
||||
{
|
||||
algorithm_ = stringToRateLimiterType(
|
||||
config.get("algorithm", "token_bucket").asString());
|
||||
timeUnit_ = std::chrono::seconds(config.get("time_unit", 60).asUInt());
|
||||
|
||||
multiThreads_ = config.get("multi_threads", true).asBool();
|
||||
|
||||
useRealIpResolver_ = config.get("use_real_ip_resolver", false).asBool();
|
||||
rejectResponse_ = HttpResponse::newHttpResponse();
|
||||
rejectResponse_->setStatusCode(k429TooManyRequests);
|
||||
rejectResponse_->setBody(
|
||||
config.get("rejection_message", "Too many requests").asString());
|
||||
rejectResponse_->setCloseConnection(true);
|
||||
limiterExpireTime_ =
|
||||
(std::min)(static_cast<size_t>(
|
||||
config.get("limiter_expire_time", 600).asUInt()),
|
||||
static_cast<size_t>(timeUnit_.count() * 3));
|
||||
limitStrategies_.emplace_back(makeLimitStrategy(config));
|
||||
if (config.isMember("sub_limits") && config["sub_limits"].isArray())
|
||||
{
|
||||
for (auto &subLimit : config["sub_limits"])
|
||||
{
|
||||
assert(subLimit.isObject());
|
||||
if (!subLimit["urls"].isArray() || subLimit["urls"].size() == 0)
|
||||
{
|
||||
LOG_ERROR
|
||||
<< "The urls of sub_limits must be an array and not empty!";
|
||||
continue;
|
||||
}
|
||||
if (subLimit["capacity"].asUInt() == 0 &&
|
||||
subLimit["ip_capacity"].asUInt() == 0 &&
|
||||
subLimit["user_capacity"].asUInt() == 0)
|
||||
{
|
||||
LOG_ERROR << "At least one capacity of sub_limits must be "
|
||||
"greater than 0!";
|
||||
continue;
|
||||
}
|
||||
limitStrategies_.emplace_back(makeLimitStrategy(subLimit));
|
||||
}
|
||||
}
|
||||
app().registerPreHandlingAdvice([this](const HttpRequestPtr &req,
|
||||
AdviceCallback &&acb,
|
||||
AdviceChainCallback &&accb) {
|
||||
onHttpRequest(req, std::move(acb), std::move(accb));
|
||||
});
|
||||
}
|
||||
|
||||
void Hodor::shutdown()
|
||||
{
|
||||
LOG_TRACE << "Hodor plugin is shutdown!";
|
||||
}
|
||||
bool Hodor::checkLimit(const HttpRequestPtr &req,
|
||||
const LimitStrategy &strategy,
|
||||
const std::string &ip,
|
||||
const optional<std::string> &userId)
|
||||
{
|
||||
if (strategy.regexFlag)
|
||||
{
|
||||
if (!std::regex_match(req->path(), strategy.urlsRegex))
|
||||
{
|
||||
return true;
|
||||
}
|
||||
}
|
||||
if (strategy.globalLimiterPtr)
|
||||
{
|
||||
if (!strategy.globalLimiterPtr->isAllowed())
|
||||
{
|
||||
return false;
|
||||
}
|
||||
}
|
||||
if (strategy.ipCapacity > 0)
|
||||
{
|
||||
RateLimiterPtr limiterPtr;
|
||||
strategy.ipLimiterMapPtr->modify(
|
||||
ip,
|
||||
[this, &limiterPtr, &strategy](RateLimiterPtr &ptr) {
|
||||
if (!ptr)
|
||||
{
|
||||
if (multiThreads_)
|
||||
{
|
||||
ptr = std::make_shared<SafeRateLimiter>(
|
||||
RateLimiter::newRateLimiter(algorithm_,
|
||||
strategy.ipCapacity,
|
||||
timeUnit_));
|
||||
}
|
||||
else
|
||||
{
|
||||
ptr = RateLimiter::newRateLimiter(algorithm_,
|
||||
strategy.ipCapacity,
|
||||
timeUnit_);
|
||||
}
|
||||
}
|
||||
limiterPtr = ptr;
|
||||
},
|
||||
limiterExpireTime_);
|
||||
if (!limiterPtr->isAllowed())
|
||||
{
|
||||
return false;
|
||||
}
|
||||
}
|
||||
if (strategy.userCapacity > 0)
|
||||
{
|
||||
if (!userId.has_value())
|
||||
{
|
||||
return true;
|
||||
}
|
||||
RateLimiterPtr limiterPtr;
|
||||
strategy.userLimiterMapPtr->modify(
|
||||
*userId,
|
||||
[this, &strategy, &limiterPtr](RateLimiterPtr &ptr) {
|
||||
if (!ptr)
|
||||
{
|
||||
if (multiThreads_)
|
||||
{
|
||||
ptr = std::make_shared<SafeRateLimiter>(
|
||||
RateLimiter::newRateLimiter(algorithm_,
|
||||
strategy.userCapacity,
|
||||
timeUnit_));
|
||||
}
|
||||
else
|
||||
{
|
||||
ptr = RateLimiter::newRateLimiter(algorithm_,
|
||||
strategy.userCapacity,
|
||||
timeUnit_);
|
||||
}
|
||||
}
|
||||
limiterPtr = ptr;
|
||||
},
|
||||
limiterExpireTime_);
|
||||
if (!limiterPtr->isAllowed())
|
||||
{
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
void Hodor::onHttpRequest(const drogon::HttpRequestPtr &req,
|
||||
drogon::AdviceCallback &&adviceCallback,
|
||||
drogon::AdviceChainCallback &&chainCallback)
|
||||
{
|
||||
std::string ip;
|
||||
if (useRealIpResolver_)
|
||||
{
|
||||
ip = drogon::plugin::RealIpResolver::GetRealAddr(req).toIp();
|
||||
}
|
||||
else
|
||||
{
|
||||
ip = req->peerAddr().toIp();
|
||||
}
|
||||
optional<std::string> userId;
|
||||
if (userIdGetter_)
|
||||
{
|
||||
userId = userIdGetter_(req);
|
||||
}
|
||||
for (auto &strategy : limitStrategies_)
|
||||
{
|
||||
if (!checkLimit(req, strategy, ip, userId))
|
||||
{
|
||||
if (rejectResponseFactory_)
|
||||
{
|
||||
adviceCallback(rejectResponseFactory_(req));
|
||||
}
|
||||
else
|
||||
{
|
||||
adviceCallback(rejectResponse_);
|
||||
}
|
||||
return;
|
||||
}
|
||||
}
|
||||
chainCallback();
|
||||
}
|
|
@ -92,7 +92,10 @@ HttpAppFrameworkImpl::HttpAppFrameworkImpl()
|
|||
postHandlingAdvices_)),
|
||||
websockCtrlsRouterPtr_(
|
||||
new WebsocketControllersRouter(postRoutingAdvices_,
|
||||
postRoutingObservers_)),
|
||||
postRoutingObservers_,
|
||||
preHandlingAdvices_,
|
||||
preHandlingObservers_,
|
||||
postHandlingAdvices_)),
|
||||
listenerManagerPtr_(new ListenerManager),
|
||||
pluginsManagerPtr_(new PluginsManager),
|
||||
dbClientManagerPtr_(new orm::DbClientManager),
|
||||
|
|
|
@ -0,0 +1,24 @@
|
|||
#include <drogon/RateLimiter.h>
|
||||
#include "FixedWindowRateLimiter.h"
|
||||
#include "SlidingWindowRateLimiter.h"
|
||||
#include "TokenBucketRateLimiter.h"
|
||||
|
||||
using namespace drogon;
|
||||
|
||||
RateLimiterPtr RateLimiter::newRateLimiter(
|
||||
RateLimiterType type,
|
||||
size_t capacity,
|
||||
std::chrono::duration<double> timeUnit)
|
||||
{
|
||||
switch (type)
|
||||
{
|
||||
case RateLimiterType::kFixedWindow:
|
||||
return std::make_shared<FixedWindowRateLimiter>(capacity, timeUnit);
|
||||
case RateLimiterType::kSlidingWindow:
|
||||
return std::make_shared<SlidingWindowRateLimiter>(capacity,
|
||||
timeUnit);
|
||||
case RateLimiterType::kTokenBucket:
|
||||
return std::make_shared<TokenBucketRateLimiter>(capacity, timeUnit);
|
||||
}
|
||||
return std::make_shared<TokenBucketRateLimiter>(capacity, timeUnit);
|
||||
}
|
|
@ -0,0 +1,58 @@
|
|||
#include "SlidingWindowRateLimiter.h"
|
||||
#include <assert.h>
|
||||
|
||||
using namespace drogon;
|
||||
|
||||
SlidingWindowRateLimiter::SlidingWindowRateLimiter(
|
||||
size_t capacity,
|
||||
std::chrono::duration<double> timeUnit)
|
||||
: capacity_(capacity),
|
||||
unitStartTime_(std::chrono::steady_clock::now()),
|
||||
lastTime_(unitStartTime_),
|
||||
timeUnit_(timeUnit)
|
||||
{
|
||||
}
|
||||
// implementation of the sliding window algorithm
|
||||
bool SlidingWindowRateLimiter::isAllowed()
|
||||
{
|
||||
auto now = std::chrono::steady_clock::now();
|
||||
unitStartTime_ =
|
||||
unitStartTime_ +
|
||||
std::chrono::duration_cast<decltype(unitStartTime_)::duration>(
|
||||
std::chrono::duration<double>(
|
||||
static_cast<double>((uint64_t)(
|
||||
std::chrono::duration_cast<std::chrono::duration<double>>(
|
||||
now - unitStartTime_)
|
||||
.count() /
|
||||
timeUnit_.count())) *
|
||||
timeUnit_.count()));
|
||||
|
||||
if (unitStartTime_ > lastTime_)
|
||||
{
|
||||
auto duration =
|
||||
std::chrono::duration_cast<std::chrono::duration<double>>(
|
||||
unitStartTime_ - lastTime_);
|
||||
auto startTime = lastTime_;
|
||||
if (duration >= timeUnit_)
|
||||
{
|
||||
previousRequests_ = 0;
|
||||
}
|
||||
else
|
||||
{
|
||||
previousRequests_ = currentRequests_;
|
||||
}
|
||||
currentRequests_ = 0;
|
||||
}
|
||||
auto coef = std::chrono::duration_cast<std::chrono::duration<double>>(
|
||||
now - unitStartTime_) /
|
||||
timeUnit_;
|
||||
assert(coef <= 1.0);
|
||||
auto count = previousRequests_ * (1.0 - coef) + currentRequests_;
|
||||
if (count < capacity_)
|
||||
{
|
||||
currentRequests_++;
|
||||
lastTime_ = now;
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
|
@ -0,0 +1,22 @@
|
|||
#pragma once
|
||||
#include <drogon/RateLimiter.h>
|
||||
#include <chrono>
|
||||
|
||||
namespace drogon
|
||||
{
|
||||
class SlidingWindowRateLimiter : public RateLimiter
|
||||
{
|
||||
public:
|
||||
SlidingWindowRateLimiter(size_t capacity,
|
||||
std::chrono::duration<double> timeUnit);
|
||||
bool isAllowed() override;
|
||||
|
||||
private:
|
||||
size_t capacity_;
|
||||
size_t currentRequests_{0};
|
||||
size_t previousRequests_{0};
|
||||
std::chrono::steady_clock::time_point unitStartTime_;
|
||||
std::chrono::steady_clock::time_point lastTime_;
|
||||
std::chrono::duration<double> timeUnit_;
|
||||
};
|
||||
} // namespace drogon
|
|
@ -0,0 +1,30 @@
|
|||
#include "TokenBucketRateLimiter.h"
|
||||
|
||||
using namespace drogon;
|
||||
|
||||
TokenBucketRateLimiter::TokenBucketRateLimiter(
|
||||
size_t capacity,
|
||||
std::chrono::duration<double> timeUnit)
|
||||
: capacity_(capacity),
|
||||
lastTime_(std::chrono::steady_clock::now()),
|
||||
timeUnit_(timeUnit)
|
||||
{
|
||||
}
|
||||
|
||||
// implementation of the token bucket algorithm
|
||||
bool TokenBucketRateLimiter::isAllowed()
|
||||
{
|
||||
auto now = std::chrono::steady_clock::now();
|
||||
auto duration = std::chrono::duration_cast<std::chrono::duration<double>>(
|
||||
now - lastTime_);
|
||||
tokens_ += capacity_ * (duration / timeUnit_);
|
||||
if (tokens_ > capacity_)
|
||||
tokens_ = capacity_;
|
||||
lastTime_ = now;
|
||||
if (tokens_ > 1.0)
|
||||
{
|
||||
tokens_ -= 1.0;
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
|
@ -0,0 +1,20 @@
|
|||
#pragma once
|
||||
|
||||
#include <drogon/RateLimiter.h>
|
||||
|
||||
namespace drogon
|
||||
{
|
||||
class TokenBucketRateLimiter : public RateLimiter
|
||||
{
|
||||
public:
|
||||
TokenBucketRateLimiter(size_t capacity,
|
||||
std::chrono::duration<double> timeUnit);
|
||||
bool isAllowed() override;
|
||||
|
||||
private:
|
||||
size_t capacity_;
|
||||
std::chrono::steady_clock::time_point lastTime_;
|
||||
std::chrono::duration<double> timeUnit_;
|
||||
double tokens_{0};
|
||||
};
|
||||
} // namespace drogon
|
|
@ -91,6 +91,95 @@ void WebsocketControllersRouter::registerWebSocketController(
|
|||
}
|
||||
}
|
||||
|
||||
void WebsocketControllersRouter::doPreHandlingAdvices(
|
||||
const WebSocketControllerRouterItem &routerItem,
|
||||
std::string &wsKey,
|
||||
const HttpRequestImplPtr &req,
|
||||
std::function<void(const HttpResponsePtr &)> &&callback,
|
||||
const WebSocketConnectionImplPtr &wsConnPtr)
|
||||
{
|
||||
if (req->method() == Options)
|
||||
{
|
||||
auto resp = HttpResponse::newHttpResponse();
|
||||
resp->setContentTypeCode(ContentType::CT_TEXT_PLAIN);
|
||||
std::string methods = "OPTIONS,";
|
||||
if (routerItem.binders_[Get] && routerItem.binders_[Get]->isCORS_)
|
||||
{
|
||||
methods.append("GET,HEAD,");
|
||||
}
|
||||
if (routerItem.binders_[Post] && routerItem.binders_[Post]->isCORS_)
|
||||
{
|
||||
methods.append("POST,");
|
||||
}
|
||||
if (routerItem.binders_[Put] && routerItem.binders_[Put]->isCORS_)
|
||||
{
|
||||
methods.append("PUT,");
|
||||
}
|
||||
if (routerItem.binders_[Delete] && routerItem.binders_[Delete]->isCORS_)
|
||||
{
|
||||
methods.append("DELETE,");
|
||||
}
|
||||
if (routerItem.binders_[Patch] && routerItem.binders_[Patch]->isCORS_)
|
||||
{
|
||||
methods.append("PATCH,");
|
||||
}
|
||||
methods.resize(methods.length() - 1);
|
||||
resp->addHeader("ALLOW", methods);
|
||||
auto &origin = req->getHeader("Origin");
|
||||
if (origin.empty())
|
||||
{
|
||||
resp->addHeader("Access-Control-Allow-Origin", "*");
|
||||
}
|
||||
else
|
||||
{
|
||||
resp->addHeader("Access-Control-Allow-Origin", origin);
|
||||
}
|
||||
resp->addHeader("Access-Control-Allow-Methods", methods);
|
||||
auto &headers = req->getHeaderBy("access-control-request-headers");
|
||||
if (!headers.empty())
|
||||
{
|
||||
resp->addHeader("Access-Control-Allow-Headers", headers);
|
||||
}
|
||||
callback(resp);
|
||||
return;
|
||||
}
|
||||
if (!preHandlingObservers_.empty())
|
||||
{
|
||||
for (auto &observer : preHandlingObservers_)
|
||||
{
|
||||
observer(req);
|
||||
}
|
||||
}
|
||||
if (preHandlingAdvices_.empty())
|
||||
{
|
||||
doControllerHandler(
|
||||
routerItem, wsKey, req, std::move(callback), wsConnPtr);
|
||||
}
|
||||
else
|
||||
{
|
||||
auto callbackPtr =
|
||||
std::make_shared<std::function<void(const HttpResponsePtr &)>>(
|
||||
std::move(callback));
|
||||
doAdvicesChain(
|
||||
preHandlingAdvices_,
|
||||
0,
|
||||
req,
|
||||
std::make_shared<std::function<void(const HttpResponsePtr &)>>(
|
||||
[callbackPtr](const HttpResponsePtr &resp) {
|
||||
(*callbackPtr)(resp);
|
||||
}),
|
||||
[this,
|
||||
&routerItem,
|
||||
wsKey = std::move(wsKey),
|
||||
req,
|
||||
callbackPtr,
|
||||
wsConnPtr]() mutable {
|
||||
doControllerHandler(
|
||||
routerItem, wsKey, req, std::move(*callbackPtr), wsConnPtr);
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
void WebsocketControllersRouter::route(
|
||||
const HttpRequestImplPtr &req,
|
||||
std::function<void(const HttpResponsePtr &)> &&callback,
|
||||
|
@ -149,7 +238,7 @@ void WebsocketControllersRouter::route(
|
|||
wsConnPtr,
|
||||
&ctrlInfo,
|
||||
this]() mutable {
|
||||
doControllerHandler(
|
||||
doPreHandlingAdvices(
|
||||
ctrlInfo,
|
||||
wsKey,
|
||||
req,
|
||||
|
@ -159,7 +248,7 @@ void WebsocketControllersRouter::route(
|
|||
}
|
||||
else
|
||||
{
|
||||
doControllerHandler(
|
||||
doPreHandlingAdvices(
|
||||
ctrlInfo, wsKey, req, std::move(callback), wsConnPtr);
|
||||
}
|
||||
return;
|
||||
|
@ -169,46 +258,47 @@ void WebsocketControllersRouter::route(
|
|||
auto callbackPtr = std::make_shared<
|
||||
std::function<void(const HttpResponsePtr &)>>(
|
||||
std::move(callback));
|
||||
doAdvicesChain(
|
||||
postRoutingAdvices_,
|
||||
0,
|
||||
req,
|
||||
callbackPtr,
|
||||
[callbackPtr,
|
||||
&filters,
|
||||
req,
|
||||
&ctrlInfo,
|
||||
this,
|
||||
wsKey = std::move(wsKey),
|
||||
wsConnPtr]() mutable {
|
||||
if (!filters.empty())
|
||||
{
|
||||
filters_function::doFilters(
|
||||
filters,
|
||||
doAdvicesChain(postRoutingAdvices_,
|
||||
0,
|
||||
req,
|
||||
callbackPtr,
|
||||
[callbackPtr,
|
||||
&filters,
|
||||
req,
|
||||
callbackPtr,
|
||||
[this,
|
||||
wsKey = std::move(wsKey),
|
||||
callbackPtr,
|
||||
wsConnPtr = std::move(wsConnPtr),
|
||||
req,
|
||||
&ctrlInfo]() mutable {
|
||||
doControllerHandler(ctrlInfo,
|
||||
wsKey,
|
||||
req,
|
||||
std::move(*callbackPtr),
|
||||
wsConnPtr);
|
||||
});
|
||||
}
|
||||
else
|
||||
{
|
||||
doControllerHandler(ctrlInfo,
|
||||
wsKey,
|
||||
req,
|
||||
std::move(*callbackPtr),
|
||||
wsConnPtr);
|
||||
}
|
||||
});
|
||||
&ctrlInfo,
|
||||
this,
|
||||
wsKey = std::move(wsKey),
|
||||
wsConnPtr]() mutable {
|
||||
if (!filters.empty())
|
||||
{
|
||||
filters_function::doFilters(
|
||||
filters,
|
||||
req,
|
||||
callbackPtr,
|
||||
[this,
|
||||
wsKey = std::move(wsKey),
|
||||
callbackPtr,
|
||||
wsConnPtr = std::move(wsConnPtr),
|
||||
req,
|
||||
&ctrlInfo]() mutable {
|
||||
doPreHandlingAdvices(
|
||||
ctrlInfo,
|
||||
wsKey,
|
||||
req,
|
||||
std::move(*callbackPtr),
|
||||
wsConnPtr);
|
||||
});
|
||||
}
|
||||
else
|
||||
{
|
||||
doPreHandlingAdvices(ctrlInfo,
|
||||
wsKey,
|
||||
req,
|
||||
std::move(
|
||||
*callbackPtr),
|
||||
wsConnPtr);
|
||||
}
|
||||
});
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
@ -247,52 +337,6 @@ void WebsocketControllersRouter::doControllerHandler(
|
|||
std::function<void(const HttpResponsePtr &)> &&callback,
|
||||
const WebSocketConnectionImplPtr &wsConnPtr)
|
||||
{
|
||||
if (req->method() == Options)
|
||||
{
|
||||
auto resp = HttpResponse::newHttpResponse();
|
||||
resp->setContentTypeCode(ContentType::CT_TEXT_PLAIN);
|
||||
std::string methods = "OPTIONS,";
|
||||
if (routerItem.binders_[Get] && routerItem.binders_[Get]->isCORS_)
|
||||
{
|
||||
methods.append("GET,HEAD,");
|
||||
}
|
||||
if (routerItem.binders_[Post] && routerItem.binders_[Post]->isCORS_)
|
||||
{
|
||||
methods.append("POST,");
|
||||
}
|
||||
if (routerItem.binders_[Put] && routerItem.binders_[Put]->isCORS_)
|
||||
{
|
||||
methods.append("PUT,");
|
||||
}
|
||||
if (routerItem.binders_[Delete] && routerItem.binders_[Delete]->isCORS_)
|
||||
{
|
||||
methods.append("DELETE,");
|
||||
}
|
||||
if (routerItem.binders_[Patch] && routerItem.binders_[Patch]->isCORS_)
|
||||
{
|
||||
methods.append("PATCH,");
|
||||
}
|
||||
methods.resize(methods.length() - 1);
|
||||
resp->addHeader("ALLOW", methods);
|
||||
|
||||
auto &origin = req->getHeader("Origin");
|
||||
if (origin.empty())
|
||||
{
|
||||
resp->addHeader("Access-Control-Allow-Origin", "*");
|
||||
}
|
||||
else
|
||||
{
|
||||
resp->addHeader("Access-Control-Allow-Origin", origin);
|
||||
}
|
||||
resp->addHeader("Access-Control-Allow-Methods", methods);
|
||||
auto &headers = req->getHeaderBy("access-control-request-headers");
|
||||
if (!headers.empty())
|
||||
{
|
||||
resp->addHeader("Access-Control-Allow-Headers", headers);
|
||||
}
|
||||
callback(resp);
|
||||
return;
|
||||
}
|
||||
wsKey.append("258EAFA5-E914-47DA-95CA-C5AB0DC85B11");
|
||||
unsigned char accKey[SHA_DIGEST_LENGTH];
|
||||
SHA1(reinterpret_cast<const unsigned char *>(wsKey.c_str()),
|
||||
|
@ -304,6 +348,10 @@ void WebsocketControllersRouter::doControllerHandler(
|
|||
resp->addHeader("Upgrade", "websocket");
|
||||
resp->addHeader("Connection", "Upgrade");
|
||||
resp->addHeader("Sec-WebSocket-Accept", base64Key);
|
||||
for (auto &advice : postHandlingAdvices_)
|
||||
{
|
||||
advice(req, resp);
|
||||
}
|
||||
callback(resp);
|
||||
auto ctrlPtr = routerItem.binders_[req->method()]->controller_;
|
||||
wsConnPtr->setMessageCallback(
|
||||
|
@ -335,4 +383,4 @@ void WebsocketControllersRouter::init()
|
|||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -38,9 +38,21 @@ class WebsocketControllersRouter : public trantor::NonCopyable
|
|||
AdviceChainCallback &&)>>
|
||||
&postRoutingAdvices,
|
||||
const std::vector<std::function<void(const HttpRequestPtr &)>>
|
||||
&postRoutingObservers)
|
||||
&postRoutingObservers,
|
||||
const std::vector<std::function<void(const HttpRequestPtr &,
|
||||
AdviceCallback &&,
|
||||
AdviceChainCallback &&)>>
|
||||
&preHandlingAdvices,
|
||||
const std::vector<std::function<void(const HttpRequestPtr &)>>
|
||||
&preHandlingObservers,
|
||||
const std::vector<std::function<void(const HttpRequestPtr &,
|
||||
const HttpResponsePtr &)>>
|
||||
&postHandlingAdvices)
|
||||
: postRoutingAdvices_(postRoutingAdvices),
|
||||
postRoutingObservers_(postRoutingObservers)
|
||||
postRoutingObservers_(postRoutingObservers),
|
||||
preHandlingAdvices_(preHandlingAdvices),
|
||||
preHandlingObservers_(preHandlingObservers),
|
||||
postHandlingAdvices_(postHandlingAdvices)
|
||||
{
|
||||
}
|
||||
void registerWebSocketController(
|
||||
|
@ -78,11 +90,26 @@ class WebsocketControllersRouter : public trantor::NonCopyable
|
|||
&postRoutingAdvices_;
|
||||
const std::vector<std::function<void(const HttpRequestPtr &)>>
|
||||
&postRoutingObservers_;
|
||||
const std::vector<std::function<void(const HttpRequestPtr &,
|
||||
AdviceCallback &&,
|
||||
AdviceChainCallback &&)>>
|
||||
&preHandlingAdvices_;
|
||||
const std::vector<std::function<void(const HttpRequestPtr &)>>
|
||||
&preHandlingObservers_;
|
||||
const std::vector<
|
||||
std::function<void(const HttpRequestPtr &, const HttpResponsePtr &)>>
|
||||
&postHandlingAdvices_;
|
||||
void doControllerHandler(
|
||||
const WebSocketControllerRouterItem &routerItem,
|
||||
std::string &wsKey,
|
||||
const HttpRequestImplPtr &req,
|
||||
std::function<void(const HttpResponsePtr &)> &&callback,
|
||||
const WebSocketConnectionImplPtr &wsConnPtr);
|
||||
void doPreHandlingAdvices(
|
||||
const WebSocketControllerRouterItem &routerItem,
|
||||
std::string &wsKey,
|
||||
const HttpRequestImplPtr &req,
|
||||
std::function<void(const HttpResponsePtr &)> &&callback,
|
||||
const WebSocketConnectionImplPtr &wsConnPtr);
|
||||
};
|
||||
} // namespace drogon
|
||||
|
|
Loading…
Reference in New Issue