diff --git a/CMakeLists.txt b/CMakeLists.txt index bbb529b796a933973f9ff2b5836db5183ff8e9b5..c614126e39fcaebb2e0db493b7e0515c556095fa 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -78,6 +78,7 @@ list (APPEND opendht_HEADERS include/opendht/node_cache.h include/opendht/network_engine.h include/opendht/scheduler.h + include/opendht/rate_limiter.h include/opendht/securedht.h include/opendht/log.h include/opendht.h diff --git a/include/opendht/network_engine.h b/include/opendht/network_engine.h index 388401827cdf16ffe904ed69020704ea777e0c4d..13f30e55eddb7faa36e4d08b40f2db2895813149 100644 --- a/include/opendht/network_engine.h +++ b/include/opendht/network_engine.h @@ -28,6 +28,7 @@ #include "scheduler.h" #include "utils.h" #include "rng.h" +#include "rate_limiter.h" #include <vector> #include <string> @@ -382,11 +383,12 @@ private: /* The maximum number of nodes that we snub. There is probably little reason to increase this value. */ static constexpr unsigned BLACKLISTED_MAX {10}; - /* TODO */ + static const std::string my_v; + static std::mt19937 rd_device; - bool rateLimit(); + bool rateLimit(const SockAddr& addr); static bool isMartian(const SockAddr& addr); bool isNodeBlacklisted(const SockAddr& addr) const; @@ -462,8 +464,40 @@ private: const Logger& DHT_LOG; NodeCache cache {}; - std::queue<time_point> rate_limit_time {}; - static std::mt19937 rd_device; + + /** + * A comparator to classify IP addresses, only considering the + * first 64 bits in IPv6. + */ + struct cmpSockAddr { + bool operator()(const SockAddr& a, const SockAddr& b) { + if (a.second != b.second) + return a.second < b.second; + socklen_t start, len; + switch(a.getFamily()) { + case AF_INET: + start = offsetof(sockaddr_in, sin_addr); + len = sizeof(in_addr); + break; + case AF_INET6: + start = offsetof(sockaddr_in6, sin6_addr); + // don't consider more than 64 bits (IPv6) + len = 8; + break; + default: + start = 0; + len = a.second; + break; + } + + return std::memcmp((uint8_t*)&a.first+start, (uint8_t*)&b.first+start, len) < 0; + } + }; + // global limiting should be triggered by at least 8 different IPs + using IpLimiter = RateLimiter<MAX_REQUESTS_PER_SEC/8>; + using IpLimiterMap = std::map<SockAddr, IpLimiter, cmpSockAddr>; + IpLimiterMap address_rate_limiter {}; + RateLimiter<MAX_REQUESTS_PER_SEC> rate_limiter {}; // requests handling uint16_t transaction_id {1}; diff --git a/include/opendht/rate_limiter.h b/include/opendht/rate_limiter.h new file mode 100644 index 0000000000000000000000000000000000000000..3c0befa0cbb65ccf03a659ca864a8de803b5d6f8 --- /dev/null +++ b/include/opendht/rate_limiter.h @@ -0,0 +1,50 @@ +/* + * Copyright (C) 2016 Savoir-faire Linux Inc. + * Author : Adrien Béraud <adrien.beraud@savoirfairelinux.com> + * + * This program is free software; you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation; either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program. If not, see <http://www.gnu.org/licenses/>. + */ + +#pragma once + +#include "utils.h" +#include <queue> + +namespace dht { + +template<size_t Quota, unsigned long Perdiod=1> +class RateLimiter { +public: + /** Clear outdated records and return current quota usage */ + size_t maintain(const time_point& now) { + using namespace std::chrono; + while (not records.empty() and duration_cast<seconds>(now - records.front()) >= seconds(Perdiod)) + records.pop(); + return records.size(); + } + /** Return false if quota is reached, insert record and return true otherwise. */ + bool limit(const time_point& now) { + if (maintain(now) >= Quota) + return false; + records.emplace(now); + return true; + } + bool empty() const { + return records.empty(); + } +private: + std::queue<time_point> records {}; +}; + +} diff --git a/src/Makefile.am b/src/Makefile.am index 92affd084f5b71d72d869d05cf4cb6ec989b123c..2dde1aa9a10a13902ad501e1d4cf4c0358c593c1 100644 --- a/src/Makefile.am +++ b/src/Makefile.am @@ -35,6 +35,7 @@ nobase_include_HEADERS = \ ../include/opendht/routing_table.h \ ../include/opendht/network_engine.h \ ../include/opendht/scheduler.h \ + ../include/opendht/rate_limiter.h \ ../include/opendht/utils.h \ ../include/opendht/sockaddr.h \ ../include/opendht/infohash.h \ diff --git a/src/network_engine.cpp b/src/network_engine.cpp index 14aef4eced9953bdc840f0d72917420c5e24f04a..8d69afa5702700bfe065b628f8c48cf34c75d031 100644 --- a/src/network_engine.cpp +++ b/src/network_engine.cpp @@ -191,18 +191,28 @@ NetworkEngine::sendRequest(std::shared_ptr<Request>& request) /* Rate control for requests we receive. */ bool -NetworkEngine::rateLimit() +NetworkEngine::rateLimit(const SockAddr& addr) { - using namespace std::chrono; const auto& now = scheduler.time(); - while (not rate_limit_time.empty() and duration_cast<seconds>(now - rate_limit_time.front()) > seconds(1)) - rate_limit_time.pop(); - if (rate_limit_time.size() >= MAX_REQUESTS_PER_SEC) + // occasional IP limiter maintenance + std::bernoulli_distribution rand_trial(1./128.); + if (rand_trial(rd_device)) { + for (auto it = address_rate_limiter.begin(); it != address_rate_limiter.end();) { + if (it->second.maintain(now) == 0) + address_rate_limiter.erase(it++); + else + ++it; + } + } + + // invoke per IP rate limiter + auto it = address_rate_limiter.emplace(addr, IpLimiter{}); + if (not it.first->second.limit(now)) return false; - rate_limit_time.emplace(now); - return true; + // invoke global limiter + return rate_limiter.limit(now); } bool @@ -299,7 +309,7 @@ NetworkEngine::processMessage(const uint8_t *buf, size_t buflen, const SockAddr& if (msg.type > MessageType::Reply) { /* Rate limit requests. */ - if (!rateLimit()) { + if (!rateLimit(from)) { DHT_LOG.WARN("Dropping request due to rate limiting."); return; }