diff --git a/include/opendht/dht.h b/include/opendht/dht.h index 4672f68ab5745f2d2429ff6d4ab0d67e67bddc03..d294404be969998efb6e6dccb4218ba931a5bc17 100644 --- a/include/opendht/dht.h +++ b/include/opendht/dht.h @@ -41,6 +41,8 @@ namespace dht { +struct Request; + /** * Main Dht class. * Provides a Distributed Hash Table node. @@ -292,24 +294,19 @@ private: struct SearchNode { SearchNode(std::shared_ptr<Node> node) : node(node) {} - using AnnounceStatusMap = std::map<Value::Id, std::shared_ptr<NetworkEngine::Request>>; + using AnnounceStatusMap = std::map<Value::Id, std::shared_ptr<Request>>; /** * Can we use this node to listen/announce now ? */ bool isSynced(time_point now) const { - /*if (not getStatus) - return false;*/ - return not node->isExpired(now) and + return not node->isExpired() and not token.empty() and last_get_reply >= now - Node::NODE_EXPIRE_TIME; } bool canGet(time_point now, time_point update) const { - /*if (not getStatus) - return true;*/ - return not node->isExpired(now) and + return not node->isExpired() and (now > last_get_reply + Node::NODE_EXPIRE_TIME or update > last_get_reply) and (not getStatus or not getStatus->pending()); - // and now > getStatus->last_try + Node::MAX_RESPONSE_TIME; } bool isAnnounced(Value::Id vid, const ValueType& type, time_point now) const { @@ -342,15 +339,15 @@ private: return listenStatus->pending() ? time_point::max() : listenStatus->reply_time + LISTEN_EXPIRE_TIME - REANNOUNCE_MARGIN; } - bool isBad(const time_point& now) const { - return !node || node->isExpired(now) || candidate; + bool isBad() const { + return !node || node->isExpired() || candidate; } std::shared_ptr<Node> node {}; time_point last_get_reply {time_point::min()}; /* last time received valid token */ - std::shared_ptr<NetworkEngine::Request> getStatus {}; /* get/sync status */ - std::shared_ptr<NetworkEngine::Request> listenStatus {}; + std::shared_ptr<Request> getStatus {}; /* get/sync status */ + std::shared_ptr<Request> listenStatus {}; AnnounceStatusMap acked {}; /* announcement status for a given value id */ Blob token {}; @@ -423,7 +420,7 @@ private: bool insertNode(std::shared_ptr<Node> n, time_point now, const Blob& token={}); unsigned insertBucket(const Bucket&, time_point now); - SearchNode* getNode(std::shared_ptr<Node>& n) { + SearchNode* getNode(const std::shared_ptr<Node>& n) { auto srn = std::find_if(nodes.begin(), nodes.end(), [&](SearchNode& sn) { return n == sn.node; }); @@ -450,7 +447,7 @@ private: /** * @return The number of non-good search nodes. */ - unsigned getNumberOfBadNodes(time_point now); + unsigned getNumberOfBadNodes(); /** * ret = 0 : no announce required. @@ -721,24 +718,24 @@ private: void processMessage(const uint8_t *buf, size_t buflen, const sockaddr *from, socklen_t fromlen); - void onError(std::shared_ptr<NetworkEngine::Request> node, DhtProtocolException e); + void onError(std::shared_ptr<Request> node, DhtProtocolException e); /* when our address is reported by a distant peer. */ void onReportedAddr(const InfoHash& id, sockaddr* sa , socklen_t salen); /* when we receive a ping request */ NetworkEngine::RequestAnswer onPing(std::shared_ptr<Node> node); /* when we receive a "find node" request */ NetworkEngine::RequestAnswer onFindNode(std::shared_ptr<Node> node, InfoHash& hash, want_t want); - void onFindNodeDone(std::shared_ptr<NetworkEngine::Request> status, NetworkEngine::RequestAnswer& a, std::shared_ptr<Search> sr); + void onFindNodeDone(const Request& status, NetworkEngine::RequestAnswer& a, std::shared_ptr<Search> sr); /* when we receive a "get values" request */ NetworkEngine::RequestAnswer onGetValues(std::shared_ptr<Node> node, InfoHash& hash, want_t want); - void onGetValuesDone(std::shared_ptr<NetworkEngine::Request> status, NetworkEngine::RequestAnswer& a, std::shared_ptr<Search> sr); + void onGetValuesDone(const Request& status, NetworkEngine::RequestAnswer& a, std::shared_ptr<Search> sr); /* when we receive a listen request */ NetworkEngine::RequestAnswer onListen(std::shared_ptr<Node> node, InfoHash& hash, Blob& token, size_t rid); - void onListenDone(std::shared_ptr<NetworkEngine::Request>& status, NetworkEngine::RequestAnswer& a, std::shared_ptr<Search>& sr); + void onListenDone(const Request& status, NetworkEngine::RequestAnswer& a, std::shared_ptr<Search>& sr); /* when we receive an announce request */ NetworkEngine::RequestAnswer onAnnounce(std::shared_ptr<Node> node, InfoHash& hash, Blob& token, std::vector<std::shared_ptr<Value>> v, time_point created); - void onAnnounceDone(std::shared_ptr<NetworkEngine::Request>& status, NetworkEngine::RequestAnswer& a, std::shared_ptr<Search>& sr); + void onAnnounceDone(const Request& status, NetworkEngine::RequestAnswer& a, std::shared_ptr<Search>& sr); }; } diff --git a/include/opendht/network_engine.h b/include/opendht/network_engine.h index 6cb501f656133aa71097528e04ba80250b86e7ae..fa32a822354b779db23a875d0f50734919c009c4 100644 --- a/include/opendht/network_engine.h +++ b/include/opendht/network_engine.h @@ -28,6 +28,7 @@ #include "scheduler.h" #include "utils.h" #include "rng.h" +#include "request.h" #include <vector> #include <string> @@ -73,6 +74,8 @@ private: const InfoHash failing_node_id; }; +struct ParsedMessage; + /*! * @class NetworkEngine * @brief An abstraction of communication protocol on the network. @@ -99,6 +102,7 @@ class NetworkEngine final { static const TransPrefix ANNOUNCE_VALUES; static const TransPrefix LISTEN; }; +public: /* Transaction-ids are 4-bytes long, with the first two bytes identifying * the kind of request, and the remaining two a sequence number in @@ -139,36 +143,6 @@ class NetworkEngine final { unsigned length {4}; }; - enum class MessageType { - Error = 0, - Reply, - Ping, - FindNode, - GetValues, - AnnounceValue, - Listen - }; - - struct ParsedMessage { - MessageType type; - InfoHash id; /* the id of the sender */ - InfoHash info_hash; /* hash for which values are requested */ - InfoHash target; /* target id around which to find nodes */ - TransId tid; /* transaction id */ - Blob token; /* security token */ - Value::Id value_id; /* the value id */ - time_point created { time_point::max() }; /* time when value was first created */ - Blob nodes4_raw, nodes6_raw; /* IPv4 nodes in response to a 'find' request */ - std::vector<std::shared_ptr<Node>> nodes4, nodes6; - std::vector<std::shared_ptr<Value>> values; /* values for a 'get' request */ - want_t want; /* states if ipv4 or ipv6 request */ - uint16_t error_code; /* error code in case of error */ - std::string ua; - Address addr; /* reported address by the distant node */ - void msgpack_unpack(msgpack::object o); - }; - -public: /*! * @class RequestAnswer * @brief Answer for a request. @@ -183,80 +157,10 @@ public: std::vector<std::shared_ptr<Node>> nodes4 {}; std::vector<std::shared_ptr<Node>> nodes6 {}; RequestAnswer() {} - RequestAnswer(ParsedMessage&& msg) - : ntoken(std::move(msg.token)), values(std::move(msg.values)), nodes4(std::move(msg.nodes4)), nodes6(std::move(msg.nodes6)) {} + RequestAnswer(ParsedMessage&& msg); }; - /*! - * @class Request - * @brief An atomic request destined to a node. - * @details - * A request contains data used by the NetworkEngine to process a request - * desitned to specific node and std::function callbacks to execute when the - * request is done. - */ - struct Request { - friend class dht::NetworkEngine; - - static const constexpr size_t MAX_ATTEMPT_COUNT {3}; - - std::shared_ptr<Node> node {}; /* the node to whom the request is destined. */ - time_point reply_time {time_point::min()}; /* time when we received the response to the request. */ - - bool expired() const { return expired_; } - bool completed() const { return completed_; } - bool cancelled() const { return cancelled_; } - bool pending() const { - return not cancelled_ - and not completed_ - and not expired_; - } - bool over() const { return not pending(); } - - Request() {} - - private: - Request(uint16_t tid, - std::shared_ptr<Node> node, - Blob &&msg, - std::function<void(std::shared_ptr<Request> req_status, ParsedMessage&&)> on_done, - std::function<void(std::shared_ptr<Request> req_status, bool)> on_expired, bool persistent = false) : - node(node), on_done(on_done), on_expired(on_expired), tid(tid), msg(std::move(msg)), persistent(persistent) { } - - bool isExpired(time_point now) const { - return now > last_try + Node::MAX_RESPONSE_TIME and attempt_count >= Request::MAX_ATTEMPT_COUNT - and not completed_ and not cancelled_; - } - - void cancel() { - if (not completed_ and not expired_) { - cancelled_ = true; - clear(); - } - } - - void clear() { - on_done = {}; - on_expired = {}; - msg.clear(); - } - - bool cancelled_ {false}; /* whether the request is canceled before done. */ - bool completed_ {false}; /* whether the request is completed. */ - bool expired_ {false}; - unsigned attempt_count {0}; /* number of attempt to process the request. */ - time_point start {time_point::min()}; /* time when the request is created. */ - time_point last_try {time_point::min()}; /* time of the last attempt to process the request. */ - - std::function<void(std::shared_ptr<Request> req_status, ParsedMessage&&)> on_done {}; - std::function<void(std::shared_ptr<Request> req_status, bool)> on_expired {}; - - const uint16_t tid {0}; /* the request id. */ - Blob msg {}; /* the serialized message. */ - const bool persistent {false}; /* the request is not erased upon completion. */ - }; - /** * Cancel a request. Setting req->cancelled = true is not enough in the case * a request is "persistent". @@ -354,8 +258,8 @@ private: time_point)> onAnnounce {}; public: - using RequestCb = std::function<void(std::shared_ptr<Request>, RequestAnswer&&)>; - using RequestExpiredCb = std::function<void(std::shared_ptr<Request>, bool)>; + using RequestCb = std::function<void(const Request&, RequestAnswer&&)>; + using RequestExpiredCb = std::function<void(const Request&, bool)>; NetworkEngine(Logger& log, Scheduler& scheduler) : myid(zeroes), DHT_LOG(log), scheduler(scheduler) {} NetworkEngine(InfoHash& myid, int s, int s6, Logger& log, Scheduler& scheduler, @@ -455,28 +359,13 @@ public: } std::vector<unsigned> getNodeMessageStats(bool in) { - auto stats = in ? std::vector<unsigned>{in_stats.ping, in_stats.find, in_stats.get, in_stats.listen, in_stats.put} - : std::vector<unsigned>{out_stats.ping, out_stats.find, out_stats.get, out_stats.listen, out_stats.put}; - if (in) { in_stats = {}; } - else { out_stats = {}; } - + auto& st = in ? in_stats : out_stats; + std::vector<unsigned> stats {st.ping, st.find, st.get, st.listen, st.put}; + st = {}; return stats; } - void blacklistNode(const std::shared_ptr<Node>& n) { - for (auto rit = requests.begin(); rit != requests.end();) { - if (rit->second->node == n) { - rit->second->cancel(); - requests.erase(rit++); - } else { - ++rit; - } - } - //blacklistedNodes.emplace(n); - memcpy(&blacklist[next_blacklisted], &n->ss, n->sslen); - next_blacklisted = (next_blacklisted + 1) % BLACKLISTED_MAX; - //blacklistNode(&n->id, (const sockaddr*)&n->ss, n->sslen); - } + void blacklistNode(const std::shared_ptr<Node>& n); private: /*************** @@ -508,30 +397,24 @@ private: bool rateLimit(); static bool isMartian(const sockaddr* sa, socklen_t len); - //void blacklistNode(const InfoHash* id, const sockaddr*, socklen_t); bool isNodeBlacklisted(const sockaddr*, socklen_t) const; - void pinged(Node&); - void requestStep(std::shared_ptr<Request> req) { - if (req->over()) + if (not req->pending()) return; auto now = scheduler.time(); - if (req->node->isExpired(now) or req->isExpired(now)) { - req->expired_ = true; - req->on_expired(req, true); - req->clear(); + if (req->isExpired(now)) { + req->node->setExpired(); requests.erase(req->tid); return; } else if (req->attempt_count == 1) { - req->on_expired(req, false); + req->on_expired(*req, false); } send((char*)req->msg.data(), req->msg.size(), (req->node->reply_time >= now - UDP_REPLY_TIME) ? 0 : MSG_CONFIRM, (sockaddr*)&req->node->ss, req->node->sslen); - pinged(*req->node); ++req->attempt_count; req->last_try = now; std::weak_ptr<Request> wreq = req; @@ -552,6 +435,7 @@ private: if (!e.second) { DHT_LOG.ERROR("Request already existed !"); } + request->node->requested(request); requestStep(request); } diff --git a/include/opendht/node.h b/include/opendht/node.h index dfe00c6921b1b97c044650e7716c011846ad9b7d..adf9e26bbeb86253db3cee4dcc703b0a579d23ae 100644 --- a/include/opendht/node.h +++ b/include/opendht/node.h @@ -26,8 +26,12 @@ #include <arpa/inet.h> +#include <list> + namespace dht { +class Request; + struct Node { friend class NetworkEngine; @@ -36,7 +40,6 @@ struct Node { socklen_t sslen {0}; time_point time {time_point::min()}; /* last time eared about */ time_point reply_time {time_point::min()}; /* time of last correct reply received */ - time_point pinged_time {time_point::min()}; /* time of last message sent */ Node() : ss() { std::fill_n((uint8_t*)&ss, sizeof(ss), 0); @@ -56,26 +59,23 @@ struct Node { std::string getAddrStr() const { return print_addr(ss, sslen); } - bool isExpired(time_point now) const; - bool isExpired() const { return isExpired(clock::now()); } + bool isExpired() const { return expired_; } bool isGood(time_point now) const; - bool isMessagePending(time_point now) const; + bool isMessagePending() const; NodeExport exportNode() const { return NodeExport {id, ss, sslen}; } sa_family_t getFamily() const { return ss.ss_family; } void update(const sockaddr* sa, socklen_t salen); - /** To be called when a message was sent to the node */ - void requested(time_point now); + void requested(std::shared_ptr<Request>& req); + void received(time_point now, std::shared_ptr<Request> req); - /** To be called when a message was received from the node. - Answer should be true if the message was an aswer to a request we made*/ - void received(time_point now, bool answer); + void setExpired(); /** * Resets the state of the node so it's not expired anymore. */ - void reset() { pinged = 0; } + void reset() { expired_ = false; } std::string toString() const; @@ -90,7 +90,16 @@ struct Node { static constexpr const std::chrono::seconds MAX_RESPONSE_TIME {3}; private: - unsigned pinged {0}; /* how many requests we sent since last reply */ + + std::list<std::weak_ptr<Request>> requests_ {}; + bool expired_ {false}; + + void clearPendingQueue() { + requests_.remove_if([](std::weak_ptr<Request>& w) { + return w.expired(); + }); + } + }; } diff --git a/include/opendht/request.h b/include/opendht/request.h new file mode 100644 index 0000000000000000000000000000000000000000..edf3a62de208a459bbd28d63fcb88a947300d847 --- /dev/null +++ b/include/opendht/request.h @@ -0,0 +1,112 @@ +/* + * Copyright (C) 2016 Savoir-faire Linux Inc. + * Author(s) : Adrien Béraud <adrien.beraud@savoirfairelinux.com> + * Simon Désaulniers <sim.desaulniers@gmail.com> + * + * This program is free software; you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation; either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program; if not, write to the Free Software + * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA. + */ + +namespace dht { + +class NetworkEngine; +class ParsedMessage; + +/*! + * @class Request + * @brief An atomic request destined to a node. + * @details + * A request contains data used by the NetworkEngine to process a request + * desitned to specific node and std::function callbacks to execute when the + * request is done. + */ +struct Request { + friend class dht::NetworkEngine; + std::shared_ptr<Node> node {}; /* the node to whom the request is destined. */ + time_point reply_time {time_point::min()}; /* time when we received the response to the request. */ + + enum class State + { + PENDING, + CANCELLED, + EXPIRED, + COMPLETED + }; + + bool expired() const { return state_ == State::EXPIRED; } + bool completed() const { return state_ == State::COMPLETED; } + bool cancelled() const { return state_ == State::CANCELLED; } + bool pending() const { return state_ == State::PENDING; } + bool over() const { return not pending(); } + State getState() const { return state_; } + + Request() {} + Request(uint16_t tid, + std::shared_ptr<Node> node, + Blob&& msg, + std::function<void(const Request& req_status, ParsedMessage&&)> on_done, + std::function<void(const Request& req_status, bool)> on_expired, + bool persistent = false) : + node(node), on_done(on_done), on_expired(on_expired), tid(tid), msg(std::move(msg)), persistent(persistent) { } + + void setExpired() { + if (pending()) { + state_ = Request::State::EXPIRED; + on_expired(*this, true); + clear(); + } + } + void setDone(ParsedMessage&& msg) { + if (pending() or persistent) { + state_ = Request::State::COMPLETED; + on_done(*this, std::forward<ParsedMessage>(msg)); + if (not persistent) + clear(); + } + } +private: + static const constexpr size_t MAX_ATTEMPT_COUNT {3}; + + bool isExpired(time_point now) const { + return pending() and now > last_try + Node::MAX_RESPONSE_TIME and attempt_count >= Request::MAX_ATTEMPT_COUNT; + } + + void cancel() { + if (pending()) { + state_ = State::CANCELLED; + clear(); + } + } + + void clear() { + on_done = {}; + on_expired = {}; + msg.clear(); + } + + State state_ {State::PENDING}; + + unsigned attempt_count {0}; /* number of attempt to process the request. */ + time_point start {time_point::min()}; /* time when the request is created. */ + time_point last_try {time_point::min()}; /* time of the last attempt to process the request. */ + + std::function<void(const Request& req_status, ParsedMessage&&)> on_done {}; + std::function<void(const Request& req_status, bool)> on_expired {}; + + const uint16_t tid {0}; /* the request id. */ + Blob msg {}; /* the serialized message. */ + const bool persistent {false}; /* the request is not erased upon completion. */ +}; + +} diff --git a/src/dht.cpp b/src/dht.cpp index f250207e36f86215a4c52379ddc682841e05dc3d..610db100b472e99b41155e75badaea347330d0cb 100644 --- a/src/dht.cpp +++ b/src/dht.cpp @@ -70,17 +70,6 @@ set_nonblocking(int fd, int nonblocking) static std::mt19937 rd {dht::crypto::random_device{}()}; static std::uniform_int_distribution<uint8_t> rand_byte; -static std::string -to_hex(const uint8_t *buf, size_t buflen) -{ - std::stringstream s; - s << std::hex; - for (size_t i = 0; i < buflen; i++) - s << std::setfill('0') << std::setw(2) << (unsigned)buf[i]; - s << std::dec; - return s.str(); -} - namespace dht { using namespace std::placeholders; @@ -251,7 +240,7 @@ Dht::newNode(const std::shared_ptr<Node>& node, int confirm) /* Try to get rid of an expired node. */ for (auto& n : b->nodes) { - if (not n->isExpired(now)) + if (not n->isExpired()) continue; n = node; return n; @@ -267,7 +256,7 @@ Dht::newNode(const std::shared_ptr<Node>& node, int confirm) of bad nodes fast. */ if (not n->isGood(now)) { dubious = true; - if (n->pinged_time + Node::MAX_RESPONSE_TIME < now) { + if (not n->isMessagePending()) { DHT_LOG.DEBUG("Sending ping to dubious node %s.", n->toString().c_str()); network_engine.sendPing(n, nullptr, nullptr); break; @@ -302,7 +291,7 @@ Dht::expireBuckets(RoutingTable& list) for (auto& b : list) { bool changed = false; b.nodes.remove_if([this,&changed](const std::shared_ptr<Node>& n) { - if (n->isExpired(scheduler.time())) { + if (n->isExpired()) { changed = true; return true; } @@ -320,7 +309,7 @@ Dht::Search::removeExpiredNode(time_point now) while (e != nodes.cbegin()) { e = std::prev(e); const Node& n = *e->node; - if (n.isExpired(now) and n.time + Node::NODE_EXPIRE_TIME < now) { + if (n.isExpired() and n.time + Node::NODE_EXPIRE_TIME < now) { //std::cout << "Removing expired node " << n.id << " from IPv" << (af==AF_INET?'4':'6') << " search " << id << std::endl; nodes.erase(e); return true; @@ -338,7 +327,7 @@ Dht::Search::insertNode(std::shared_ptr<Node> node, time_point now, const Blob& if (expired and nodes.empty()) return false; - if (node->ss.ss_family != af) { + if (node->getFamily() != af) { //DHT_LOG.DEBUG("Attempted to insert node in the wrong family."); return false; } @@ -346,11 +335,11 @@ Dht::Search::insertNode(std::shared_ptr<Node> node, time_point now, const Blob& const auto& nid = node->id; // Fast track for the case where the node is not relevant for this search - if (nodes.size() >= SEARCH_NODES && id.xorCmp(nid, nodes.back().node->id) > 0 && node->isExpired(now)) + if (nodes.size() >= SEARCH_NODES && id.xorCmp(nid, nodes.back().node->id) > 0 && node->isExpired()) return false; bool found = false; - unsigned num_bad_nodes = getNumberOfBadNodes(now); + unsigned num_bad_nodes = getNumberOfBadNodes(); auto n = std::find_if(nodes.begin(), nodes.end(), [&](const SearchNode& sn) { if (sn.node == node) { found = true; @@ -362,7 +351,7 @@ Dht::Search::insertNode(std::shared_ptr<Node> node, time_point now, const Blob& bool new_search_node = false; if (!found) { if (nodes.size()-num_bad_nodes >= SEARCH_NODES) { - if (node->isExpired(now)) + if (node->isExpired()) return false; if (n == nodes.end()) return false; @@ -389,7 +378,7 @@ Dht::Search::insertNode(std::shared_ptr<Node> node, time_point now, const Blob& num_bad_nodes--; auto to_remove = std::find_if(nodes.rbegin(), nodes.rend(), - [&](const SearchNode& n) { return not n.isBad(now)/* and not (n.getStatus and n.getStatus->pending(now))*/; } + [&](const SearchNode& n) { return not n.isBad()/* and not (n.getStatus and n.getStatus->pending(now))*/; } ); if (to_remove != nodes.rend()) { if (to_remove->getStatus and to_remove->getStatus->pending()) @@ -459,31 +448,29 @@ Dht::searchSendGetValues(std::shared_ptr<Search> sr, SearchNode* pn, bool update return nullptr; } - DHT_LOG.DEBUG("[search %s IPv%c] [node %s %s] sending 'get'", + DHT_LOG.DEBUG("[search %s IPv%c] [node %s] sending 'get'", sr->id.toString().c_str(), sr->af == AF_INET ? '4' : '6', - n->node->id.toString().c_str(), - print_addr(n->node->ss, n->node->sslen).c_str()); + n->node->toString().c_str()); std::weak_ptr<Search> ws = sr; auto onDone = - [this,ws](std::shared_ptr<NetworkEngine::Request> status, NetworkEngine::RequestAnswer&& answer) mutable { + [this,ws](const Request& status, NetworkEngine::RequestAnswer&& answer) mutable { if (auto sr = ws.lock()) { - auto srn = sr->getNode(status->node); + auto srn = sr->getNode(status.node); if (srn and not srn->candidate) sr->current_get_requests--; - sr->insertNode(status->node, scheduler.time(), answer.ntoken); + sr->insertNode(status.node, scheduler.time(), answer.ntoken); onGetValuesDone(status, answer, sr); } }; auto onExpired = - [this,ws](std::shared_ptr<NetworkEngine::Request> status, bool over) mutable { + [this,ws](const Request& status, bool over) mutable { if (auto sr = ws.lock()) { - auto srn = sr->getNode(status->node); + auto srn = sr->getNode(status.node); if (srn and not srn->candidate) { - DHT_LOG.DEBUG("[search %s IPv%c] [node %s %s] 'get' expired", + DHT_LOG.DEBUG("[search %s IPv%c] [node %s] 'get' expired", sr->id.toString().c_str(), sr->af == AF_INET ? '4' : '6', - srn->node->id.toString().c_str(), - print_addr(srn->node->ss, srn->node->sslen).c_str()); + srn->node->toString().c_str()); if (not over) { srn->candidate = true; //DHT_LOG.DEBUG("[search %s] sn %s now candidate... %d", @@ -495,7 +482,7 @@ Dht::searchSendGetValues(std::shared_ptr<Search> sr, SearchNode* pn, bool update } }; sr->current_get_requests++; - std::shared_ptr<NetworkEngine::Request> rstatus; + std::shared_ptr<Request> rstatus; if (sr->callbacks.empty() and sr->listeners.empty()) rstatus = network_engine.sendFindNode(n->node, sr->id, -1, onDone, onExpired); else @@ -519,7 +506,7 @@ Dht::searchStep(std::shared_ptr<Search> sr) * The accurate delay between two refills has not been strongly determined. * TODO: Emprical analysis over refill timeout. */ - if (sr->refill_time + Node::NODE_EXPIRE_TIME < now and sr->nodes.size()-sr->getNumberOfBadNodes(now) < SEARCH_NODES) { + if (sr->refill_time + Node::NODE_EXPIRE_TIME < now and sr->nodes.size()-sr->getNumberOfBadNodes() < SEARCH_NODES) { auto added = sr->refill(sr->af == AF_INET ? buckets : buckets6, now); if (added) sr->refill_time = now; @@ -555,10 +542,9 @@ Dht::searchStep(std::shared_ptr<Search> sr) if (not n.isSynced(now) or (n.candidate and t >= LISTEN_NODES)) continue; if (n.getListenTime() <= now) { - DHT_LOG.WARN("[search %s IPv%c] [node %s %s] sending 'listen'", + DHT_LOG.WARN("[search %s IPv%c] [node %s] sending 'listen'", sr->id.toString().c_str(), sr->af == AF_INET ? '4' : '6', - n.node->id.toString().c_str(), - print_addr(n.node->ss, n.node->sslen).c_str()); + n.node->toString().c_str()); //std::cout << "Sending listen to " << n.node->id << " " << print_addr(n.node->ss, n.node->sslen) << std::endl; //network_engine.cancelRequest(n.listenStatus); @@ -566,7 +552,7 @@ Dht::searchStep(std::shared_ptr<Search> sr) std::weak_ptr<Search> ws = sr; n.listenStatus = network_engine.sendListen(n.node, sr->id, n.token, - [this,ws,ls](std::shared_ptr<NetworkEngine::Request> status, + [this,ws,ls](const Request& status, NetworkEngine::RequestAnswer&& answer) mutable { /* on done */ network_engine.cancelRequest(ls); @@ -575,7 +561,7 @@ Dht::searchStep(std::shared_ptr<Search> sr) searchStep(sr); } }, - [this,ws,ls](std::shared_ptr<NetworkEngine::Request>, bool) mutable + [this,ws,ls](const Request&, bool) mutable { /* on expired */ network_engine.cancelRequest(ls); if (auto sr = ws.lock()) { @@ -606,21 +592,20 @@ Dht::searchStep(std::shared_ptr<Search> sr) continue; auto at = n.getAnnounceTime(vid, type); if ( at <= now ) { - DHT_LOG.WARN("[search %s IPv%c] [node %s %s] sending 'put' (vid: %d)", - sr->id.toString().c_str(), sr->af == AF_INET ? '4' : '6', n.node->id.toString().c_str(), - print_addr(n.node->ss, n.node->sslen).c_str(), vid); + DHT_LOG.WARN("[search %s IPv%c] [node %s] sending 'put' (vid: %d)", + sr->id.toString().c_str(), sr->af == AF_INET ? '4' : '6', n.node->toString().c_str(), vid); //std::cout << "Sending announce_value to " << n.node->id << " " << print_addr(n.node->ss, n.node->sslen) << std::endl; std::weak_ptr<Search> ws = sr; n.acked[vid] = network_engine.sendAnnounceValue(n.node, sr->id, *a.value, a.created, n.token, - [this,ws](std::shared_ptr<NetworkEngine::Request> status, NetworkEngine::RequestAnswer&& answer) mutable + [this,ws](const Request& status, NetworkEngine::RequestAnswer&& answer) mutable { /* on done */ if (auto sr = ws.lock()) { onAnnounceDone(status, answer, sr); searchStep(sr); } }, - [this,ws](std::shared_ptr<NetworkEngine::Request>, bool) mutable + [this,ws](const Request&, bool) mutable { /* on expired */ if (auto sr = ws.lock()) { searchStep(sr); } } @@ -647,9 +632,10 @@ Dht::searchStep(std::shared_ptr<Search> sr) DHT_LOG.DEBUG("[search %s IPv%c] step: sent %u requests.", sr->id.toString().c_str(), sr->af == AF_INET ? '4' : '6', i); - if (i == 0 && (size_t)std::count_if(sr->nodes.begin(), sr->nodes.end(), [&](const SearchNode& sn) { - return sn.candidate or sn.node->isExpired(now); - }) == sr->nodes.size()) + auto expiredn = (size_t)std::count_if(sr->nodes.begin(), sr->nodes.end(), [&](const SearchNode& sn) { + return sn.candidate or sn.node->isExpired(); + }); + if (i == 0 && expiredn == sr->nodes.size()) { DHT_LOG.ERROR("[search %s IPv%c] expired", sr->id.toString().c_str(), sr->af == AF_INET ? '4' : '6'); // no nodes or all expired nodes @@ -718,7 +704,7 @@ Dht::Search::insertBucket(const Bucket& b, time_point now) { unsigned inserted = 0; for (auto& n : b.nodes) { - if (not n->isExpired(now) and insertNode(n, now)) + if (not n->isExpired() and insertNode(n, now)) inserted++; } return inserted; @@ -729,7 +715,7 @@ Dht::Search::isSynced(time_point now) const { unsigned i = 0; for (const auto& n : nodes) { - if (n.node->isExpired(now) or n.candidate) + if (n.node->isExpired() or n.candidate) continue; if (not n.isSynced(now)) return false; @@ -739,10 +725,10 @@ Dht::Search::isSynced(time_point now) const return i > 0; } -unsigned Dht::Search::getNumberOfBadNodes(time_point now) { +unsigned Dht::Search::getNumberOfBadNodes() { return std::count_if(nodes.begin(), nodes.end(), [=](const SearchNode& sn) { - return sn.isBad(now); + return sn.isBad(); }); } @@ -761,7 +747,7 @@ Dht::Search::isDone(const Get& get, time_point now) const unsigned i = 0; const auto limit = std::max(get.start, now - Node::NODE_EXPIRE_TIME); for (const auto& sn : nodes) { - if (sn.node->isExpired(now) or sn.candidate) + if (sn.node->isExpired() or sn.candidate) continue; if (sn.last_get_reply < limit) return false; @@ -778,7 +764,7 @@ Dht::Search::getUpdateTime(time_point now) const const auto last_get = getLastGetTime(); unsigned i = 0, t = 0, d = 0; for (const auto& sn : nodes) { - if (sn.node->isExpired(now) or (sn.candidate and t >= TARGET_NODES)) + if (sn.node->isExpired() or (sn.candidate and t >= TARGET_NODES)) continue; bool pending = sn.getStatus and sn.getStatus->pending(); if (sn.last_get_reply < std::max(now - Node::NODE_EXPIRE_TIME, last_get) or pending) { @@ -809,7 +795,7 @@ Dht::Search::isAnnounced(Value::Id id, const ValueType& type, time_point now) co return false; unsigned i = 0; for (const auto& n : nodes) { - if (n.candidate or n.node->isExpired(now)) + if (n.candidate or n.node->isExpired()) continue; if (not n.isAnnounced(id, type, now)) return false; @@ -826,7 +812,7 @@ Dht::Search::isListening(time_point now) const return false; unsigned i = 0; for (const auto& n : nodes) { - if (n.candidate or n.node->isExpired(now)) + if (n.candidate or n.node->isExpired()) continue; if (!n.isListening(now)) return false; @@ -940,7 +926,7 @@ Dht::Search::refill(const RoutingTable& r, time_point now) { if (r.isEmpty() or r.front().af != af) return 0; unsigned added = 0; - auto num_bad_nodes = getNumberOfBadNodes(now); + auto num_bad_nodes = getNumberOfBadNodes(); auto b = r.findBucket(id); auto n = b; while (nodes.size()-num_bad_nodes < SEARCH_NODES && (std::next(n) != r.end() || b != r.begin())) { @@ -1380,8 +1366,7 @@ Dht::storageChanged(Storage& st, ValueStorage& v) } for (const auto& l : st.listeners) { - DHT_LOG.DEBUG("Storage changed. Sending update to %s %s.", - l.first->id.toString().c_str(), print_addr((sockaddr*)&l.first->ss, l.first->sslen).c_str()); + DHT_LOG.DEBUG("Storage changed. Sending update to %s.", l.first->toString().c_str()); std::vector<std::shared_ptr<Value>> vals; vals.push_back(v.data); Blob ntoken = makeToken((const sockaddr*)&l.first->ss, false); @@ -1651,12 +1636,12 @@ Dht::dumpBucket(const Bucket& b, std::ostream& out) const out << " (cached)"; out << std::endl; for (auto& n : b.nodes) { - out << " Node " << n->id << " " << print_addr((sockaddr*)&n->ss, n->sslen); + out << " Node " << n->toString(); if (n->time != n->reply_time) out << " age " << duration_cast<seconds>(now - n->time).count() << ", reply: " << duration_cast<seconds>(now - n->reply_time).count(); else out << " age " << duration_cast<seconds>(now - n->time).count(); - if (n->isExpired(now)) + if (n->isExpired()) out << " [expired]"; else if (n->isGood(now)) out << " [good]"; @@ -1695,9 +1680,9 @@ Dht::dumpSearch(const Search& sr, std::ostream& out) const out << ' ' << (findNode(n.node->id, AF_INET) || findNode(n.node->id, AF_INET6) ? '*' : ' '); out << ' ' << (n.candidate ? 'c' : ' '); out << " [" - << (n.node->isMessagePending(now) ? 'f':' '); + << (n.node->isMessagePending() ? 'f':' '); out << ' '; - out << (n.node->isExpired(now) ? 'x' : ' ') << "]"; + out << (n.node->isExpired() ? 'x' : ' ') << "]"; { bool pending {false}, expired {false}; @@ -1791,8 +1776,9 @@ Dht::getStorageLog() const std::stringstream out; for (const auto& st : store) { out << "Storage " << st.id << " " << st.listeners.size() << " list., " << st.valueCount() << " values (" << st.totalSize() << " bytes)" << std::endl; + out << " " << st.local_listeners.size() << " local listeners" << std::endl; for (const auto& l : st.listeners) { - out << " " << "Listener " << l.first->id << " " << print_addr((sockaddr*)&l.first->ss, l.first->sslen); + out << " " << "Listener " << l.first->toString(); auto since = duration_cast<seconds>(now - l.second.time); auto expires = duration_cast<seconds>(l.second.time + Node::NODE_EXPIRE_TIME - now); out << " (since " << since.count() << "s, exp in " << expires.count() << "s)" << std::endl; @@ -2218,11 +2204,11 @@ Dht::pingNode(const sockaddr *sa, socklen_t salen) } void -Dht::onError(std::shared_ptr<NetworkEngine::Request> req, DhtProtocolException e) { +Dht::onError(std::shared_ptr<Request> req, DhtProtocolException e) { if (e.getCode() == DhtProtocolException::UNAUTHORIZED) { network_engine.cancelRequest(req); unsigned cleared = 0; - for (auto& srp : req->node->ss.ss_family == AF_INET ? searches4 : searches6) { + for (auto& srp : req->node->getFamily() == AF_INET ? searches4 : searches6) { auto& sr = srp.second; for (auto& n : sr->nodes) { if (n.node != req->node) continue; @@ -2233,8 +2219,7 @@ Dht::onError(std::shared_ptr<NetworkEngine::Request> req, DhtProtocolException e break; } } - DHT_LOG.WARN("[node %s %s] token flush (%d searches affected)", - req->node->id.toString().c_str(), print_addr((sockaddr*)&req->node->ss, req->node->sslen).c_str(), cleared); + DHT_LOG.WARN("[node %s] token flush (%d searches affected)", req->node->toString().c_str(), cleared); } } @@ -2270,8 +2255,7 @@ NetworkEngine::RequestAnswer Dht::onGetValues(std::shared_ptr<Node> node, InfoHash& hash, want_t) { if (hash == zeroes) { - DHT_LOG.WARN("[node %s %s] Eek! Got get_values with no info_hash.", - node->id.toString().c_str(), print_addr(node->ss, node->sslen).c_str()); + DHT_LOG.WARN("[node %s] Eek! Got get_values with no info_hash.", node->toString().c_str()); throw DhtProtocolException {DhtProtocolException::NON_AUTHORITATIVE_INFORMATION, DhtProtocolException::GET_NO_INFOHASH}; } const auto& now = scheduler.time(); @@ -2286,17 +2270,15 @@ Dht::onGetValues(std::shared_ptr<Node> node, InfoHash& hash, want_t) std::transform(values.begin(), values.end(), answer.values.begin(), [](const ValueStorage& vs) { return vs.data; }); - DHT_LOG.DEBUG("[node %s %s] sending %u values.", - node->id.toString().c_str(), print_addr(node->ss, node->sslen).c_str(), answer.values.size()); + DHT_LOG.DEBUG("[node %s] sending %u values.", node->toString().c_str(), answer.values.size()); } else { - DHT_LOG.DEBUG("[node %s %s] sending nodes.", - node->id.toString().c_str(), print_addr(node->ss, node->sslen).c_str()); + DHT_LOG.DEBUG("[node %s] sending nodes.", node->toString().c_str()); } return answer; } void -Dht::onGetValuesDone(std::shared_ptr<NetworkEngine::Request> status, +Dht::onGetValuesDone(const Request& status, NetworkEngine::RequestAnswer& a, std::shared_ptr<Search> sr) { if (not sr) { @@ -2304,7 +2286,7 @@ Dht::onGetValuesDone(std::shared_ptr<NetworkEngine::Request> status, return; } - DHT_LOG.DEBUG("[search %s IPv%c] got reply to 'get' from %s with %u nodes", sr->id.toString().c_str(), sr->af == AF_INET ? '4' : '6', status->node->toString().c_str(), a.nodes4.size()); + DHT_LOG.DEBUG("[search %s IPv%c] got reply to 'get' from %s with %u nodes", sr->id.toString().c_str(), sr->af == AF_INET ? '4' : '6', status.node->toString().c_str(), a.nodes4.size()); if (not a.ntoken.empty()) { if (!a.values.empty()) { @@ -2338,9 +2320,8 @@ Dht::onGetValuesDone(std::shared_ptr<NetworkEngine::Request> status, l.first(l.second); } } else { - DHT_LOG.WARN("[node %s %s] no token provided. Ignoring response content.", - status->node->id.toString().c_str(), print_addr(status->node->ss, status->node->sslen).c_str()); - network_engine.blacklistNode(status->node); + DHT_LOG.WARN("[node %s] no token provided. Ignoring response content.", status.node->toString().c_str()); + network_engine.blacklistNode(status.node); } if (not sr->done) { @@ -2362,9 +2343,7 @@ Dht::onListen(std::shared_ptr<Node> node, InfoHash& hash, Blob& token, size_t ri }; } if (!tokenMatch(token, (sockaddr*)&node->ss)) { - DHT_LOG.WARN("[node %s %s] incorrect token %s for 'listen'.", - node->id.toString().c_str(), print_addr(node->ss, node->sslen).c_str(), - hash.toString().c_str(), to_hex(token.data(), token.size()).c_str()); + DHT_LOG.WARN("[node %s] incorrect token %s for 'listen'.", node->toString().c_str(), hash.toString().c_str()); throw DhtProtocolException {DhtProtocolException::UNAUTHORIZED, DhtProtocolException::LISTEN_WRONG_TOKEN}; } storageAddListener(hash, node, rid); @@ -2372,7 +2351,7 @@ Dht::onListen(std::shared_ptr<Node> node, InfoHash& hash, Blob& token, size_t ri } void -Dht::onListenDone(std::shared_ptr<NetworkEngine::Request>& status, NetworkEngine::RequestAnswer& answer, std::shared_ptr<Search>& sr) +Dht::onListenDone(const Request& status, NetworkEngine::RequestAnswer& answer, std::shared_ptr<Search>& sr) { DHT_LOG.DEBUG("[search %s] Got reply to listen.", sr->id.toString().c_str()); if (sr) { @@ -2402,20 +2381,16 @@ Dht::onAnnounce(std::shared_ptr<Node> node, InfoHash& hash, Blob& token, std::ve }; } if (!tokenMatch(token, (sockaddr*)&node->ss)) { - DHT_LOG.WARN("[node %s %s] incorrect token %s for 'put'.", - node->id.toString().c_str(), print_addr(node->ss, node->sslen).c_str(), - hash.toString().c_str(), to_hex(token.data(), token.size()).c_str()); + DHT_LOG.WARN("[node %s] incorrect token %s for 'put'.", node->toString().c_str(), hash.toString().c_str()); throw DhtProtocolException {DhtProtocolException::UNAUTHORIZED, DhtProtocolException::PUT_WRONG_TOKEN}; } { // We store a value only if we think we're part of the // SEARCH_NODES nodes around the target id. - auto closest_nodes = ( - ((sockaddr*)&node->ss)->sa_family == AF_INET ? buckets : buckets6 - ).findClosestNodes(hash, scheduler.time(), SEARCH_NODES); + auto closest_nodes = (node->getFamily() == AF_INET ? buckets : buckets6) + .findClosestNodes(hash, scheduler.time(), SEARCH_NODES); if (closest_nodes.size() >= TARGET_NODES and hash.xorCmp(closest_nodes.back()->id, myid) < 0) { - DHT_LOG.WARN("[node %s %s] announce too far from the target id. Dropping value.", - node->id.toString().c_str(), print_addr(node->ss, node->sslen).c_str()); + DHT_LOG.WARN("[node %s] announce too far from the target. Dropping value.", node->toString().c_str()); return {}; } } @@ -2460,7 +2435,7 @@ Dht::onAnnounce(std::shared_ptr<Node> node, InfoHash& hash, Blob& token, std::ve } void -Dht::onAnnounceDone(std::shared_ptr<NetworkEngine::Request>&, NetworkEngine::RequestAnswer& answer, +Dht::onAnnounceDone(const Request&, NetworkEngine::RequestAnswer& answer, std::shared_ptr<Search>& sr) { const auto& now = scheduler.time(); diff --git a/src/network_engine.cpp b/src/network_engine.cpp index 18e6f79e48fe8c84c96eceb628e59a0db8eb2fbc..254be1c29cf994786b622277a050b678406fce78 100644 --- a/src/network_engine.cpp +++ b/src/network_engine.cpp @@ -48,15 +48,49 @@ static const uint8_t v4prefix[16] = { 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0xFF, 0xFF, 0, 0, 0, 0 }; + +enum class MessageType { + Error = 0, + Reply, + Ping, + FindNode, + GetValues, + AnnounceValue, + Listen +}; + +struct ParsedMessage { + MessageType type; + InfoHash id; /* the id of the sender */ + InfoHash info_hash; /* hash for which values are requested */ + InfoHash target; /* target id around which to find nodes */ + NetworkEngine::TransId tid; /* transaction id */ + Blob token; /* security token */ + Value::Id value_id; /* the value id */ + time_point created { time_point::max() }; /* time when value was first created */ + Blob nodes4_raw, nodes6_raw; /* IPv4 nodes in response to a 'find' request */ + std::vector<std::shared_ptr<Node>> nodes4, nodes6; + std::vector<std::shared_ptr<Value>> values; /* values for a 'get' request */ + want_t want; /* states if ipv4 or ipv6 request */ + uint16_t error_code; /* error code in case of error */ + std::string ua; + Address addr; /* reported address by the distant node */ + void msgpack_unpack(msgpack::object o); +}; + +NetworkEngine::RequestAnswer::RequestAnswer(ParsedMessage&& msg) + : ntoken(std::move(msg.token)), values(std::move(msg.values)), nodes4(std::move(msg.nodes4)), nodes6(std::move(msg.nodes6)) {} + + /* Called whenever we send a request to a node, increases the ping count and, if that reaches 3, sends a ping to a new candidate. */ -void +/*void NetworkEngine::pinged(Node& n) { const auto& now = scheduler.time(); if (not n.isExpired(now)) n.requested(now); -} +}*/ void NetworkEngine::tellListener(std::shared_ptr<Node> node, uint16_t rid, InfoHash hash, want_t want, @@ -98,13 +132,14 @@ bool NetworkEngine::rateLimit() { using namespace std::chrono; - while (not rate_limit_time.empty() and duration_cast<seconds>(scheduler.time() - rate_limit_time.front()) > seconds(1)) + const auto& now = scheduler.time(); + while (not rate_limit_time.empty() and duration_cast<seconds>(now - rate_limit_time.front()) > seconds(1)) rate_limit_time.pop(); if (rate_limit_time.size() >= MAX_REQUESTS_PER_SEC) return false; - rate_limit_time.emplace(scheduler.time()); + rate_limit_time.emplace(now); return true; } @@ -144,29 +179,20 @@ NetworkEngine::isMartian(const sockaddr* sa, socklen_t len) /* The internal blacklist is an LRU cache of nodes that have sent incorrect messages. */ -/*void -NetworkEngine::blacklistNode(const InfoHash* id, const sockaddr *sa, socklen_t salen) +void +NetworkEngine::blacklistNode(const std::shared_ptr<Node>& n) { - DHT_LOG.WARN("Blacklisting broken node."); - - if (id) { - auto n = cache.getNode(*id, sa, salen, 0);//findNode(); - for () - // Discard it from any searches in progress. - auto black_list_in = [&](std::map<InfoHash, std::shared_ptr<Search>>& srs) { - for (auto& srp : srs) { - auto& sr = srp.second; - sr->nodes.erase(std::partition(sr->nodes.begin(), sr->nodes.end(), [&](SearchNode& sn) { - return sn.node != n; - }), sr->nodes.end()); - } - }; - black_list_in(searches4); - black_list_in(searches6); + for (auto rit = requests.begin(); rit != requests.end();) { + if (rit->second->node == n) { + rit->second->cancel(); + requests.erase(rit++); + } else { + ++rit; + } } - // And make sure we don't hear from it again. - -}*/ + memcpy(&blacklist[next_blacklisted], &n->ss, n->sslen); + next_blacklisted = (next_blacklisted + 1) % BLACKLISTED_MAX; +} bool NetworkEngine::isNodeBlacklisted(const sockaddr *sa, socklen_t salen) const @@ -186,7 +212,7 @@ NetworkEngine::isNodeBlacklisted(const sockaddr *sa, socklen_t salen) const } void -NetworkEngine::processMessage(const uint8_t *buf, size_t buflen, const sockaddr *from, socklen_t fromlen) +NetworkEngine::processMessage(const uint8_t *buf, size_t buflen, const sockaddr* from, socklen_t fromlen) { if (isMartian(from, fromlen)) return; @@ -233,13 +259,25 @@ NetworkEngine::processMessage(const uint8_t *buf, size_t buflen, const sockaddr uint16_t ttid = 0; if (msg.type == MessageType::Error or msg.type == MessageType::Reply) { auto reqp = requests.find(msg.tid.getTid()); - if (reqp == requests.end()) + if (reqp == requests.end()) { throw DhtProtocolException {DhtProtocolException::UNKNOWN_TID, "Can't find transaction", msg.id}; + } auto req = reqp->second; - if (req->cancelled()) + + auto node = req->node;//cache.getNode(msg.id, from, fromlen, now, 2); + node->received(now, req); + if (node->id == zeroes) { + // reply to a message sent when we didn't know the node ID. + node = cache.getNode(msg.id, from, fromlen, now, 2); + req->node = node; + } else + node->update(from, fromlen); + + if (req->cancelled()) { + DHT_LOG.ERROR("Request is cancelled: %d", msg.tid); return; + } - auto node = cache.getNode(msg.id, from, fromlen, now, 2); onNewNode(node, 2); onReportedAddr(msg.id, (sockaddr*)&msg.addr.first, msg.addr.second); switch (msg.type) { @@ -266,16 +304,14 @@ NetworkEngine::processMessage(const uint8_t *buf, size_t buflen, const sockaddr req->reply_time = scheduler.time(); deserializeNodesValues(msg); - req->completed_ = true; - req->on_done(req, std::move(msg)); - if (not req->persistent) - req->clear(); + req->setDone(std::move(msg)); break; default: break; } } else { auto node = cache.getNode(msg.id, from, fromlen, now, 1); + node->received(now, {}); onNewNode(node, 1); try { switch (msg.type) { @@ -374,7 +410,7 @@ NetworkEngine::send(const char *buf, size_t len, int flags, const sockaddr *sa, return sendto(s, buf, len, flags, sa, salen); } -std::shared_ptr<NetworkEngine::Request> +std::shared_ptr<Request> NetworkEngine::sendPing(std::shared_ptr<Node> node, RequestCb on_done, RequestExpiredCb on_expired) { auto tid = TransId {TransPrefix::PING, getNewTid()}; msgpack::sbuffer buffer; @@ -392,13 +428,13 @@ NetworkEngine::sendPing(std::shared_ptr<Node> node, RequestCb on_done, RequestEx Blob b {buffer.data(), buffer.data() + buffer.size()}; std::shared_ptr<Request> req(new Request {tid.getTid(), node, std::move(b), - [=](std::shared_ptr<Request> req_status, ParsedMessage&&) { - DHT_LOG.DEBUG("Got pong from %s", print_addr(req_status->node->ss, req_status->node->sslen).c_str()); + [=](const Request& req_status, ParsedMessage&&) { + DHT_LOG.DEBUG("Got pong from %s", req_status.node->toString().c_str()); if (on_done) { on_done(req_status, {}); } }, - [=](std::shared_ptr<Request> req_status, bool) { /* on expired */ + [=](const Request& req_status, bool) { /* on expired */ if (on_expired) { on_expired(req_status, {}); } @@ -427,7 +463,7 @@ NetworkEngine::sendPong(const sockaddr* sa, socklen_t salen, TransId tid) { send(buffer.data(), buffer.size(), 0, sa, salen); } -std::shared_ptr<NetworkEngine::Request> +std::shared_ptr<Request> NetworkEngine::sendFindNode(std::shared_ptr<Node> n, const InfoHash& target, want_t want, RequestCb on_done, RequestExpiredCb on_expired) { auto tid = TransId {TransPrefix::FIND_NODE, getNewTid()}; @@ -454,12 +490,12 @@ NetworkEngine::sendFindNode(std::shared_ptr<Node> n, const InfoHash& target, wan Blob b {buffer.data(), buffer.data() + buffer.size()}; std::shared_ptr<Request> req(new Request {tid.getTid(), n, std::move(b), - [=](std::shared_ptr<Request> req_status, ParsedMessage&& msg) { /* on done */ + [=](const Request& req_status, ParsedMessage&& msg) { /* on done */ if (on_done) { on_done(req_status, {std::forward<ParsedMessage>(msg)}); } }, - [=](std::shared_ptr<Request> req_status, bool) { /* on expired */ + [=](const Request& req_status, bool) { /* on expired */ if (on_expired) { on_expired(req_status, {}); } @@ -471,7 +507,7 @@ NetworkEngine::sendFindNode(std::shared_ptr<Node> n, const InfoHash& target, wan } -std::shared_ptr<NetworkEngine::Request> +std::shared_ptr<Request> NetworkEngine::sendGetValues(std::shared_ptr<Node> n, const InfoHash& info_hash, want_t want, RequestCb on_done, RequestExpiredCb on_expired) { auto tid = TransId {TransPrefix::GET_VALUES, getNewTid()}; @@ -497,12 +533,12 @@ NetworkEngine::sendGetValues(std::shared_ptr<Node> n, const InfoHash& info_hash, Blob b {buffer.data(), buffer.data() + buffer.size()}; std::shared_ptr<Request> req(new Request {tid.getTid(), n, std::move(b), - [=](std::shared_ptr<Request> req_status, ParsedMessage&& msg) { /* on done */ + [=](const Request& req_status, ParsedMessage&& msg) { /* on done */ if (on_done) { on_done(req_status, {std::forward<ParsedMessage>(msg)}); } }, - [=](std::shared_ptr<Request> req_status, bool) { /* on expired */ + [=](const Request& req_status, bool) { /* on expired */ if (on_expired) { on_expired(req_status, {}); } @@ -687,7 +723,7 @@ NetworkEngine::bufferNodes(sa_family_t af, const InfoHash& id, want_t want, return {std::move(bnodes4), std::move(bnodes6)}; } -std::shared_ptr<NetworkEngine::Request> +std::shared_ptr<Request> NetworkEngine::sendListen(std::shared_ptr<Node> n, const InfoHash& infohash, const Blob& token, RequestCb on_done, RequestExpiredCb on_expired) { auto tid = TransId {TransPrefix::LISTEN, getNewTid()}; @@ -709,11 +745,11 @@ NetworkEngine::sendListen(std::shared_ptr<Node> n, const InfoHash& infohash, con Blob b {buffer.data(), buffer.data() + buffer.size()}; std::shared_ptr<Request> req(new Request {tid.getTid(), n, std::move(b), - [=](std::shared_ptr<Request> req_status, ParsedMessage&& msg) { /* on done */ + [=](const Request& req_status, ParsedMessage&& msg) { /* on done */ if (on_done) on_done(req_status, {std::forward<ParsedMessage>(msg)}); }, - [=](std::shared_ptr<Request> req_status, bool) { /* on expired */ + [=](const Request& req_status, bool) { /* on expired */ if (on_expired) on_expired(req_status, {}); }, @@ -742,7 +778,7 @@ NetworkEngine::sendListenConfirmation(const sockaddr* sa, socklen_t salen, Trans send(buffer.data(), buffer.size(), 0, sa, salen); } -std::shared_ptr<NetworkEngine::Request> +std::shared_ptr<Request> NetworkEngine::sendAnnounceValue(std::shared_ptr<Node> n, const InfoHash& infohash, const Value& value, time_point created, const Blob& token, RequestCb on_done, RequestExpiredCb on_expired) { auto tid = TransId {TransPrefix::ANNOUNCE_VALUES, getNewTid()}; @@ -768,7 +804,7 @@ NetworkEngine::sendAnnounceValue(std::shared_ptr<Node> n, const InfoHash& infoha Blob b {buffer.data(), buffer.data() + buffer.size()}; std::shared_ptr<Request> req(new Request {tid.getTid(), n, std::move(b), - [=](std::shared_ptr<Request> req_status, ParsedMessage&& msg) { /* on done */ + [=](const Request& req_status, ParsedMessage&& msg) { /* on done */ if (msg.value_id == Value::INVALID_ID) { DHT_LOG.DEBUG("Unknown search or announce!"); } else { @@ -779,7 +815,7 @@ NetworkEngine::sendAnnounceValue(std::shared_ptr<Node> n, const InfoHash& infoha } } }, - [=](std::shared_ptr<Request> req_status, bool) { /* on expired */ + [=](const Request& req_status, bool) { /* on expired */ if (on_expired) { on_expired(req_status, {}); } @@ -852,7 +888,7 @@ findMapValue(msgpack::object& map, const std::string& key) { } void -NetworkEngine::ParsedMessage::msgpack_unpack(msgpack::object msg) +ParsedMessage::msgpack_unpack(msgpack::object msg) { auto y = findMapValue(msg, "y"); auto a = findMapValue(msg, "a"); diff --git a/src/node.cpp b/src/node.cpp index 0da55411abdbcf3a6aa04b9a0831c2d71b41d3d0..9beeca2e4ff26f5423801b76ba496d4a17659825 100644 --- a/src/node.cpp +++ b/src/node.cpp @@ -20,6 +20,7 @@ #include "node.h" +#include "request.h" #include <sstream> @@ -33,22 +34,21 @@ constexpr std::chrono::seconds Node::MAX_RESPONSE_TIME; bool Node::isGood(time_point now) const { - return - not isExpired(now) && + return not expired_ && reply_time >= now - NODE_GOOD_TIME && time >= now - NODE_EXPIRE_TIME; } bool -Node::isExpired(time_point now) const +Node::isMessagePending() const { - return pinged >= 3 && reply_time < pinged_time && pinged_time + MAX_RESPONSE_TIME < now; -} - -bool -Node::isMessagePending(time_point now) const -{ - return reply_time < pinged_time && pinged_time + MAX_RESPONSE_TIME > now; + for (auto w : requests_) { + if (auto r = w.lock()) { + if (r->pending()) + return true; + } + } + return false; } void @@ -60,25 +60,42 @@ Node::update(const sockaddr* sa, socklen_t salen) /** To be called when a message was sent to the node */ void -Node::requested(time_point now) +Node::requested(std::shared_ptr<Request>& req) { - pinged++; - if (reply_time > pinged_time || pinged_time + MAX_RESPONSE_TIME < now) - pinged_time = now; + requests_.emplace_back(req); } /** To be called when a message was received from the node. - Answer should be true if the message was an aswer to a request we made*/ + Req should be true if the message was an aswer to a request we made*/ void -Node::received(time_point now, bool answer) +Node::received(time_point now, std::shared_ptr<Request> req) { time = now; - if (answer) { - pinged = 0; + if (req) { reply_time = now; + expired_ = false; + for (auto it = requests_.begin(); it != requests_.end();) { + auto r = it->lock(); + if (not r or r == req) + it = requests_.erase(it); + else + ++it; + } } } +void +Node::setExpired() +{ + expired_ = true; + for (auto w : requests_) { + if (auto r = w.lock()) + r->setExpired(); + } + requests_.clear(); +} + + std::string Node::toString() const { diff --git a/src/node_cache.cpp b/src/node_cache.cpp index 7ed542e524ccd8a1f21ea1cbf93eb3bde3ac8bdd..a579f416b39511a4e71a09fa8a61413226456a4e 100644 --- a/src/node_cache.cpp +++ b/src/node_cache.cpp @@ -100,8 +100,8 @@ NodeCache::NodeTree::get(const InfoHash& id, const sockaddr* sa, socklen_t sa_le } else if (confirm || node->time < now - Node::NODE_EXPIRE_TIME) { node->update(sa, sa_len); } - if (confirm) - node->received(now, confirm >= 2); + /*if (confirm) + node->received(now, confirm >= 2);*/ return node; }