diff --git a/include/opendht/callbacks.h b/include/opendht/callbacks.h index 5318c08c336385f584d15e484c5cc32a5c96ccca..ed75c0d6e830926007243087e9673f70c95ecee0 100644 --- a/include/opendht/callbacks.h +++ b/include/opendht/callbacks.h @@ -70,6 +70,7 @@ static constexpr size_t DEFAULT_STORAGE_LIMIT {1024 * 1024 * 64}; using ValuesExport = std::pair<InfoHash, Blob>; +using QueryCallback = std::function<bool(const std::vector<std::shared_ptr<FieldValueIndex>>& fields)>; using GetCallback = std::function<bool(const std::vector<std::shared_ptr<Value>>& values)>; using GetCallbackSimple = std::function<bool(std::shared_ptr<Value> value)>; using ShutdownCallback = std::function<void()>; diff --git a/include/opendht/dht.h b/include/opendht/dht.h index a5409e367b715deea7412f61ced4dd92fb2a7e9e..d0a89eb261a28adb6423feb42449b01a173a5e40 100644 --- a/include/opendht/dht.h +++ b/include/opendht/dht.h @@ -133,18 +133,32 @@ public: * @param cb a function called when new values are found on the network. * It should return false to stop the operation. * @param donecb a function called when the operation is complete. - cb and donecb won't be called again afterward. + cb and donecb won't be called again afterward. * @param f a filter function used to prefilter values. */ - virtual void get(const InfoHash& key, GetCallback cb, DoneCallback donecb={}, Value::Filter&& f={}); - virtual void get(const InfoHash& key, GetCallback cb, DoneCallbackSimple donecb={}, Value::Filter&& f={}) { - get(key, cb, bindDoneCb(donecb), std::forward<Value::Filter>(f)); + virtual void get(const InfoHash& key, GetCallback cb, DoneCallback donecb={}, Value::Filter&& f={}, Where&& w = {}); + virtual void get(const InfoHash& key, GetCallback cb, DoneCallbackSimple donecb={}, Value::Filter&& f={}, Where&& w = {}) { + get(key, cb, bindDoneCb(donecb), std::forward<Value::Filter>(f), std::forward<Where>(w)); } - virtual void get(const InfoHash& key, GetCallbackSimple cb, DoneCallback donecb={}, Value::Filter&& f={}) { - get(key, bindGetCb(cb), donecb, std::forward<Value::Filter>(f)); + virtual void get(const InfoHash& key, GetCallbackSimple cb, DoneCallback donecb={}, Value::Filter&& f={}, Where&& w = {}) { + get(key, bindGetCb(cb), donecb, std::forward<Value::Filter>(f), std::forward<Where>(w)); } - virtual void get(const InfoHash& key, GetCallbackSimple cb, DoneCallbackSimple donecb, Value::Filter&& f={}) { - get(key, bindGetCb(cb), bindDoneCb(donecb), std::forward<Value::Filter>(f)); + virtual void get(const InfoHash& key, GetCallbackSimple cb, DoneCallbackSimple donecb, Value::Filter&& f={}, Where&& w = {}) { + get(key, bindGetCb(cb), bindDoneCb(donecb), std::forward<Value::Filter>(f), std::forward<Where>(w)); + } + /** + * Similar to Dht::get, but sends a Query to filter data remotely. + * @param key the key for which to query data for. + * @param cb a function called when new values are found on the network. + * It should return false to stop the operation. + * @param donecb a function called when the operation is complete. + cb and donecb won't be called again afterward. + * @param q a query used to filter values on the remotes before they send a + * response. + */ + virtual void query(const InfoHash& key, QueryCallback cb, DoneCallback done_cb = {}, Query&& q = {}); + virtual void query(const InfoHash& key, QueryCallback cb, DoneCallbackSimple done_cb = {}, Query&& q = {}) { + query(key, cb, bindDoneCb(done_cb), std::forward<Query>(q)); } /** @@ -158,14 +172,10 @@ public: std::shared_ptr<Value> getLocalById(const InfoHash& key, Value::Id vid) const; /** - * Announce a value on all available protocols (IPv4, IPv6), and - * automatically re-announce when it's about to expire. + * Announce a value on all available protocols (IPv4, IPv6). + * * The operation will start as soon as the node is connected to the network. * The done callback will be called once, when the first announce succeeds, or fails. - * - * A "put" operation will never end by itself because the value will need to be - * reannounced on a regular basis. - * User can call #cancelPut(InfoHash, Value::Id) to cancel a put operation. */ void put(const InfoHash& key, std::shared_ptr<Value>, @@ -221,9 +231,9 @@ public: * * @return a token to cancel the listener later. */ - virtual size_t listen(const InfoHash&, GetCallback, Value::Filter&&={}); - virtual size_t listen(const InfoHash& key, GetCallbackSimple cb, Value::Filter f={}) { - return listen(key, bindGetCb(cb), std::forward<Value::Filter>(f)); + virtual size_t listen(const InfoHash&, GetCallback, Value::Filter&&={}, Where&& w = {}); + virtual size_t listen(const InfoHash& key, GetCallbackSimple cb, Value::Filter f={}, Where w = {}) { + return listen(key, bindGetCb(cb), std::forward<Value::Filter>(f), std::forward<Where>(w)); } virtual bool cancelListen(const InfoHash&, size_t token); @@ -317,6 +327,8 @@ private: struct Get { time_point start; Value::Filter filter; + std::shared_ptr<Query> query; + QueryCallback query_cb; GetCallback get_cb; DoneCallback done_cb; }; @@ -335,6 +347,7 @@ private: * A single "listen" operation data */ struct LocalListener { + std::shared_ptr<Query> query; Value::Filter filter; GetCallback get_cb; }; @@ -359,8 +372,9 @@ private: struct Listener { size_t rid {}; time_point time {}; + Query query {}; - /*constexpr*/ Listener(size_t rid, time_point t) : rid(rid), time(t) {} + /*constexpr*/ Listener(size_t rid, time_point t, Query&& q) : rid(rid), time(t), query(q) {} void refresh(size_t tid, time_point t) { rid = tid; @@ -427,7 +441,7 @@ private: decltype(Dht::store)::iterator findStorage(const InfoHash& id); decltype(Dht::store)::const_iterator findStorage(const InfoHash& id) const; - void storageAddListener(const InfoHash& id, const std::shared_ptr<Node>& node, size_t tid); + void storageAddListener(const InfoHash& id, const std::shared_ptr<Node>& node, size_t tid, Query&& = {}); bool storageStore(const InfoHash& id, const std::shared_ptr<Value>& value, time_point created); void expireStorage(); void storageChanged(Storage& st, ValueStorage&); @@ -472,13 +486,13 @@ private: // Searches /** - * Low-level method that will perform a search on the DHT for the - * specified infohash (id), using the specified IP version (IPv4 or IPv6). - * The values can be filtered by an arbitrary provided filter. + * Low-level method that will perform a search on the DHT for the specified + * infohash (id), using the specified IP version (IPv4 or IPv6). */ - std::shared_ptr<Search> search(const InfoHash& id, sa_family_t af, GetCallback = nullptr, DoneCallback = nullptr, Value::Filter = Value::AllFilter()); + std::shared_ptr<Search> search(const InfoHash& id, sa_family_t af, GetCallback = {}, QueryCallback = {}, DoneCallback = {}, Value::Filter = {}, Query q = {}); + void announce(const InfoHash& id, sa_family_t af, std::shared_ptr<Value> value, DoneCallback callback, time_point created=time_point::max(), bool permanent = false); - size_t listenTo(const InfoHash& id, sa_family_t af, GetCallback cb, Value::Filter f = Value::AllFilter()); + size_t listenTo(const InfoHash& id, sa_family_t af, GetCallback cb, Value::Filter f = Value::AllFilter(), const std::shared_ptr<Query>& q = {}); void bootstrapSearch(Search& sr); Search *findSearch(unsigned short tid, sa_family_t af); @@ -508,11 +522,14 @@ private: NetworkEngine::RequestAnswer onFindNode(std::shared_ptr<Node> node, InfoHash& hash, want_t want); void onFindNodeDone(const Request& status, NetworkEngine::RequestAnswer& a, std::shared_ptr<Search> sr); /* when we receive a "get values" request */ - NetworkEngine::RequestAnswer onGetValues(std::shared_ptr<Node> node, InfoHash& hash, want_t want); - void onGetValuesDone(const Request& status, NetworkEngine::RequestAnswer& a, std::shared_ptr<Search> sr); + NetworkEngine::RequestAnswer onGetValues(std::shared_ptr<Node> node, InfoHash& hash, want_t want, const Query& q); + void onGetValuesDone(const Request& status, NetworkEngine::RequestAnswer& a, std::shared_ptr<Search> sr, + const std::shared_ptr<Query>& orig_query); /* when we receive a listen request */ - NetworkEngine::RequestAnswer onListen(std::shared_ptr<Node> node, InfoHash& hash, Blob& token, size_t rid); - void onListenDone(const Request& status, NetworkEngine::RequestAnswer& a, std::shared_ptr<Search>& sr); + NetworkEngine::RequestAnswer onListen(std::shared_ptr<Node> node, InfoHash& hash, Blob& token, size_t rid, + Query&& query); + void onListenDone(const Request& status, NetworkEngine::RequestAnswer& a, + std::shared_ptr<Search>& sr, const std::shared_ptr<Query>& orig_query); /* when we receive an announce request */ NetworkEngine::RequestAnswer onAnnounce(std::shared_ptr<Node> node, InfoHash& hash, Blob& token, std::vector<std::shared_ptr<Value>> v, time_point created); diff --git a/include/opendht/network_engine.h b/include/opendht/network_engine.h index 3d336c354e3e04d76ce9e03d33d0e352eba16bf1..04afdecf392739656c67571f895935a2aa7bf66b 100644 --- a/include/opendht/network_engine.h +++ b/include/opendht/network_engine.h @@ -153,6 +153,7 @@ public: Blob ntoken {}; Value::Id vid {}; std::vector<std::shared_ptr<Value>> values {}; + std::vector<std::shared_ptr<FieldValueIndex>> fields {}; std::vector<std::shared_ptr<Node>> nodes4 {}; std::vector<std::shared_ptr<Node>> nodes6 {}; RequestAnswer() {} @@ -223,7 +224,8 @@ private: */ std::function<RequestAnswer(std::shared_ptr<Node>, InfoHash&, - want_t)> onGetValues {}; + want_t, + Query)> onGetValues {}; /** * @brief on listen request callback. * @@ -235,7 +237,8 @@ private: std::function<RequestAnswer(std::shared_ptr<Node>, InfoHash&, Blob&, - uint16_t)> onListen {}; + uint16_t, + Query)> onListen {}; /** * @brief on announce request callback. * @@ -290,9 +293,9 @@ public: * @param nodes6 The ipv6 closest nodes. * @param values The values to send. */ - void tellListener(std::shared_ptr<Node> n, uint16_t rid, InfoHash hash, want_t want, Blob ntoken, - std::vector<std::shared_ptr<Node>> nodes, std::vector<std::shared_ptr<Node>> nodes6, - std::vector<std::shared_ptr<Value>> values); + void tellListener(std::shared_ptr<Node> n, uint16_t rid, const InfoHash& hash, want_t want, const Blob& ntoken, + std::vector<std::shared_ptr<Node>>&& nodes, std::vector<std::shared_ptr<Node>>&& nodes6, + std::vector<std::shared_ptr<Value>>&& values, const Query& q); bool isRunning(sa_family_t af) const; inline want_t want () const { return dht_socket >= 0 && dht_socket6 >= 0 ? (WANT4 | WANT6) : -1; } @@ -314,13 +317,15 @@ public: RequestExpiredCb on_expired); std::shared_ptr<Request> sendGetValues(std::shared_ptr<Node> n, - const InfoHash& target, + const InfoHash& info_hash, + const Query& query, want_t want, RequestCb on_done, RequestExpiredCb on_expired); std::shared_ptr<Request> sendListen(std::shared_ptr<Node> n, const InfoHash& infohash, + const Query& query, const Blob& token, RequestCb on_done, RequestExpiredCb on_expired); @@ -432,6 +437,7 @@ private: const Blob& nodes, const Blob& nodes6, const std::vector<std::shared_ptr<Value>>& st, + const Query& query, const Blob& token); Blob bufferNodes(sa_family_t af, const InfoHash& id, std::vector<std::shared_ptr<Node>>& nodes); @@ -452,7 +458,7 @@ private: const std::string& message, bool include_id=false); - void deserializeNodesValues(ParsedMessage& msg); + void deserializeNodes(ParsedMessage& msg); std::queue<time_point> rate_limit_time {}; static std::mt19937 rd_device; diff --git a/include/opendht/securedht.h b/include/opendht/securedht.h index d5b5e6d040fd8c6f9d545cf067f761d8aa716bc0..11b92fb9d196354139ae19703a664b78eb7d5dac 100644 --- a/include/opendht/securedht.h +++ b/include/opendht/securedht.h @@ -85,18 +85,18 @@ public: * If the signature can't be checked, or if the data can't be decrypted, it is not returned. * Public, non-signed & non-encrypted data is retransmitted as-is. */ - virtual void get(const InfoHash& id, GetCallback cb, DoneCallback donecb={}, Value::Filter&& = {}) override; - virtual void get(const InfoHash& id, GetCallback cb, DoneCallbackSimple donecb={}, Value::Filter&& f = {}) override { - get(id, cb, bindDoneCb(donecb), std::forward<Value::Filter>(f)); + virtual void get(const InfoHash& id, GetCallback cb, DoneCallback donecb={}, Value::Filter&& = {}, Where&& w = {}) override; + virtual void get(const InfoHash& id, GetCallback cb, DoneCallbackSimple donecb={}, Value::Filter&& f = {}, Where&& w = {}) override { + get(id, cb, bindDoneCb(donecb), std::forward<Value::Filter>(f), std::forward<Where>(w)); } - virtual void get(const InfoHash& key, GetCallbackSimple cb, DoneCallback donecb={}, Value::Filter&& f={}) override { - get(key, bindGetCb(cb), donecb, std::forward<Value::Filter>(f)); + virtual void get(const InfoHash& key, GetCallbackSimple cb, DoneCallback donecb={}, Value::Filter&& f={}, Where&& w = {}) override { + get(key, bindGetCb(cb), donecb, std::forward<Value::Filter>(f), std::forward<Where>(w)); } - virtual void get(const InfoHash& key, GetCallbackSimple cb, DoneCallbackSimple donecb, Value::Filter&& f={}) override { - get(key, bindGetCb(cb), bindDoneCb(donecb), std::forward<Value::Filter>(f)); + virtual void get(const InfoHash& key, GetCallbackSimple cb, DoneCallbackSimple donecb, Value::Filter&& f={}, Where&& w = {}) override { + get(key, bindGetCb(cb), bindDoneCb(donecb), std::forward<Value::Filter>(f), std::forward<Where>(w)); } - virtual size_t listen(const InfoHash& id, GetCallback cb, Value::Filter&& = {}) override; + virtual size_t listen(const InfoHash& id, GetCallback cb, Value::Filter&& = {}, Where&& w = {}) override; /** * Will take ownership of the value, sign it using our private key and put it in the DHT. diff --git a/src/dht.cpp b/src/dht.cpp index 590bea804180f610b54b64d50782e3af0490963a..9ebc4f81c1a06f295ce7bda487b35f1c10cf94bd 100644 --- a/src/dht.cpp +++ b/src/dht.cpp @@ -176,6 +176,7 @@ struct Dht::SearchNode { SearchNode(std::shared_ptr<Node> node) : node(node) {} using AnnounceStatusMap = std::map<Value::Id, std::shared_ptr<Request>>; + using SyncStatusMap = std::map<std::shared_ptr<Query>, std::shared_ptr<Request>>; /** * Can we use this node to listen/announce now ? @@ -189,10 +190,31 @@ struct Dht::SearchNode { * Could a "get" request be sent to this node now ? * update: time of the last "get" op for the search. */ - bool canGet(time_point now, time_point update) const { + bool canGet(time_point now, time_point update, std::shared_ptr<Query> q = {}) const { + const auto& get_status = q ? getStatus.find(q) : getStatus.end(); + const auto& sq_status = std::find_if(getStatus.begin(), getStatus.end(), + [q](const SyncStatusMap::value_type& s) { + return q and s.first and q->isSatisfiedBy(*s.first); + } + ); return not node->isExpired() and (now > last_get_reply + Node::NODE_EXPIRE_TIME or update > last_get_reply) - and (not getStatus or not getStatus->pending()); + and ((get_status == getStatus.end() or not get_status->second or not get_status->second->pending()) and + (sq_status == getStatus.end() or not sq_status->second or not sq_status->second->pending())); + } + + bool expired(const SyncStatusMap& status) const { + return std::find_if(status.begin(), status.end(), + [](const SyncStatusMap::value_type& r){ + return r.second and r.second->expired(); + }) != status.end(); + } + + bool pending(const SyncStatusMap& status) const { + return std::find_if(status.begin(), status.end(), + [](const SyncStatusMap::value_type& r){ + return r.second and r.second->pending(); + }) != status.end(); } bool isAnnounced(Value::Id vid, const ValueType& type, time_point now) const { @@ -202,11 +224,27 @@ struct Dht::SearchNode { } return ack->second->reply_time + type.expiration > now; } + bool isListening(time_point now) const { - if (not listenStatus) + auto ls = listenStatus.begin(); + for ( ; ls != listenStatus.end() ; ++ls) { + if (isListening(now, ls)) { + break; + } + } + return ls != listenStatus.end(); + } + bool isListening(time_point now, const std::shared_ptr<Query>& q) const { + const auto& ls = listenStatus.find(q); + if (ls == listenStatus.end()) return false; - - return listenStatus->reply_time + LISTEN_EXPIRE_TIME > now; + else + return isListening(now, ls); + } + bool isListening(time_point now, SyncStatusMap::const_iterator listen_status) const { + if (listen_status == listenStatus.end()) + return false; + return listen_status->second->reply_time + LISTEN_EXPIRE_TIME > now; } /** @@ -226,10 +264,20 @@ struct Dht::SearchNode { * Assumng the node is synced, should a "listen" request be sent to this node now ? */ time_point getListenTime() const { - if (not listenStatus) + time_point t {time_point::max()}; + for (auto ls = listenStatus.begin(); ls != listenStatus.end() ; ++ls) { + t = std::min(t, getListenTime(ls)); + } + return t; + } + time_point getListenTime(const std::shared_ptr<Query>& q) const { + return getListenTime(listenStatus.find(q)); + } + time_point getListenTime(SyncStatusMap::const_iterator listen_status) const { + if (listen_status == listenStatus.end()) return time_point::min(); - - return listenStatus->pending() ? time_point::max() : listenStatus->reply_time + LISTEN_EXPIRE_TIME - REANNOUNCE_MARGIN; + return listen_status->second->pending() ? time_point::max() : + listen_status->second->reply_time + LISTEN_EXPIRE_TIME - REANNOUNCE_MARGIN; } /** @@ -241,10 +289,10 @@ struct Dht::SearchNode { std::shared_ptr<Node> node {}; - time_point last_get_reply {time_point::min()}; /* last time received valid token */ - std::shared_ptr<Request> getStatus {}; /* get/sync status */ - std::shared_ptr<Request> listenStatus {}; - AnnounceStatusMap acked {}; /* announcement status for a given value id */ + time_point last_get_reply {time_point::min()}; /* last time received valid token */ + SyncStatusMap getStatus {}; /* get/sync status */ + SyncStatusMap listenStatus {}; /* listen status */ + AnnounceStatusMap acked {}; /* announcement status for a given value id */ Blob token {}; @@ -272,7 +320,7 @@ struct Dht::Search { std::vector<Announce> announce {}; /* pending gets */ - std::vector<Get> callbacks {}; + std::multimap<time_point, Get> callbacks {}; /* listeners */ std::map<size_t, LocalListener> listeners {}; @@ -295,7 +343,7 @@ struct Dht::Search { unsigned currentGetRequests() const { unsigned count = 0; for (const auto& n : nodes) - if (not n.isBad() and n.getStatus and n.getStatus->pending()) + if (not n.isBad() and n.pending(n.getStatus)) count++; return count; } @@ -724,48 +772,64 @@ Dht::searchSendGetValues(std::shared_ptr<Search> sr, SearchNode* pn, bool update const auto& now = scheduler.time(); const time_point up = update ? sr->getLastGetTime() : time_point::min(); + SearchNode* n = nullptr; - if (pn) { - if (not pn->canGet(now, up)) - return nullptr; - n = pn; - } else { - for (auto& sn : sr->nodes) { - if (sn.canGet(now, up)) { - n = &sn; - break; + auto cb = sr->callbacks.begin(); + do { /* for all queries to send */ + + /* cases v 'get' v 'find_node' */ + auto query = cb != sr->callbacks.end() ? cb->second.query : std::make_shared<Query>(); + if (pn) { + if (not pn->canGet(now, up, query)) + return nullptr; + n = pn; + } else { + for (auto& sn : sr->nodes) { + if (sn.canGet(now, up, query)) { + n = &sn; + break; + } } + if (not n) + return nullptr; } - if (not n) - return nullptr; - } - /*DHT_LOG.DEBUG("[search %s IPv%c] [node %s] sending 'get'", - sr->id.toString().c_str(), sr->af == AF_INET ? '4' : '6', - n->node->toString().c_str());*/ + std::weak_ptr<Search> ws = sr; + auto onDone = + [this,ws,query](const Request& status, NetworkEngine::RequestAnswer&& answer) mutable { + if (auto sr = ws.lock()) { + sr->insertNode(status.node, scheduler.time(), answer.ntoken); + onGetValuesDone(status, answer, sr, query); + } + }; + auto onExpired = + [this,ws,query](const Request& status, bool over) mutable { + if (auto sr = ws.lock()) { + if (auto srn = sr->getNode(status.node)) { + srn->candidate = not over; + if (over) + srn->getStatus.erase(query); + } + scheduler.edit(sr->nextSearchStep, scheduler.time()); + } + }; + std::shared_ptr<Request> rstatus; + if (sr->callbacks.empty() and sr->listeners.empty()) { + DHT_LOG.WARN("[search %s IPv%c] [node %s] sending 'find_node'", + sr->id.toString().c_str(), sr->af == AF_INET ? '4' : '6', + n->node->toString().c_str()); + rstatus = network_engine.sendFindNode(n->node, sr->id, -1, onDone, onExpired); + } else { + DHT_LOG.WARN("[search %s IPv%c] [node %s] sending 'get'", + sr->id.toString().c_str(), sr->af == AF_INET ? '4' : '6', + n->node->toString().c_str()); + rstatus = network_engine.sendGetValues(n->node, sr->id, query ? *query : Query {}, -1, onDone, onExpired); + } + n->getStatus[query] = rstatus; - std::weak_ptr<Search> ws = sr; - auto onDone = - [this,ws](const Request& status, NetworkEngine::RequestAnswer&& answer) mutable { - if (auto sr = ws.lock()) { - sr->insertNode(status.node, scheduler.time(), answer.ntoken); - onGetValuesDone(status, answer, sr); - } - }; - auto onExpired = - [this,ws](const Request& status, bool over) mutable { - if (auto sr = ws.lock()) { - if (auto srn = sr->getNode(status.node)) - srn->candidate = not over; - scheduler.edit(sr->nextSearchStep, scheduler.time()); - } - }; - std::shared_ptr<Request> rstatus; - if (sr->callbacks.empty() and sr->listeners.empty()) - rstatus = network_engine.sendFindNode(n->node, sr->id, -1, onDone, onExpired); - else - rstatus = network_engine.sendGetValues(n->node, sr->id, -1, onDone, onExpired); - n->getStatus = rstatus; + if (not sr->isSynced(now) or cb == sr->callbacks.end()) + break; /* only trying to find nodes, only send the oldest query */ + } while (++cb != sr->callbacks.end()); return n; } @@ -777,13 +841,15 @@ Dht::searchStep(std::shared_ptr<Search> sr) if (not sr or sr->expired or sr->done) return; const auto& now = scheduler.time(); - DHT_LOG.DEBUG("[search %s IPv%c] step (%d requests)", sr->id.toString().c_str(), sr->af == AF_INET ? '4' : '6', sr->currentGetRequests()); + DHT_LOG.DEBUG("[search %s IPv%c] step (%d requests)", + sr->id.toString().c_str(), sr->af == AF_INET ? '4' : '6', sr->currentGetRequests()); sr->step_time = now; if (sr->refill_time + Node::NODE_EXPIRE_TIME < now and sr->nodes.size()-sr->getNumberOfBadNodes() < SEARCH_NODES) { if (auto added = sr->refill(sr->af == AF_INET ? buckets : buckets6, now)) { sr->refill_time = now; - DHT_LOG.DEBUG("[search %s IPv%c] refilled with %u nodes", sr->id.toString().c_str(), (sr->af == AF_INET) ? '4' : '6', added); + DHT_LOG.DEBUG("[search %s IPv%c] refilled with %u nodes", + sr->id.toString().c_str(), (sr->af == AF_INET) ? '4' : '6', added); } } @@ -793,9 +859,11 @@ Dht::searchStep(std::shared_ptr<Search> sr) // search is synced but some (newer) get operations are not complete // Call callbacks when done for (auto b = sr->callbacks.begin(); b != sr->callbacks.end();) { - if (sr->isDone(*b, now)) { - if (b->done_cb) - b->done_cb(true, sr->getNodes()); + if (sr->isDone(b->second, now)) { + if (b->second.done_cb) + b->second.done_cb(true, sr->getNodes()); + for (auto& n : sr->nodes) + n.getStatus.erase(b->second.query); b = sr->callbacks.erase(b); } else @@ -808,43 +876,50 @@ Dht::searchStep(std::shared_ptr<Search> sr) // true if this node is part of the target nodes cluter. bool in = sr->id.xorCmp(myid, sr->nodes.back().node->id) < 0; - DHT_LOG.DEBUG("[search %s IPv%c] synced%s", sr->id.toString().c_str(), sr->af == AF_INET ? '4' : '6', in ? ", in" : ""); + DHT_LOG.DEBUG("[search %s IPv%c] synced%s", + sr->id.toString().c_str(), sr->af == AF_INET ? '4' : '6', in ? ", in" : ""); if (not sr->listeners.empty()) { unsigned i = 0; for (auto& n : sr->nodes) { if (not n.isSynced(now)) continue; - if (n.getListenTime() <= now) { - DHT_LOG.WARN("[search %s IPv%c] [node %s] sending 'listen'", - sr->id.toString().c_str(), sr->af == AF_INET ? '4' : '6', - n.node->toString().c_str()); - //std::cout << "Sending listen to " << n.node->id << " " << print_addr(n.node->ss, n.node->sslen) << std::endl; - - //network_engine.cancelRequest(n.listenStatus); - auto ls = n.listenStatus; - - std::weak_ptr<Search> ws = sr; - n.listenStatus = network_engine.sendListen(n.node, sr->id, n.token, - [this,ws,ls](const Request& status, - NetworkEngine::RequestAnswer&& answer) mutable - { /* on done */ - // cancel previous request - network_engine.cancelRequest(ls); - if (auto sr = ws.lock()) { - onListenDone(status, answer, sr); - searchStep(sr); + for (const auto& l : sr->listeners) { + const auto& query = l.second.query; + if (n.getListenTime(query) <= now) { + DHT_LOG.WARN("[search %s IPv%c] [node %s] sending 'listen'", + sr->id.toString().c_str(), sr->af == AF_INET ? '4' : '6', + n.node->toString().c_str()); + //std::cout << "Sending listen to " << n.node->id << " " << print_addr(n.node->ss, n.node->sslen) << std::endl; + + const auto& r = n.listenStatus.find(query); + auto last_req = r != n.listenStatus.end() ? r->second : std::shared_ptr<Request> {}; + + std::weak_ptr<Search> ws = sr; + n.listenStatus[query] = network_engine.sendListen(n.node, sr->id, *query, n.token, + [this,ws,last_req,query](const Request& req, + NetworkEngine::RequestAnswer&& answer) mutable + { /* on done */ + network_engine.cancelRequest(last_req); + if (auto sr = ws.lock()) { + onListenDone(req, answer, sr, query); + searchStep(sr); + if (auto sn = sr->getNode(req.node)) + sn->listenStatus.erase(query); + } + }, + [this,ws,last_req,query](const Request& req, bool over) mutable + { /* on expired */ + network_engine.cancelRequest(last_req); + if (auto sr = ws.lock()) { + searchStep(sr); + if (over) + if (auto sn = sr->getNode(req.node)) + sn->listenStatus.erase(query); + } } - }, - [this,ws,ls](const Request&, bool over) mutable - { /* on expired */ - if (over) { - network_engine.cancelRequest(ls); - if (auto sr = ws.lock()) - scheduler.edit(sr->nextSearchStep, scheduler.time()); - } - } - ); + ); + } } if (not n.candidate and ++i == LISTEN_NODES) break; @@ -928,8 +1003,8 @@ Dht::searchStep(std::shared_ptr<Search> sr) { auto get_cbs = std::move(sr->callbacks); for (const auto& g : get_cbs) { - if (g.done_cb) - g.done_cb(false, {}); + if (g.second.done_cb) + g.second.done_cb(false, {}); } } { @@ -994,7 +1069,7 @@ Dht::Search::getLastGetTime() const { time_point last = time_point::min(); for (const auto& g : callbacks) - last = std::max(last, g.start); + last = std::max(last, g.second.start); return last; } @@ -1004,10 +1079,11 @@ Dht::Search::isDone(const Get& get, time_point now) const unsigned i = 0; const auto limit = std::max(get.start, now - Node::NODE_EXPIRE_TIME); for (const auto& sn : nodes) { + const auto& gs = sn.getStatus.find(get.query); if (sn.isBad()) continue; - if (sn.last_get_reply < limit) - return false; + if (gs == sn.getStatus.end() or not gs->second or gs->second->reply_time < limit) + return false; if (++i == TARGET_NODES) break; } @@ -1024,7 +1100,7 @@ Dht::Search::getUpdateTime(time_point now) const for (const auto& sn : nodes) { if (sn.node->isExpired() or (sn.candidate and t >= TARGET_NODES)) continue; - bool pending = sn.getStatus and sn.getStatus->pending(); + auto pending = sn.pending(sn.getStatus); if (sn.last_get_reply < std::max(now - Node::NODE_EXPIRE_TIME, last_get) or pending) { // not isSynced if (not pending and reqs < SEARCH_REQUESTS) @@ -1072,7 +1148,12 @@ Dht::Search::isListening(time_point now) const for (const auto& n : nodes) { if (n.isBad()) continue; - if (!n.isListening(now)) + SearchNode::SyncStatusMap::const_iterator ls {}; + for (ls = n.listenStatus.begin(); ls != n.listenStatus.end() ; ++ls) { + if (n.isListening(now, ls)) + break; + } + if (ls == n.listenStatus.end()) return false; if (++i == LISTEN_NODES) break; @@ -1203,12 +1284,12 @@ Dht::Search::refill(const RoutingTable& r, time_point now) { /* Start a search. */ std::shared_ptr<Dht::Search> -Dht::search(const InfoHash& id, sa_family_t af, GetCallback callback, DoneCallback done_callback, Value::Filter filter) +Dht::search(const InfoHash& id, sa_family_t af, GetCallback gcb, QueryCallback qcb, DoneCallback dcb, Value::Filter f, Query q) { if (!isRunning(af)) { DHT_LOG.ERR("[search %s IPv%c] unsupported protocol", id.toString().c_str(), (af == AF_INET) ? '4' : '6'); - if (done_callback) - done_callback(false, {}); + if (dcb) + dcb(false, {}); return {}; } @@ -1248,14 +1329,22 @@ Dht::search(const InfoHash& id, sa_family_t af, GetCallback callback, DoneCallba search_id++; } - if (callback) - sr->callbacks.push_back({/*.start=*/scheduler.time(), /*.filter=*/filter, /*.get_cb=*/callback, /*.done_cb=*/done_callback}); - bootstrapSearch(*sr); + if (gcb or qcb) { + auto now = scheduler.time(); + sr->callbacks.insert(std::make_pair<time_point, Get>( + std::move(now), + Get { scheduler.time(), f, std::make_shared<Query>(q), + qcb ? qcb : QueryCallback {}, gcb ? gcb : GetCallback {}, dcb + } + )); + } + bootstrapSearch(*sr); if (sr->nextSearchStep) scheduler.edit(sr->nextSearchStep, sr->getNextStepTime(types, scheduler.time())); else sr->nextSearchStep = scheduler.add(scheduler.time(), std::bind(&Dht::searchStep, this, sr)); + return sr; } @@ -1271,7 +1360,7 @@ Dht::announce(const InfoHash& id, sa_family_t af, std::shared_ptr<Value> value, } auto& srs = af == AF_INET ? searches4 : searches6; auto srp = srs.find(id); - auto sr = srp == srs.end() ? search(id, af, nullptr, nullptr) : srp->second; + auto sr = srp == srs.end() ? search(id, af) : srp->second; if (!sr) { if (callback) callback(false, {}); @@ -1317,7 +1406,7 @@ Dht::announce(const InfoHash& id, sa_family_t af, std::shared_ptr<Value> value, } size_t -Dht::listenTo(const InfoHash& id, sa_family_t af, GetCallback cb, Value::Filter f) +Dht::listenTo(const InfoHash& id, sa_family_t af, GetCallback cb, Value::Filter f, const std::shared_ptr<Query>& q) { const auto& now = scheduler.time(); if (!isRunning(af)) @@ -1327,22 +1416,23 @@ Dht::listenTo(const InfoHash& id, sa_family_t af, GetCallback cb, Value::Filter //DHT_LOG.WARN("listenTo %s", id.toString().c_str()); auto& srs = af == AF_INET ? searches4 : searches6; auto srp = srs.find(id); - std::shared_ptr<Search> sr = (srp == srs.end()) ? search(id, af, nullptr, nullptr) : srp->second; + std::shared_ptr<Search> sr = (srp == srs.end()) ? search(id, af) : srp->second; if (!sr) throw DhtException("Can't create search"); DHT_LOG.ERR("[search %s IPv%c] listen", id.toString().c_str(), (af == AF_INET) ? '4' : '6'); sr->done = false; auto token = ++sr->listener_token; - sr->listeners.emplace(token, LocalListener{f, cb}); + sr->listeners.emplace(token, LocalListener{q, f, cb}); scheduler.edit(sr->nextSearchStep, sr->getNextStepTime(types, now)); return token; } size_t -Dht::listen(const InfoHash& id, GetCallback cb, Value::Filter&& f) +Dht::listen(const InfoHash& id, GetCallback cb, Value::Filter&& f, Where&& where) { scheduler.syncTime(); + Query q {{}, where}; auto vals = std::make_shared<std::map<Value::Id, std::shared_ptr<Value>>>(); auto token = ++listener_token; @@ -1367,6 +1457,8 @@ Dht::listen(const InfoHash& id, GetCallback cb, Value::Filter&& f) return true; }; + auto query = std::make_shared<Query>(q); + auto filter = f.chain(q.where.getFilter()); auto st = findStorage(id); size_t tokenlocal = 0; if (st == store.end() && store.size() < MAX_HASHES) { @@ -1375,7 +1467,7 @@ Dht::listen(const InfoHash& id, GetCallback cb, Value::Filter&& f) } if (st != store.end()) { if (not st->empty()) { - std::vector<std::shared_ptr<Value>> newvals = st->get(f); + std::vector<std::shared_ptr<Value>> newvals = st->get(filter); if (not newvals.empty()) { if (!cb(newvals)) return 0; @@ -1387,11 +1479,11 @@ Dht::listen(const InfoHash& id, GetCallback cb, Value::Filter&& f) } } tokenlocal = ++st->listener_token; - st->local_listeners.emplace(tokenlocal, LocalListener{f, gcb}); + st->local_listeners.emplace(tokenlocal, LocalListener{query, filter, gcb}); } - auto token4 = Dht::listenTo(id, AF_INET, gcb, f); - auto token6 = Dht::listenTo(id, AF_INET6, gcb, f); + auto token4 = Dht::listenTo(id, AF_INET, gcb, filter, query); + auto token6 = Dht::listenTo(id, AF_INET6, gcb, filter, query); DHT_LOG.DEBUG("Added listen : %d -> %d %d %d", token, tokenlocal, token4, token6); listeners.emplace(token, std::make_tuple(tokenlocal, token4, token6)); @@ -1424,9 +1516,11 @@ Dht::cancelListen(const InfoHash& id, size_t token) s->listeners.erase(af_token); if (s->listeners.empty()) { for (auto& sn : s->nodes) { - // also erase requests for all searchnodes. - network_engine.cancelRequest(sn.listenStatus); - sn.listenStatus.reset(); + /* also erase requests for all searchnodes. */ + for (auto& ls : sn.listenStatus) { + network_engine.cancelRequest(ls.second); + } + sn.listenStatus.clear(); } } } @@ -1475,69 +1569,138 @@ Dht::put(const InfoHash& id, std::shared_ptr<Value> val, DoneCallback callback, }, created, permanent); } +template <typename T> struct OpStatus { - bool done {false}; - bool ok {false}; + struct Status { + bool done {false}; + bool ok {false}; + }; + Status status; + Status status4; + Status status6; + std::vector<std::shared_ptr<T>> values; + std::vector<std::shared_ptr<Node>> nodes; +}; + +template <typename T> +void doneCallbackWrapper(DoneCallback dcb, const std::vector<std::shared_ptr<Node>>& nodes, std::shared_ptr<OpStatus<T>> op) { + if (op->status.done) + return; + op->nodes.insert(op->nodes.end(), nodes.begin(), nodes.end()); + if (op->status.ok || (op->status4.done and op->status6.done)) { + bool ok = op->status.ok || op->status4.ok || op->status6.ok; + op->status.done = true; + if (dcb) + dcb(ok, op->nodes); + } +}; + +template <typename T, typename Cb> +bool callbackWrapper(Cb get_cb, + DoneCallback done_cb, + const std::vector<std::shared_ptr<T>>& values, + std::function<std::vector<std::shared_ptr<T>>(const std::vector<std::shared_ptr<T>>&)> add_values, + std::shared_ptr<OpStatus<T>> op) +{ + if (op->status.done) + return false; + auto newvals = add_values(values); + if (not newvals.empty()) { + op->status.ok = !get_cb(newvals); + op->values.insert(op->values.end(), newvals.begin(), newvals.end()); + } + doneCallbackWrapper(done_cb, {}, op); + return !op->status.ok; }; void -Dht::get(const InfoHash& id, GetCallback getcb, DoneCallback donecb, Value::Filter&& filter) +Dht::get(const InfoHash& id, GetCallback getcb, DoneCallback donecb, Value::Filter&& filter, Where&& where) { scheduler.syncTime(); - auto status = std::make_shared<OpStatus>(); - auto status4 = std::make_shared<OpStatus>(); - auto status6 = std::make_shared<OpStatus>(); - auto vals = std::make_shared<std::vector<std::shared_ptr<Value>>>(); - auto all_nodes = std::make_shared<std::vector<std::shared_ptr<Node>>>(); + Query q {{}, where}; + auto op = std::make_shared<OpStatus<Value>>(); - auto done_l = [=](const std::vector<std::shared_ptr<Node>>& nodes) { - if (status->done) - return; - all_nodes->insert(all_nodes->end(), nodes.begin(), nodes.end()); - if (status->ok || (status4->done && status6->done)) { - bool ok = status->ok || status4->ok || status6->ok; - status->done = true; - if (donecb) - donecb(ok, *all_nodes); - } - }; - auto cb = [=](const std::vector<std::shared_ptr<Value>>& values) { - if (status->done) - return false; + auto f = filter.chain(q.where.getFilter()); + auto add_values = [=](const std::vector<std::shared_ptr<Value>>& values) { std::vector<std::shared_ptr<Value>> newvals {}; for (const auto& v : values) { - auto it = std::find_if(vals->cbegin(), vals->cend(), [&](const std::shared_ptr<Value>& sv) { - return sv == v || *sv == *v; + auto it = std::find_if(op->values.cbegin(), op->values.cend(), [&](const std::shared_ptr<Value>& sv) { + return sv == v or *sv == *v; }); - if (it == vals->cend()) { - if (!filter || filter(*v)) - newvals.push_back(v); + if (it == op->values.cend()) { + if (not f or f(*v)) + newvals.push_back(v); } } - if (!newvals.empty()) { - status->ok = !getcb(newvals); - vals->insert(vals->end(), newvals.begin(), newvals.end()); - } - done_l({}); - return !status->ok; + return newvals; }; + auto gcb = std::bind(callbackWrapper<Value, GetCallback>, getcb, donecb, _1, add_values, op); /* Try to answer this search locally. */ - cb(getLocal(id, filter)); + gcb(getLocal(id, f)); - Dht::search(id, AF_INET, cb, [=](bool ok, const std::vector<std::shared_ptr<Node>>& nodes) { + Dht::search(id, AF_INET, gcb, {}, [=](bool ok, const std::vector<std::shared_ptr<Node>>& nodes) { //DHT_LOG.WARN("DHT done IPv4"); - status4->done = true; - status4->ok = ok; - done_l(nodes); - }); - Dht::search(id, AF_INET6, cb, [=](bool ok, const std::vector<std::shared_ptr<Node>>& nodes) { + op->status4.done = true; + op->status4.ok = ok; + doneCallbackWrapper(donecb, nodes, op); + }, f, q); + Dht::search(id, AF_INET6, gcb, {}, [=](bool ok, const std::vector<std::shared_ptr<Node>>& nodes) { //DHT_LOG.WARN("DHT done IPv6"); - status6->done = true; - status6->ok = ok; - done_l(nodes); + op->status6.done = true; + op->status6.ok = ok; + doneCallbackWrapper(donecb, nodes, op); + }, f, q); +} + +void Dht::query(const InfoHash& id, QueryCallback cb, DoneCallback done_cb, Query&& q) +{ + scheduler.syncTime(); + auto op = std::make_shared<OpStatus<FieldValueIndex>>(); + + auto f = q.where.getFilter(); + auto values = getLocal(id, f); + auto add_fields = [=](const std::vector<std::shared_ptr<FieldValueIndex>>& fields) { + std::vector<std::shared_ptr<FieldValueIndex>> newvals {}; + for (const auto& f : fields) { + auto it = std::find_if(op->values.cbegin(), op->values.cend(), + [&](const std::shared_ptr<FieldValueIndex>& sf) { + return sf == f or f->containedIn(*sf); + }); + if (it == op->values.cend()) { + auto lesser = std::find_if(op->values.begin(), op->values.end(), + [&](const std::shared_ptr<FieldValueIndex>& sf) { + return sf->containedIn(*f); + }); + if (lesser != op->values.end()) + op->values.erase(lesser); + newvals.push_back(f); + } + } + return newvals; + }; + std::vector<std::shared_ptr<FieldValueIndex>> local_fields(values.size()); + std::transform(values.begin(), values.end(), local_fields.begin(), [](const std::shared_ptr<Value>& v) { + return std::make_shared<FieldValueIndex>(*v); }); + auto qcb = std::bind(callbackWrapper<FieldValueIndex, QueryCallback>, cb, done_cb, _1, add_fields, op); + + /* Try to answer this search locally. */ + qcb(local_fields); + + Dht::search(id, AF_INET, {}, qcb, [=](bool ok, const std::vector<std::shared_ptr<Node>>& nodes) { + //DHT_LOG.WARN("DHT done IPv4"); + op->status4.done = true; + op->status4.ok = ok; + doneCallbackWrapper(done_cb, nodes, op); + }, f, q); + Dht::search(id, AF_INET6, {}, qcb, [=](bool ok, const std::vector<std::shared_ptr<Node>>& nodes) { + //DHT_LOG.WARN("DHT done IPv6"); + op->status6.done = true; + op->status6.ok = ok; + doneCallbackWrapper(done_cb, nodes, op); + }, f, q); } std::vector<std::shared_ptr<Value>> @@ -1658,10 +1821,14 @@ Dht::storageChanged(Storage& st, ValueStorage& v) for (const auto& l : st.listeners) { DHT_LOG.DEBUG("Storage changed. Sending update to %s.", l.first->toString().c_str()); - std::vector<std::shared_ptr<Value>> vals; + auto f = l.second.query.where.getFilter(); + if (f and not f(*v.data)) + continue; + std::vector<std::shared_ptr<Value>> vals {}; vals.push_back(v.data); Blob ntoken = makeToken((const sockaddr*)&l.first->ss, false); - network_engine.tellListener(l.first, l.second.rid, st.id, 0, ntoken, {}, {}, vals); + network_engine.tellListener(l.first, l.second.rid, st.id, 0, ntoken, {}, {}, + std::move(vals), l.second.query); } } @@ -1724,7 +1891,7 @@ Dht::Storage::clear() } void -Dht::storageAddListener(const InfoHash& id, const std::shared_ptr<Node>& node, size_t rid) +Dht::storageAddListener(const InfoHash& id, const std::shared_ptr<Node>& node, size_t rid, Query&& query) { const auto& now = scheduler.time(); auto st = findStorage(id); @@ -1736,16 +1903,13 @@ Dht::storageAddListener(const InfoHash& id, const std::shared_ptr<Node>& node, s } auto l = st->listeners.find(node); if (l == st->listeners.end()) { - const auto& stvalues = st->getValues(); - if (not stvalues.empty()) { - std::vector<std::shared_ptr<Value>> values(stvalues.size()); - std::transform(stvalues.begin(), stvalues.end(), values.begin(), [=](const ValueStorage& vs) { return vs.data; }); - + auto vals = st->get(query.where.getFilter()); + if (not vals.empty()) { network_engine.tellListener(node, rid, id, WANT4 | WANT6, makeToken((sockaddr*)&node->ss, false), buckets.findClosestNodes(id, now, TARGET_NODES), buckets6.findClosestNodes(id, now, TARGET_NODES), - values); + std::move(vals), query); } - st->listeners.emplace(node, Listener {rid, now}); + st->listeners.emplace(node, Listener {rid, now, std::forward<Query>(query)}); } else l->second.refresh(rid, now); @@ -1813,7 +1977,7 @@ Dht::connectivityChanged() auto stop_listen = [&](std::map<InfoHash, std::shared_ptr<Search>> srs) { for (auto& sp : srs) for (auto& sn : sp.second->nodes) - sn.listenStatus.reset(); + sn.listenStatus.clear(); }; stop_listen(searches4); stop_listen(searches6); @@ -1957,6 +2121,16 @@ Dht::dumpSearch(const Search& sr, std::ostream& out) const } out << std::endl; + /*printing the queries*/ + if (sr.callbacks.size() + sr.listeners.size() > 0) + out << "Queries:" << std::endl; + for (const auto& cb : sr.callbacks) { + out << *cb.second.query << std::endl; + } + for (const auto& l : sr.listeners) { + out << *l.second.query << std::endl; + } + for (const auto& n : sr.announce) { bool announced = sr.isAnnounced(n.value->id, getType(n.value->type), now); out << "Announcement: " << *n.value << (announced ? " [announced]" : "") << std::endl; @@ -1978,18 +2152,18 @@ Dht::dumpSearch(const Search& sr, std::ostream& out) const // Get status { - char g_i = (n.getStatus && n.getStatus->pending()) ? (n.candidate ? 'c' : 'f') : ' '; + char g_i = n.pending(n.getStatus) ? (n.candidate ? 'c' : 'f') : ' '; char s_i = n.isSynced(now) ? (n.last_get_reply > last_get ? 'u' : 's') : '-'; out << " [" << s_i << g_i << "] "; } // Listen status if (not sr.listeners.empty()) { - if (not n.listenStatus) + if (n.listenStatus.empty()) out << " "; else out << "[" - << (n.isListening(now) ? 'l' : (n.listenStatus->pending() ? 'f' : ' ')) << "] "; + << (n.isListening(now) ? 'l' : (n.pending(n.listenStatus) ? 'f' : ' ')) << "] "; } // Announce status @@ -2108,8 +2282,8 @@ Dht::Dht(int s, int s6, Config config) std::bind(&Dht::onReportedAddr, this, _1, _2, _3), std::bind(&Dht::onPing, this, _1), std::bind(&Dht::onFindNode, this, _1, _2, _3), - std::bind(&Dht::onGetValues, this, _1, _2, _3), - std::bind(&Dht::onListen, this, _1, _2, _3, _4), + std::bind(&Dht::onGetValues, this, _1, _2, _3, _4), + std::bind(&Dht::onListen, this, _1, _2, _3, _4, _5), std::bind(&Dht::onAnnounce, this, _1, _2, _3, _4, _5)) { scheduler.syncTime(); @@ -2539,7 +2713,7 @@ Dht::onFindNode(std::shared_ptr<Node> node, InfoHash& target, want_t want) } NetworkEngine::RequestAnswer -Dht::onGetValues(std::shared_ptr<Node> node, InfoHash& hash, want_t) +Dht::onGetValues(std::shared_ptr<Node> node, InfoHash& hash, want_t, const Query& query) { if (hash == zeroes) { DHT_LOG.WARN("[node %s] Eek! Got get_values with no info_hash.", node->toString().c_str()); @@ -2552,11 +2726,7 @@ Dht::onGetValues(std::shared_ptr<Node> node, InfoHash& hash, want_t) answer.nodes4 = buckets.findClosestNodes(hash, now, TARGET_NODES); answer.nodes6 = buckets6.findClosestNodes(hash, now, TARGET_NODES); if (st != store.end() && not st->empty()) { - auto values = st->getValues(); - answer.values.resize(values.size()); - std::transform(values.begin(), values.end(), answer.values.begin(), [](const ValueStorage& vs) { - return vs.data; - }); + answer.values = st->get(query.where.getFilter()); DHT_LOG.DEBUG("[node %s] sending %u values.", node->toString().c_str(), answer.values.size()); } else { DHT_LOG.DEBUG("[node %s] sending nodes.", node->toString().c_str()); @@ -2566,34 +2736,52 @@ Dht::onGetValues(std::shared_ptr<Node> node, InfoHash& hash, want_t) void Dht::onGetValuesDone(const Request& status, - NetworkEngine::RequestAnswer& a, std::shared_ptr<Search> sr) + NetworkEngine::RequestAnswer& a, std::shared_ptr<Search> sr, const std::shared_ptr<Query>& orig_query) { if (not sr) { DHT_LOG.WARN("[search unknown] got reply to 'get'. Ignoring."); return; } - DHT_LOG.DEBUG("[search %s IPv%c] got reply to 'get' from %s with %u nodes", sr->id.toString().c_str(), sr->af == AF_INET ? '4' : '6', status.node->toString().c_str(), a.nodes4.size()); + DHT_LOG.DEBUG("[search %s IPv%c] got reply to 'get' from %s with %u nodes", + sr->id.toString().c_str(), sr->af == AF_INET ? '4' : '6', status.node->toString().c_str(), a.nodes4.size()); if (not a.ntoken.empty()) { - if (!a.values.empty()) { + if (not a.values.empty() or not a.fields.empty()) { DHT_LOG.DEBUG("[search %s IPv%c] found %u values", sr->id.toString().c_str(), sr->af == AF_INET ? '4' : '6', a.values.size()); - for (auto& cb : sr->callbacks) { - if (!cb.get_cb) continue; - std::vector<std::shared_ptr<Value>> tmp; - std::copy_if(a.values.begin(), a.values.end(), std::back_inserter(tmp), - [&](const std::shared_ptr<Value>& v) { - return not static_cast<bool>(cb.filter) or cb.filter(*v); + for (auto& getp : sr->callbacks) { + auto& get = getp.second; + if (not (get.get_cb or get.query_cb) or + (orig_query and get.query and not get.query->isSatisfiedBy(*orig_query))) + continue; + if (get.query_cb) { + if (not a.fields.empty()) { + get.query_cb(a.fields); + } else if (not a.values.empty()) { + std::vector<std::shared_ptr<FieldValueIndex>> fields(a.values.size()); + std::transform(a.values.begin(), a.values.end(), fields.begin(), + [&](const std::shared_ptr<Value>& v) { + return std::make_shared<FieldValueIndex>(*v, orig_query ? orig_query->select : Select {}); + }); + get.query_cb(fields); } - ); - if (not tmp.empty()) - cb.get_cb(tmp); + } else if (get.get_cb) { + std::vector<std::shared_ptr<Value>> tmp; + std::copy_if(a.values.begin(), a.values.end(), std::back_inserter(tmp), + [&](const std::shared_ptr<Value>& v) { + return not static_cast<bool>(get.filter) or get.filter(*v); + } + ); + if (not tmp.empty()) + get.get_cb(tmp); + } } std::vector<std::pair<GetCallback, std::vector<std::shared_ptr<Value>>>> tmp_lists; for (auto& l : sr->listeners) { - if (!l.second.get_cb) continue; + if (!l.second.get_cb or (orig_query and l.second.query and not l.second.query->isSatisfiedBy(*orig_query))) + continue; std::vector<std::shared_ptr<Value>> tmp; std::copy_if(a.values.begin(), a.values.end(), std::back_inserter(tmp), [&](const std::shared_ptr<Value>& v) { @@ -2620,7 +2808,7 @@ Dht::onGetValuesDone(const Request& status, } NetworkEngine::RequestAnswer -Dht::onListen(std::shared_ptr<Node> node, InfoHash& hash, Blob& token, size_t rid) +Dht::onListen(std::shared_ptr<Node> node, InfoHash& hash, Blob& token, size_t rid, Query&& query) { if (hash == zeroes) { DHT_LOG.WARN("Listen with no info_hash."); @@ -2633,18 +2821,18 @@ Dht::onListen(std::shared_ptr<Node> node, InfoHash& hash, Blob& token, size_t ri DHT_LOG.WARN("[node %s] incorrect token %s for 'listen'.", node->toString().c_str(), hash.toString().c_str()); throw DhtProtocolException {DhtProtocolException::UNAUTHORIZED, DhtProtocolException::LISTEN_WRONG_TOKEN}; } - storageAddListener(hash, node, rid); + storageAddListener(hash, node, rid, std::forward<Query>(query)); return {}; } void -Dht::onListenDone(const Request& status, NetworkEngine::RequestAnswer& answer, std::shared_ptr<Search>& sr) +Dht::onListenDone(const Request& status, NetworkEngine::RequestAnswer& answer, std::shared_ptr<Search>& sr, const std::shared_ptr<Query>& orig_query) { DHT_LOG.DEBUG("[search %s] Got reply to listen.", sr->id.toString().c_str()); if (sr) { if (not answer.values.empty()) { /* got new values from listen request */ DHT_LOG.DEBUG("[listen %s] Got new values.", sr->id.toString().c_str()); - onGetValuesDone(status, answer, sr); + onGetValuesDone(status, answer, sr, orig_query); } if (not sr->done) { diff --git a/src/network_engine.cpp b/src/network_engine.cpp index 159dc5c5cd51b5c940e68f542c7d504aff354201..56be1a69d0edf716a20a44eec732801ac0075432 100644 --- a/src/network_engine.cpp +++ b/src/network_engine.cpp @@ -21,6 +21,7 @@ #include "network_engine.h" #include "request.h" +#include "default_types.h" #include <msgpack.hpp> @@ -63,36 +64,40 @@ enum class MessageType { struct ParsedMessage { MessageType type; - InfoHash id; /* the id of the sender */ - NetId network {0}; /* network id */ - InfoHash info_hash; /* hash for which values are requested */ - InfoHash target; /* target id around which to find nodes */ - NetworkEngine::TransId tid; /* transaction id */ - Blob token; /* security token */ - Value::Id value_id; /* the value id */ - time_point created { time_point::max() }; /* time when value was first created */ - Blob nodes4_raw, nodes6_raw; /* IPv4 nodes in response to a 'find' request */ + InfoHash id; /* the id of the sender */ + NetId network {0}; /* network id */ + InfoHash info_hash; /* hash for which values are requested */ + InfoHash target; /* target id around which to find nodes */ + NetworkEngine::TransId tid; /* transaction id */ + Blob token; /* security token */ + Value::Id value_id; /* the value id */ + time_point created { time_point::max() }; /* time when value was first created */ + Blob nodes4_raw, nodes6_raw; /* IPv4 nodes in response to a 'find' request */ std::vector<std::shared_ptr<Node>> nodes4, nodes6; - std::vector<std::shared_ptr<Value>> values; /* values for a 'get' request */ - want_t want; /* states if ipv4 or ipv6 request */ - uint16_t error_code; /* error code in case of error */ + std::vector<std::shared_ptr<Value>> values; /* values for a 'get' request */ + std::vector<std::shared_ptr<FieldValueIndex>> fields; /* index for fields values */ + Query query; /* query describing a filter to apply on values. */ + want_t want; /* states if ipv4 or ipv6 request */ + uint16_t error_code; /* error code in case of error */ std::string ua; - Address addr; /* reported address by the distant node */ + Address addr; /* reported address by the distant node */ void msgpack_unpack(msgpack::object o); }; NetworkEngine::RequestAnswer::RequestAnswer(ParsedMessage&& msg) - : ntoken(std::move(msg.token)), values(std::move(msg.values)), nodes4(std::move(msg.nodes4)), nodes6(std::move(msg.nodes6)) {} + : ntoken(std::move(msg.token)), values(std::move(msg.values)), fields(std::move(msg.fields)), + nodes4(std::move(msg.nodes4)), nodes6(std::move(msg.nodes6)) {} void -NetworkEngine::tellListener(std::shared_ptr<Node> node, uint16_t rid, InfoHash hash, want_t want, - Blob ntoken, std::vector<std::shared_ptr<Node>> nodes, std::vector<std::shared_ptr<Node>> nodes6, - std::vector<std::shared_ptr<Value>> values) +NetworkEngine::tellListener(std::shared_ptr<Node> node, uint16_t rid, const InfoHash& hash, want_t want, + const Blob& ntoken, std::vector<std::shared_ptr<Node>>&& nodes, + std::vector<std::shared_ptr<Node>>&& nodes6, std::vector<std::shared_ptr<Value>>&& values, + const Query& query) { auto nnodes = bufferNodes(node->getFamily(), hash, want, nodes, nodes6); try { sendNodesValues((const sockaddr*)&node->ss, node->sslen, TransId {TransPrefix::GET_VALUES, (uint16_t)rid}, nnodes.first, nnodes.second, - values, ntoken); + values, query, ntoken); } catch (const std::overflow_error& e) { DHT_LOG.ERR("Can't send value: buffer not large enough !"); } @@ -380,7 +385,7 @@ NetworkEngine::processMessage(const uint8_t *buf, size_t buflen, const sockaddr* requests.erase(reqp); req->reply_time = scheduler.time(); - deserializeNodesValues(msg); + deserializeNodes(msg); req->setDone(std::move(msg)); break; default: @@ -404,16 +409,16 @@ NetworkEngine::processMessage(const uint8_t *buf, size_t buflen, const sockaddr* ++in_stats.find; RequestAnswer answer = onFindNode(node, msg.target, msg.want); auto nnodes = bufferNodes(from->sa_family, msg.target, msg.want, answer.nodes4, answer.nodes6); - sendNodesValues(from, fromlen, msg.tid, nnodes.first, nnodes.second, {}, answer.ntoken); + sendNodesValues(from, fromlen, msg.tid, nnodes.first, nnodes.second, {}, {}, answer.ntoken); break; } case MessageType::GetValues: { DHT_LOG.DEBUG("[node %s %s] got 'get' request for %s.", msg.id.toString().c_str(), print_addr(from, fromlen).c_str(), msg.info_hash.toString().c_str()); ++in_stats.get; - RequestAnswer answer = onGetValues(node, msg.info_hash, msg.want); + RequestAnswer answer = onGetValues(node, msg.info_hash, msg.want, msg.query); auto nnodes = bufferNodes(from->sa_family, msg.info_hash, msg.want, answer.nodes4, answer.nodes6); - sendNodesValues(from, fromlen, msg.tid, nnodes.first, nnodes.second, answer.values, answer.ntoken); + sendNodesValues(from, fromlen, msg.tid, nnodes.first, nnodes.second, answer.values, msg.query, answer.ntoken); break; } case MessageType::AnnounceValue: { @@ -435,7 +440,7 @@ NetworkEngine::processMessage(const uint8_t *buf, size_t buflen, const sockaddr* DHT_LOG.DEBUG("[node %s %s] got 'listen' request for %s.", msg.id.toString().c_str(), print_addr(from, fromlen).c_str(), msg.info_hash.toString().c_str()); ++in_stats.listen; - RequestAnswer answer = onListen(node, msg.info_hash, msg.token, msg.tid.getTid()); + RequestAnswer answer = onListen(node, msg.info_hash, msg.token, msg.tid.getTid(), std::move(msg.query)); sendListenConfirmation(from, fromlen, msg.tid); break; } @@ -594,16 +599,19 @@ NetworkEngine::sendFindNode(std::shared_ptr<Node> n, const InfoHash& target, wan std::shared_ptr<Request> -NetworkEngine::sendGetValues(std::shared_ptr<Node> n, const InfoHash& info_hash, want_t want, +NetworkEngine::sendGetValues(std::shared_ptr<Node> n, const InfoHash& info_hash, const Query& query, want_t want, RequestCb on_done, RequestExpiredCb on_expired) { auto tid = TransId {TransPrefix::GET_VALUES, getNewTid()}; msgpack::sbuffer buffer; msgpack::packer<msgpack::sbuffer> pk(&buffer); pk.pack_map(5+(network?1:0)); - pk.pack(std::string("a")); pk.pack_map(2 + (want>0?1:0)); + pk.pack(std::string("a")); pk.pack_map(2 + + (query.where.getFilter() or not query.select.getSelection().empty() ? 1:0) + + (want>0?1:0)); pk.pack(std::string("id")); pk.pack(myid); pk.pack(std::string("h")); pk.pack(info_hash); + pk.pack(std::string("q")); pk.pack(query); if (want > 0) { pk.pack(std::string("w")); pk.pack_array(((want & WANT4)?1:0) + ((want & WANT6)?1:0)); @@ -639,7 +647,7 @@ NetworkEngine::sendGetValues(std::shared_ptr<Node> n, const InfoHash& info_hash, } void -NetworkEngine::deserializeNodesValues(ParsedMessage& msg) { +NetworkEngine::deserializeNodes(ParsedMessage& msg) { if (msg.nodes4_raw.size() % NODE4_INFO_BUF_LEN != 0 || msg.nodes6_raw.size() % NODE6_INFO_BUF_LEN != 0) { throw DhtProtocolException {DhtProtocolException::WRONG_NODE_INFO_BUF_LEN}; } else { @@ -680,7 +688,7 @@ NetworkEngine::deserializeNodesValues(ParsedMessage& msg) { void NetworkEngine::sendNodesValues(const sockaddr* sa, socklen_t salen, TransId tid, const Blob& nodes, const Blob& nodes6, - const std::vector<std::shared_ptr<Value>>& st, const Blob& token) { + const std::vector<std::shared_ptr<Value>>& st, const Query& query, const Blob& token) { msgpack::sbuffer buffer; msgpack::packer<msgpack::sbuffer> pk(&buffer); pk.pack_map(4+(network?1:0)); @@ -702,34 +710,50 @@ NetworkEngine::sendNodesValues(const sockaddr* sa, socklen_t salen, TransId tid, if (not token.empty()) { pk.pack(std::string("token")); packToken(pk, token); } - if (not st.empty()) { - // We treat the storage as a circular list, and serve a randomly - // chosen slice. In order to make sure we fit, - // we limit ourselves to 50 values. - std::uniform_int_distribution<> pos_dis(0, st.size()-1); - std::vector<Blob> subset {}; - subset.reserve(std::min<size_t>(st.size(), 50)); - + if (not st.empty()) { /* pack complete values */ + auto fields = query.select.getSelection(); size_t total_size = 0; - unsigned j0 = pos_dis(rd_device); - unsigned j = j0; - unsigned k = 0; - - do { - subset.emplace_back(packMsg(st[j])); - total_size += subset.back().size(); - ++k; - j = (j + 1) % st.size(); - } while (j != j0 && k < 50 && total_size < MAX_VALUE_SIZE); - - pk.pack(std::string("values")); - pk.pack_array(subset.size()); - for (const auto& b : subset) - buffer.write((const char*)b.data(), b.size()); - DHT_LOG.DEBUG("sending closest nodes (%d+%d nodes.), %lu bytes of values", nodes.size(), nodes6.size(), total_size); + if (fields.empty()) { + // We treat the storage as a circular list, and serve a randomly + // chosen slice. In order to make sure we fit, + // we limit ourselves to 50 values. + std::uniform_int_distribution<> pos_dis(0, st.size()-1); + std::vector<Blob> subset {}; + subset.reserve(std::min<size_t>(st.size(), 50)); + + unsigned j0 = pos_dis(rd_device); + unsigned j = j0; + unsigned k = 0; + + do { + subset.emplace_back(packMsg(st[j])); + total_size += subset.back().size(); + ++k; + j = (j + 1) % st.size(); + } while (j != j0 && k < 50 && total_size < MAX_VALUE_SIZE); + + pk.pack(std::string("values")); + pk.pack_array(subset.size()); + for (const auto& b : subset) + buffer.write((const char*)b.data(), b.size()); + DHT_LOG.DEBUG("sending closest nodes (%d+%d nodes.), %lu bytes of values", + nodes.size(), nodes6.size(), total_size); + } else { /* pack fields */ + pk.pack(std::string("fields")); + pk.pack_map(2); + pk.pack(std::string("f")); pk.pack(fields); + pk.pack(std::string("v")); pk.pack_array(st.size()*fields.size()); + for (const auto& v : st) { + v->msgpack_pack_fields(fields, pk); + } + DHT_LOG.DEBUG("sending closest nodes (%d+%d nodes.), %u value headers containing %u fields", + nodes.size(), nodes6.size(), st.size(), fields.size()); + } } else DHT_LOG.DEBUG("sending closest nodes (%d+%d nodes.)", nodes.size(), nodes6.size()); + DHT_LOG.DEBUG("sending closest nodes (%d+%d nodes.)", nodes.size(), nodes6.size()); + pk.pack(std::string("t")); pk.pack_bin(tid.size()); pk.pack_bin_body((const char*)tid.data(), tid.size()); pk.pack(std::string("y")); pk.pack(std::string("r")); @@ -794,16 +818,18 @@ NetworkEngine::bufferNodes(sa_family_t af, const InfoHash& id, want_t want, } std::shared_ptr<Request> -NetworkEngine::sendListen(std::shared_ptr<Node> n, const InfoHash& infohash, const Blob& token, +NetworkEngine::sendListen(std::shared_ptr<Node> n, const InfoHash& infohash, const Query& query, const Blob& token, RequestCb on_done, RequestExpiredCb on_expired) { auto tid = TransId {TransPrefix::LISTEN, getNewTid()}; msgpack::sbuffer buffer; msgpack::packer<msgpack::sbuffer> pk(&buffer); pk.pack_map(5+(network?1:0)); - pk.pack(std::string("a")); pk.pack_map(3); + pk.pack(std::string("a")); pk.pack_map(3 + + (query.where.getFilter() or not query.select.getSelection().empty() ? 1:0)); pk.pack(std::string("id")); pk.pack(myid); pk.pack(std::string("h")); pk.pack(infohash); + pk.pack(std::string("q")); pk.pack(query); pk.pack(std::string("token")); packToken(pk, token); pk.pack(std::string("q")); pk.pack(std::string("listen")); @@ -961,17 +987,36 @@ void ParsedMessage::msgpack_unpack(msgpack::object msg) { auto y = findMapValue(msg, "y"); - auto a = findMapValue(msg, "a"); auto r = findMapValue(msg, "r"); auto e = findMapValue(msg, "e"); - std::string query; - if (auto q = findMapValue(msg, "q")) { - if (q->type != msgpack::type::STR) + std::string q; + if (auto rq = findMapValue(msg, "q")) { + if (rq->type != msgpack::type::STR) throw msgpack::type_error(); - query = q->as<std::string>(); + q = rq->as<std::string>(); } + if (e) + type = MessageType::Error; + else if (r) + type = MessageType::Reply; + else if (y and y->as<std::string>() != "q") + throw msgpack::type_error(); + else if (q == "ping") + type = MessageType::Ping; + else if (q == "find") + type = MessageType::FindNode; + else if (q == "get") + type = MessageType::GetValues; + else if (q == "listen") + type = MessageType::Listen; + else if (q == "put") + type = MessageType::AnnounceValue; + else + throw msgpack::type_error(); + + auto a = findMapValue(msg, "a"); if (!a && !r && !e) throw msgpack::type_error(); auto& req = a ? *a : (r ? *r : *e); @@ -994,6 +1039,9 @@ ParsedMessage::msgpack_unpack(msgpack::object msg) if (auto rtarget = findMapValue(req, "target")) target = {*rtarget}; + if (auto rquery = findMapValue(req, "q")) + query.msgpack_unpack(*rquery); + if (auto otoken = findMapValue(req, "token")) token = unpackBlob(*otoken); @@ -1040,6 +1088,23 @@ ParsedMessage::msgpack_unpack(msgpack::object msg) } catch (const std::exception& e) { //DHT_LOG.WARN("Error reading value: %s", e.what()); } + } else if (auto raw_fields = findMapValue(req, "fields")) { + if (auto rfields = findMapValue(*raw_fields, "f")) { + auto fields_ = rfields->as<std::set<Value::Field>>(); + if (auto rvalues = findMapValue(*raw_fields, "v")) { + if (rvalues->type != msgpack::type::ARRAY) + throw msgpack::type_error(); + for (size_t i = 0; i < rvalues->via.array.size; ++i) { + try { + auto v = std::make_shared<FieldValueIndex>(); + v->msgpack_unpack_fields(fields_, *rvalues, i*fields.size()); + fields.emplace_back(std::move(v)); + } catch (const std::exception& e) { } + } + } + } else { + throw msgpack::type_error(); + } } if (auto w = findMapValue(req, "w")) { @@ -1066,24 +1131,6 @@ ParsedMessage::msgpack_unpack(msgpack::object msg) if (auto rv = findMapValue(msg, "v")) ua = rv->as<std::string>(); - if (e) - type = MessageType::Error; - else if (r) - type = MessageType::Reply; - else if (y and y->as<std::string>() != "q") - throw msgpack::type_error(); - else if (query == "ping") - type = MessageType::Ping; - else if (query == "find") - type = MessageType::FindNode; - else if (query == "get") - type = MessageType::GetValues; - else if (query == "listen") - type = MessageType::Listen; - else if (query == "put") - type = MessageType::AnnounceValue; - else - throw msgpack::type_error(); } } diff --git a/src/securedht.cpp b/src/securedht.cpp index 2c37ad6e85e3120f7f7672f8590630b0258b242d..8e1ba6f7828db86e5260f0c924124039e2cd2890 100644 --- a/src/securedht.cpp +++ b/src/securedht.cpp @@ -291,15 +291,15 @@ SecureDht::getCallbackFilter(GetCallback cb, Value::Filter&& filter) } void -SecureDht::get(const InfoHash& id, GetCallback cb, DoneCallback donecb, Value::Filter&& f) +SecureDht::get(const InfoHash& id, GetCallback cb, DoneCallback donecb, Value::Filter&& f, Where&& w) { - Dht::get(id, getCallbackFilter(cb, std::forward<Value::Filter>(f)), donecb); + Dht::get(id, getCallbackFilter(cb, std::forward<Value::Filter>(f)), donecb, {}, std::forward<Where>(w)); } size_t -SecureDht::listen(const InfoHash& id, GetCallback cb, Value::Filter&& f) +SecureDht::listen(const InfoHash& id, GetCallback cb, Value::Filter&& f, Where&& w) { - return Dht::listen(id, getCallbackFilter(cb, std::forward<Value::Filter>(f))); + return Dht::listen(id, getCallbackFilter(cb, std::forward<Value::Filter>(f)), {}, std::forward<Where>(w)); } void