Skip to content
Snippets Groups Projects
Commit a4834efd authored by Simon Désaulniers's avatar Simon Désaulniers
Browse files

dht: use queries in dht

parent 07e4435a
Branches
Tags
No related merge requests found
...@@ -70,6 +70,7 @@ static constexpr size_t DEFAULT_STORAGE_LIMIT {1024 * 1024 * 64}; ...@@ -70,6 +70,7 @@ static constexpr size_t DEFAULT_STORAGE_LIMIT {1024 * 1024 * 64};
using ValuesExport = std::pair<InfoHash, Blob>; using ValuesExport = std::pair<InfoHash, Blob>;
using QueryCallback = std::function<bool(const std::vector<std::shared_ptr<FieldValueIndex>>& fields)>;
using GetCallback = std::function<bool(const std::vector<std::shared_ptr<Value>>& values)>; using GetCallback = std::function<bool(const std::vector<std::shared_ptr<Value>>& values)>;
using GetCallbackSimple = std::function<bool(std::shared_ptr<Value> value)>; using GetCallbackSimple = std::function<bool(std::shared_ptr<Value> value)>;
using ShutdownCallback = std::function<void()>; using ShutdownCallback = std::function<void()>;
......
...@@ -136,15 +136,29 @@ public: ...@@ -136,15 +136,29 @@ public:
cb and donecb won't be called again afterward. cb and donecb won't be called again afterward.
* @param f a filter function used to prefilter values. * @param f a filter function used to prefilter values.
*/ */
virtual void get(const InfoHash& key, GetCallback cb, DoneCallback donecb={}, Value::Filter&& f={}); virtual void get(const InfoHash& key, GetCallback cb, DoneCallback donecb={}, Value::Filter&& f={}, Where&& w = {});
virtual void get(const InfoHash& key, GetCallback cb, DoneCallbackSimple donecb={}, Value::Filter&& f={}) { virtual void get(const InfoHash& key, GetCallback cb, DoneCallbackSimple donecb={}, Value::Filter&& f={}, Where&& w = {}) {
get(key, cb, bindDoneCb(donecb), std::forward<Value::Filter>(f)); get(key, cb, bindDoneCb(donecb), std::forward<Value::Filter>(f), std::forward<Where>(w));
} }
virtual void get(const InfoHash& key, GetCallbackSimple cb, DoneCallback donecb={}, Value::Filter&& f={}) { virtual void get(const InfoHash& key, GetCallbackSimple cb, DoneCallback donecb={}, Value::Filter&& f={}, Where&& w = {}) {
get(key, bindGetCb(cb), donecb, std::forward<Value::Filter>(f)); get(key, bindGetCb(cb), donecb, std::forward<Value::Filter>(f), std::forward<Where>(w));
} }
virtual void get(const InfoHash& key, GetCallbackSimple cb, DoneCallbackSimple donecb, Value::Filter&& f={}) { virtual void get(const InfoHash& key, GetCallbackSimple cb, DoneCallbackSimple donecb, Value::Filter&& f={}, Where&& w = {}) {
get(key, bindGetCb(cb), bindDoneCb(donecb), std::forward<Value::Filter>(f)); get(key, bindGetCb(cb), bindDoneCb(donecb), std::forward<Value::Filter>(f), std::forward<Where>(w));
}
/**
* Similar to Dht::get, but sends a Query to filter data remotely.
* @param key the key for which to query data for.
* @param cb a function called when new values are found on the network.
* It should return false to stop the operation.
* @param donecb a function called when the operation is complete.
cb and donecb won't be called again afterward.
* @param q a query used to filter values on the remotes before they send a
* response.
*/
virtual void query(const InfoHash& key, QueryCallback cb, DoneCallback done_cb = {}, Query&& q = {});
virtual void query(const InfoHash& key, QueryCallback cb, DoneCallbackSimple done_cb = {}, Query&& q = {}) {
query(key, cb, bindDoneCb(done_cb), std::forward<Query>(q));
} }
/** /**
...@@ -158,14 +172,10 @@ public: ...@@ -158,14 +172,10 @@ public:
std::shared_ptr<Value> getLocalById(const InfoHash& key, Value::Id vid) const; std::shared_ptr<Value> getLocalById(const InfoHash& key, Value::Id vid) const;
/** /**
* Announce a value on all available protocols (IPv4, IPv6), and * Announce a value on all available protocols (IPv4, IPv6).
* automatically re-announce when it's about to expire. *
* The operation will start as soon as the node is connected to the network. * The operation will start as soon as the node is connected to the network.
* The done callback will be called once, when the first announce succeeds, or fails. * The done callback will be called once, when the first announce succeeds, or fails.
*
* A "put" operation will never end by itself because the value will need to be
* reannounced on a regular basis.
* User can call #cancelPut(InfoHash, Value::Id) to cancel a put operation.
*/ */
void put(const InfoHash& key, void put(const InfoHash& key,
std::shared_ptr<Value>, std::shared_ptr<Value>,
...@@ -221,9 +231,9 @@ public: ...@@ -221,9 +231,9 @@ public:
* *
* @return a token to cancel the listener later. * @return a token to cancel the listener later.
*/ */
virtual size_t listen(const InfoHash&, GetCallback, Value::Filter&&={}); virtual size_t listen(const InfoHash&, GetCallback, Value::Filter&&={}, Where&& w = {});
virtual size_t listen(const InfoHash& key, GetCallbackSimple cb, Value::Filter f={}) { virtual size_t listen(const InfoHash& key, GetCallbackSimple cb, Value::Filter f={}, Where w = {}) {
return listen(key, bindGetCb(cb), std::forward<Value::Filter>(f)); return listen(key, bindGetCb(cb), std::forward<Value::Filter>(f), std::forward<Where>(w));
} }
virtual bool cancelListen(const InfoHash&, size_t token); virtual bool cancelListen(const InfoHash&, size_t token);
...@@ -317,6 +327,8 @@ private: ...@@ -317,6 +327,8 @@ private:
struct Get { struct Get {
time_point start; time_point start;
Value::Filter filter; Value::Filter filter;
std::shared_ptr<Query> query;
QueryCallback query_cb;
GetCallback get_cb; GetCallback get_cb;
DoneCallback done_cb; DoneCallback done_cb;
}; };
...@@ -335,6 +347,7 @@ private: ...@@ -335,6 +347,7 @@ private:
* A single "listen" operation data * A single "listen" operation data
*/ */
struct LocalListener { struct LocalListener {
std::shared_ptr<Query> query;
Value::Filter filter; Value::Filter filter;
GetCallback get_cb; GetCallback get_cb;
}; };
...@@ -359,8 +372,9 @@ private: ...@@ -359,8 +372,9 @@ private:
struct Listener { struct Listener {
size_t rid {}; size_t rid {};
time_point time {}; time_point time {};
Query query {};
/*constexpr*/ Listener(size_t rid, time_point t) : rid(rid), time(t) {} /*constexpr*/ Listener(size_t rid, time_point t, Query&& q) : rid(rid), time(t), query(q) {}
void refresh(size_t tid, time_point t) { void refresh(size_t tid, time_point t) {
rid = tid; rid = tid;
...@@ -427,7 +441,7 @@ private: ...@@ -427,7 +441,7 @@ private:
decltype(Dht::store)::iterator findStorage(const InfoHash& id); decltype(Dht::store)::iterator findStorage(const InfoHash& id);
decltype(Dht::store)::const_iterator findStorage(const InfoHash& id) const; decltype(Dht::store)::const_iterator findStorage(const InfoHash& id) const;
void storageAddListener(const InfoHash& id, const std::shared_ptr<Node>& node, size_t tid); void storageAddListener(const InfoHash& id, const std::shared_ptr<Node>& node, size_t tid, Query&& = {});
bool storageStore(const InfoHash& id, const std::shared_ptr<Value>& value, time_point created); bool storageStore(const InfoHash& id, const std::shared_ptr<Value>& value, time_point created);
void expireStorage(); void expireStorage();
void storageChanged(Storage& st, ValueStorage&); void storageChanged(Storage& st, ValueStorage&);
...@@ -472,13 +486,13 @@ private: ...@@ -472,13 +486,13 @@ private:
// Searches // Searches
/** /**
* Low-level method that will perform a search on the DHT for the * Low-level method that will perform a search on the DHT for the specified
* specified infohash (id), using the specified IP version (IPv4 or IPv6). * infohash (id), using the specified IP version (IPv4 or IPv6).
* The values can be filtered by an arbitrary provided filter.
*/ */
std::shared_ptr<Search> search(const InfoHash& id, sa_family_t af, GetCallback = nullptr, DoneCallback = nullptr, Value::Filter = Value::AllFilter()); std::shared_ptr<Search> search(const InfoHash& id, sa_family_t af, GetCallback = {}, QueryCallback = {}, DoneCallback = {}, Value::Filter = {}, Query q = {});
void announce(const InfoHash& id, sa_family_t af, std::shared_ptr<Value> value, DoneCallback callback, time_point created=time_point::max(), bool permanent = false); void announce(const InfoHash& id, sa_family_t af, std::shared_ptr<Value> value, DoneCallback callback, time_point created=time_point::max(), bool permanent = false);
size_t listenTo(const InfoHash& id, sa_family_t af, GetCallback cb, Value::Filter f = Value::AllFilter()); size_t listenTo(const InfoHash& id, sa_family_t af, GetCallback cb, Value::Filter f = Value::AllFilter(), const std::shared_ptr<Query>& q = {});
void bootstrapSearch(Search& sr); void bootstrapSearch(Search& sr);
Search *findSearch(unsigned short tid, sa_family_t af); Search *findSearch(unsigned short tid, sa_family_t af);
...@@ -508,11 +522,14 @@ private: ...@@ -508,11 +522,14 @@ private:
NetworkEngine::RequestAnswer onFindNode(std::shared_ptr<Node> node, InfoHash& hash, want_t want); NetworkEngine::RequestAnswer onFindNode(std::shared_ptr<Node> node, InfoHash& hash, want_t want);
void onFindNodeDone(const Request& status, NetworkEngine::RequestAnswer& a, std::shared_ptr<Search> sr); void onFindNodeDone(const Request& status, NetworkEngine::RequestAnswer& a, std::shared_ptr<Search> sr);
/* when we receive a "get values" request */ /* when we receive a "get values" request */
NetworkEngine::RequestAnswer onGetValues(std::shared_ptr<Node> node, InfoHash& hash, want_t want); NetworkEngine::RequestAnswer onGetValues(std::shared_ptr<Node> node, InfoHash& hash, want_t want, const Query& q);
void onGetValuesDone(const Request& status, NetworkEngine::RequestAnswer& a, std::shared_ptr<Search> sr); void onGetValuesDone(const Request& status, NetworkEngine::RequestAnswer& a, std::shared_ptr<Search> sr,
const std::shared_ptr<Query>& orig_query);
/* when we receive a listen request */ /* when we receive a listen request */
NetworkEngine::RequestAnswer onListen(std::shared_ptr<Node> node, InfoHash& hash, Blob& token, size_t rid); NetworkEngine::RequestAnswer onListen(std::shared_ptr<Node> node, InfoHash& hash, Blob& token, size_t rid,
void onListenDone(const Request& status, NetworkEngine::RequestAnswer& a, std::shared_ptr<Search>& sr); Query&& query);
void onListenDone(const Request& status, NetworkEngine::RequestAnswer& a,
std::shared_ptr<Search>& sr, const std::shared_ptr<Query>& orig_query);
/* when we receive an announce request */ /* when we receive an announce request */
NetworkEngine::RequestAnswer onAnnounce(std::shared_ptr<Node> node, NetworkEngine::RequestAnswer onAnnounce(std::shared_ptr<Node> node,
InfoHash& hash, Blob& token, std::vector<std::shared_ptr<Value>> v, time_point created); InfoHash& hash, Blob& token, std::vector<std::shared_ptr<Value>> v, time_point created);
......
...@@ -153,6 +153,7 @@ public: ...@@ -153,6 +153,7 @@ public:
Blob ntoken {}; Blob ntoken {};
Value::Id vid {}; Value::Id vid {};
std::vector<std::shared_ptr<Value>> values {}; std::vector<std::shared_ptr<Value>> values {};
std::vector<std::shared_ptr<FieldValueIndex>> fields {};
std::vector<std::shared_ptr<Node>> nodes4 {}; std::vector<std::shared_ptr<Node>> nodes4 {};
std::vector<std::shared_ptr<Node>> nodes6 {}; std::vector<std::shared_ptr<Node>> nodes6 {};
RequestAnswer() {} RequestAnswer() {}
...@@ -223,7 +224,8 @@ private: ...@@ -223,7 +224,8 @@ private:
*/ */
std::function<RequestAnswer(std::shared_ptr<Node>, std::function<RequestAnswer(std::shared_ptr<Node>,
InfoHash&, InfoHash&,
want_t)> onGetValues {}; want_t,
Query)> onGetValues {};
/** /**
* @brief on listen request callback. * @brief on listen request callback.
* *
...@@ -235,7 +237,8 @@ private: ...@@ -235,7 +237,8 @@ private:
std::function<RequestAnswer(std::shared_ptr<Node>, std::function<RequestAnswer(std::shared_ptr<Node>,
InfoHash&, InfoHash&,
Blob&, Blob&,
uint16_t)> onListen {}; uint16_t,
Query)> onListen {};
/** /**
* @brief on announce request callback. * @brief on announce request callback.
* *
...@@ -290,9 +293,9 @@ public: ...@@ -290,9 +293,9 @@ public:
* @param nodes6 The ipv6 closest nodes. * @param nodes6 The ipv6 closest nodes.
* @param values The values to send. * @param values The values to send.
*/ */
void tellListener(std::shared_ptr<Node> n, uint16_t rid, InfoHash hash, want_t want, Blob ntoken, void tellListener(std::shared_ptr<Node> n, uint16_t rid, const InfoHash& hash, want_t want, const Blob& ntoken,
std::vector<std::shared_ptr<Node>> nodes, std::vector<std::shared_ptr<Node>> nodes6, std::vector<std::shared_ptr<Node>>&& nodes, std::vector<std::shared_ptr<Node>>&& nodes6,
std::vector<std::shared_ptr<Value>> values); std::vector<std::shared_ptr<Value>>&& values, const Query& q);
bool isRunning(sa_family_t af) const; bool isRunning(sa_family_t af) const;
inline want_t want () const { return dht_socket >= 0 && dht_socket6 >= 0 ? (WANT4 | WANT6) : -1; } inline want_t want () const { return dht_socket >= 0 && dht_socket6 >= 0 ? (WANT4 | WANT6) : -1; }
...@@ -314,13 +317,15 @@ public: ...@@ -314,13 +317,15 @@ public:
RequestExpiredCb on_expired); RequestExpiredCb on_expired);
std::shared_ptr<Request> std::shared_ptr<Request>
sendGetValues(std::shared_ptr<Node> n, sendGetValues(std::shared_ptr<Node> n,
const InfoHash& target, const InfoHash& info_hash,
const Query& query,
want_t want, want_t want,
RequestCb on_done, RequestCb on_done,
RequestExpiredCb on_expired); RequestExpiredCb on_expired);
std::shared_ptr<Request> std::shared_ptr<Request>
sendListen(std::shared_ptr<Node> n, sendListen(std::shared_ptr<Node> n,
const InfoHash& infohash, const InfoHash& infohash,
const Query& query,
const Blob& token, const Blob& token,
RequestCb on_done, RequestCb on_done,
RequestExpiredCb on_expired); RequestExpiredCb on_expired);
...@@ -432,6 +437,7 @@ private: ...@@ -432,6 +437,7 @@ private:
const Blob& nodes, const Blob& nodes,
const Blob& nodes6, const Blob& nodes6,
const std::vector<std::shared_ptr<Value>>& st, const std::vector<std::shared_ptr<Value>>& st,
const Query& query,
const Blob& token); const Blob& token);
Blob bufferNodes(sa_family_t af, const InfoHash& id, std::vector<std::shared_ptr<Node>>& nodes); Blob bufferNodes(sa_family_t af, const InfoHash& id, std::vector<std::shared_ptr<Node>>& nodes);
...@@ -452,7 +458,7 @@ private: ...@@ -452,7 +458,7 @@ private:
const std::string& message, const std::string& message,
bool include_id=false); bool include_id=false);
void deserializeNodesValues(ParsedMessage& msg); void deserializeNodes(ParsedMessage& msg);
std::queue<time_point> rate_limit_time {}; std::queue<time_point> rate_limit_time {};
static std::mt19937 rd_device; static std::mt19937 rd_device;
......
...@@ -85,18 +85,18 @@ public: ...@@ -85,18 +85,18 @@ public:
* If the signature can't be checked, or if the data can't be decrypted, it is not returned. * If the signature can't be checked, or if the data can't be decrypted, it is not returned.
* Public, non-signed & non-encrypted data is retransmitted as-is. * Public, non-signed & non-encrypted data is retransmitted as-is.
*/ */
virtual void get(const InfoHash& id, GetCallback cb, DoneCallback donecb={}, Value::Filter&& = {}) override; virtual void get(const InfoHash& id, GetCallback cb, DoneCallback donecb={}, Value::Filter&& = {}, Where&& w = {}) override;
virtual void get(const InfoHash& id, GetCallback cb, DoneCallbackSimple donecb={}, Value::Filter&& f = {}) override { virtual void get(const InfoHash& id, GetCallback cb, DoneCallbackSimple donecb={}, Value::Filter&& f = {}, Where&& w = {}) override {
get(id, cb, bindDoneCb(donecb), std::forward<Value::Filter>(f)); get(id, cb, bindDoneCb(donecb), std::forward<Value::Filter>(f), std::forward<Where>(w));
} }
virtual void get(const InfoHash& key, GetCallbackSimple cb, DoneCallback donecb={}, Value::Filter&& f={}) override { virtual void get(const InfoHash& key, GetCallbackSimple cb, DoneCallback donecb={}, Value::Filter&& f={}, Where&& w = {}) override {
get(key, bindGetCb(cb), donecb, std::forward<Value::Filter>(f)); get(key, bindGetCb(cb), donecb, std::forward<Value::Filter>(f), std::forward<Where>(w));
} }
virtual void get(const InfoHash& key, GetCallbackSimple cb, DoneCallbackSimple donecb, Value::Filter&& f={}) override { virtual void get(const InfoHash& key, GetCallbackSimple cb, DoneCallbackSimple donecb, Value::Filter&& f={}, Where&& w = {}) override {
get(key, bindGetCb(cb), bindDoneCb(donecb), std::forward<Value::Filter>(f)); get(key, bindGetCb(cb), bindDoneCb(donecb), std::forward<Value::Filter>(f), std::forward<Where>(w));
} }
virtual size_t listen(const InfoHash& id, GetCallback cb, Value::Filter&& = {}) override; virtual size_t listen(const InfoHash& id, GetCallback cb, Value::Filter&& = {}, Where&& w = {}) override;
/** /**
* Will take ownership of the value, sign it using our private key and put it in the DHT. * Will take ownership of the value, sign it using our private key and put it in the DHT.
......
This diff is collapsed.
...@@ -21,6 +21,7 @@ ...@@ -21,6 +21,7 @@
#include "network_engine.h" #include "network_engine.h"
#include "request.h" #include "request.h"
#include "default_types.h"
#include <msgpack.hpp> #include <msgpack.hpp>
...@@ -74,6 +75,8 @@ struct ParsedMessage { ...@@ -74,6 +75,8 @@ struct ParsedMessage {
Blob nodes4_raw, nodes6_raw; /* IPv4 nodes in response to a 'find' request */ Blob nodes4_raw, nodes6_raw; /* IPv4 nodes in response to a 'find' request */
std::vector<std::shared_ptr<Node>> nodes4, nodes6; std::vector<std::shared_ptr<Node>> nodes4, nodes6;
std::vector<std::shared_ptr<Value>> values; /* values for a 'get' request */ std::vector<std::shared_ptr<Value>> values; /* values for a 'get' request */
std::vector<std::shared_ptr<FieldValueIndex>> fields; /* index for fields values */
Query query; /* query describing a filter to apply on values. */
want_t want; /* states if ipv4 or ipv6 request */ want_t want; /* states if ipv4 or ipv6 request */
uint16_t error_code; /* error code in case of error */ uint16_t error_code; /* error code in case of error */
std::string ua; std::string ua;
...@@ -82,17 +85,19 @@ struct ParsedMessage { ...@@ -82,17 +85,19 @@ struct ParsedMessage {
}; };
NetworkEngine::RequestAnswer::RequestAnswer(ParsedMessage&& msg) NetworkEngine::RequestAnswer::RequestAnswer(ParsedMessage&& msg)
: ntoken(std::move(msg.token)), values(std::move(msg.values)), nodes4(std::move(msg.nodes4)), nodes6(std::move(msg.nodes6)) {} : ntoken(std::move(msg.token)), values(std::move(msg.values)), fields(std::move(msg.fields)),
nodes4(std::move(msg.nodes4)), nodes6(std::move(msg.nodes6)) {}
void void
NetworkEngine::tellListener(std::shared_ptr<Node> node, uint16_t rid, InfoHash hash, want_t want, NetworkEngine::tellListener(std::shared_ptr<Node> node, uint16_t rid, const InfoHash& hash, want_t want,
Blob ntoken, std::vector<std::shared_ptr<Node>> nodes, std::vector<std::shared_ptr<Node>> nodes6, const Blob& ntoken, std::vector<std::shared_ptr<Node>>&& nodes,
std::vector<std::shared_ptr<Value>> values) std::vector<std::shared_ptr<Node>>&& nodes6, std::vector<std::shared_ptr<Value>>&& values,
const Query& query)
{ {
auto nnodes = bufferNodes(node->getFamily(), hash, want, nodes, nodes6); auto nnodes = bufferNodes(node->getFamily(), hash, want, nodes, nodes6);
try { try {
sendNodesValues((const sockaddr*)&node->ss, node->sslen, TransId {TransPrefix::GET_VALUES, (uint16_t)rid}, nnodes.first, nnodes.second, sendNodesValues((const sockaddr*)&node->ss, node->sslen, TransId {TransPrefix::GET_VALUES, (uint16_t)rid}, nnodes.first, nnodes.second,
values, ntoken); values, query, ntoken);
} catch (const std::overflow_error& e) { } catch (const std::overflow_error& e) {
DHT_LOG.ERR("Can't send value: buffer not large enough !"); DHT_LOG.ERR("Can't send value: buffer not large enough !");
} }
...@@ -380,7 +385,7 @@ NetworkEngine::processMessage(const uint8_t *buf, size_t buflen, const sockaddr* ...@@ -380,7 +385,7 @@ NetworkEngine::processMessage(const uint8_t *buf, size_t buflen, const sockaddr*
requests.erase(reqp); requests.erase(reqp);
req->reply_time = scheduler.time(); req->reply_time = scheduler.time();
deserializeNodesValues(msg); deserializeNodes(msg);
req->setDone(std::move(msg)); req->setDone(std::move(msg));
break; break;
default: default:
...@@ -404,16 +409,16 @@ NetworkEngine::processMessage(const uint8_t *buf, size_t buflen, const sockaddr* ...@@ -404,16 +409,16 @@ NetworkEngine::processMessage(const uint8_t *buf, size_t buflen, const sockaddr*
++in_stats.find; ++in_stats.find;
RequestAnswer answer = onFindNode(node, msg.target, msg.want); RequestAnswer answer = onFindNode(node, msg.target, msg.want);
auto nnodes = bufferNodes(from->sa_family, msg.target, msg.want, answer.nodes4, answer.nodes6); auto nnodes = bufferNodes(from->sa_family, msg.target, msg.want, answer.nodes4, answer.nodes6);
sendNodesValues(from, fromlen, msg.tid, nnodes.first, nnodes.second, {}, answer.ntoken); sendNodesValues(from, fromlen, msg.tid, nnodes.first, nnodes.second, {}, {}, answer.ntoken);
break; break;
} }
case MessageType::GetValues: { case MessageType::GetValues: {
DHT_LOG.DEBUG("[node %s %s] got 'get' request for %s.", DHT_LOG.DEBUG("[node %s %s] got 'get' request for %s.",
msg.id.toString().c_str(), print_addr(from, fromlen).c_str(), msg.info_hash.toString().c_str()); msg.id.toString().c_str(), print_addr(from, fromlen).c_str(), msg.info_hash.toString().c_str());
++in_stats.get; ++in_stats.get;
RequestAnswer answer = onGetValues(node, msg.info_hash, msg.want); RequestAnswer answer = onGetValues(node, msg.info_hash, msg.want, msg.query);
auto nnodes = bufferNodes(from->sa_family, msg.info_hash, msg.want, answer.nodes4, answer.nodes6); auto nnodes = bufferNodes(from->sa_family, msg.info_hash, msg.want, answer.nodes4, answer.nodes6);
sendNodesValues(from, fromlen, msg.tid, nnodes.first, nnodes.second, answer.values, answer.ntoken); sendNodesValues(from, fromlen, msg.tid, nnodes.first, nnodes.second, answer.values, msg.query, answer.ntoken);
break; break;
} }
case MessageType::AnnounceValue: { case MessageType::AnnounceValue: {
...@@ -435,7 +440,7 @@ NetworkEngine::processMessage(const uint8_t *buf, size_t buflen, const sockaddr* ...@@ -435,7 +440,7 @@ NetworkEngine::processMessage(const uint8_t *buf, size_t buflen, const sockaddr*
DHT_LOG.DEBUG("[node %s %s] got 'listen' request for %s.", DHT_LOG.DEBUG("[node %s %s] got 'listen' request for %s.",
msg.id.toString().c_str(), print_addr(from, fromlen).c_str(), msg.info_hash.toString().c_str()); msg.id.toString().c_str(), print_addr(from, fromlen).c_str(), msg.info_hash.toString().c_str());
++in_stats.listen; ++in_stats.listen;
RequestAnswer answer = onListen(node, msg.info_hash, msg.token, msg.tid.getTid()); RequestAnswer answer = onListen(node, msg.info_hash, msg.token, msg.tid.getTid(), std::move(msg.query));
sendListenConfirmation(from, fromlen, msg.tid); sendListenConfirmation(from, fromlen, msg.tid);
break; break;
} }
...@@ -594,16 +599,19 @@ NetworkEngine::sendFindNode(std::shared_ptr<Node> n, const InfoHash& target, wan ...@@ -594,16 +599,19 @@ NetworkEngine::sendFindNode(std::shared_ptr<Node> n, const InfoHash& target, wan
std::shared_ptr<Request> std::shared_ptr<Request>
NetworkEngine::sendGetValues(std::shared_ptr<Node> n, const InfoHash& info_hash, want_t want, NetworkEngine::sendGetValues(std::shared_ptr<Node> n, const InfoHash& info_hash, const Query& query, want_t want,
RequestCb on_done, RequestExpiredCb on_expired) { RequestCb on_done, RequestExpiredCb on_expired) {
auto tid = TransId {TransPrefix::GET_VALUES, getNewTid()}; auto tid = TransId {TransPrefix::GET_VALUES, getNewTid()};
msgpack::sbuffer buffer; msgpack::sbuffer buffer;
msgpack::packer<msgpack::sbuffer> pk(&buffer); msgpack::packer<msgpack::sbuffer> pk(&buffer);
pk.pack_map(5+(network?1:0)); pk.pack_map(5+(network?1:0));
pk.pack(std::string("a")); pk.pack_map(2 + (want>0?1:0)); pk.pack(std::string("a")); pk.pack_map(2 +
(query.where.getFilter() or not query.select.getSelection().empty() ? 1:0) +
(want>0?1:0));
pk.pack(std::string("id")); pk.pack(myid); pk.pack(std::string("id")); pk.pack(myid);
pk.pack(std::string("h")); pk.pack(info_hash); pk.pack(std::string("h")); pk.pack(info_hash);
pk.pack(std::string("q")); pk.pack(query);
if (want > 0) { if (want > 0) {
pk.pack(std::string("w")); pk.pack(std::string("w"));
pk.pack_array(((want & WANT4)?1:0) + ((want & WANT6)?1:0)); pk.pack_array(((want & WANT4)?1:0) + ((want & WANT6)?1:0));
...@@ -639,7 +647,7 @@ NetworkEngine::sendGetValues(std::shared_ptr<Node> n, const InfoHash& info_hash, ...@@ -639,7 +647,7 @@ NetworkEngine::sendGetValues(std::shared_ptr<Node> n, const InfoHash& info_hash,
} }
void void
NetworkEngine::deserializeNodesValues(ParsedMessage& msg) { NetworkEngine::deserializeNodes(ParsedMessage& msg) {
if (msg.nodes4_raw.size() % NODE4_INFO_BUF_LEN != 0 || msg.nodes6_raw.size() % NODE6_INFO_BUF_LEN != 0) { if (msg.nodes4_raw.size() % NODE4_INFO_BUF_LEN != 0 || msg.nodes6_raw.size() % NODE6_INFO_BUF_LEN != 0) {
throw DhtProtocolException {DhtProtocolException::WRONG_NODE_INFO_BUF_LEN}; throw DhtProtocolException {DhtProtocolException::WRONG_NODE_INFO_BUF_LEN};
} else { } else {
...@@ -680,7 +688,7 @@ NetworkEngine::deserializeNodesValues(ParsedMessage& msg) { ...@@ -680,7 +688,7 @@ NetworkEngine::deserializeNodesValues(ParsedMessage& msg) {
void void
NetworkEngine::sendNodesValues(const sockaddr* sa, socklen_t salen, TransId tid, const Blob& nodes, const Blob& nodes6, NetworkEngine::sendNodesValues(const sockaddr* sa, socklen_t salen, TransId tid, const Blob& nodes, const Blob& nodes6,
const std::vector<std::shared_ptr<Value>>& st, const Blob& token) { const std::vector<std::shared_ptr<Value>>& st, const Query& query, const Blob& token) {
msgpack::sbuffer buffer; msgpack::sbuffer buffer;
msgpack::packer<msgpack::sbuffer> pk(&buffer); msgpack::packer<msgpack::sbuffer> pk(&buffer);
pk.pack_map(4+(network?1:0)); pk.pack_map(4+(network?1:0));
...@@ -702,7 +710,10 @@ NetworkEngine::sendNodesValues(const sockaddr* sa, socklen_t salen, TransId tid, ...@@ -702,7 +710,10 @@ NetworkEngine::sendNodesValues(const sockaddr* sa, socklen_t salen, TransId tid,
if (not token.empty()) { if (not token.empty()) {
pk.pack(std::string("token")); packToken(pk, token); pk.pack(std::string("token")); packToken(pk, token);
} }
if (not st.empty()) { if (not st.empty()) { /* pack complete values */
auto fields = query.select.getSelection();
size_t total_size = 0;
if (fields.empty()) {
// We treat the storage as a circular list, and serve a randomly // We treat the storage as a circular list, and serve a randomly
// chosen slice. In order to make sure we fit, // chosen slice. In order to make sure we fit,
// we limit ourselves to 50 values. // we limit ourselves to 50 values.
...@@ -710,7 +721,6 @@ NetworkEngine::sendNodesValues(const sockaddr* sa, socklen_t salen, TransId tid, ...@@ -710,7 +721,6 @@ NetworkEngine::sendNodesValues(const sockaddr* sa, socklen_t salen, TransId tid,
std::vector<Blob> subset {}; std::vector<Blob> subset {};
subset.reserve(std::min<size_t>(st.size(), 50)); subset.reserve(std::min<size_t>(st.size(), 50));
size_t total_size = 0;
unsigned j0 = pos_dis(rd_device); unsigned j0 = pos_dis(rd_device);
unsigned j = j0; unsigned j = j0;
unsigned k = 0; unsigned k = 0;
...@@ -726,10 +736,24 @@ NetworkEngine::sendNodesValues(const sockaddr* sa, socklen_t salen, TransId tid, ...@@ -726,10 +736,24 @@ NetworkEngine::sendNodesValues(const sockaddr* sa, socklen_t salen, TransId tid,
pk.pack_array(subset.size()); pk.pack_array(subset.size());
for (const auto& b : subset) for (const auto& b : subset)
buffer.write((const char*)b.data(), b.size()); buffer.write((const char*)b.data(), b.size());
DHT_LOG.DEBUG("sending closest nodes (%d+%d nodes.), %lu bytes of values", nodes.size(), nodes6.size(), total_size); DHT_LOG.DEBUG("sending closest nodes (%d+%d nodes.), %lu bytes of values",
nodes.size(), nodes6.size(), total_size);
} else { /* pack fields */
pk.pack(std::string("fields"));
pk.pack_map(2);
pk.pack(std::string("f")); pk.pack(fields);
pk.pack(std::string("v")); pk.pack_array(st.size()*fields.size());
for (const auto& v : st) {
v->msgpack_pack_fields(fields, pk);
}
DHT_LOG.DEBUG("sending closest nodes (%d+%d nodes.), %u value headers containing %u fields",
nodes.size(), nodes6.size(), st.size(), fields.size());
}
} else } else
DHT_LOG.DEBUG("sending closest nodes (%d+%d nodes.)", nodes.size(), nodes6.size()); DHT_LOG.DEBUG("sending closest nodes (%d+%d nodes.)", nodes.size(), nodes6.size());
DHT_LOG.DEBUG("sending closest nodes (%d+%d nodes.)", nodes.size(), nodes6.size());
pk.pack(std::string("t")); pk.pack_bin(tid.size()); pk.pack(std::string("t")); pk.pack_bin(tid.size());
pk.pack_bin_body((const char*)tid.data(), tid.size()); pk.pack_bin_body((const char*)tid.data(), tid.size());
pk.pack(std::string("y")); pk.pack(std::string("r")); pk.pack(std::string("y")); pk.pack(std::string("r"));
...@@ -794,16 +818,18 @@ NetworkEngine::bufferNodes(sa_family_t af, const InfoHash& id, want_t want, ...@@ -794,16 +818,18 @@ NetworkEngine::bufferNodes(sa_family_t af, const InfoHash& id, want_t want,
} }
std::shared_ptr<Request> std::shared_ptr<Request>
NetworkEngine::sendListen(std::shared_ptr<Node> n, const InfoHash& infohash, const Blob& token, NetworkEngine::sendListen(std::shared_ptr<Node> n, const InfoHash& infohash, const Query& query, const Blob& token,
RequestCb on_done, RequestExpiredCb on_expired) { RequestCb on_done, RequestExpiredCb on_expired) {
auto tid = TransId {TransPrefix::LISTEN, getNewTid()}; auto tid = TransId {TransPrefix::LISTEN, getNewTid()};
msgpack::sbuffer buffer; msgpack::sbuffer buffer;
msgpack::packer<msgpack::sbuffer> pk(&buffer); msgpack::packer<msgpack::sbuffer> pk(&buffer);
pk.pack_map(5+(network?1:0)); pk.pack_map(5+(network?1:0));
pk.pack(std::string("a")); pk.pack_map(3); pk.pack(std::string("a")); pk.pack_map(3 +
(query.where.getFilter() or not query.select.getSelection().empty() ? 1:0));
pk.pack(std::string("id")); pk.pack(myid); pk.pack(std::string("id")); pk.pack(myid);
pk.pack(std::string("h")); pk.pack(infohash); pk.pack(std::string("h")); pk.pack(infohash);
pk.pack(std::string("q")); pk.pack(query);
pk.pack(std::string("token")); packToken(pk, token); pk.pack(std::string("token")); packToken(pk, token);
pk.pack(std::string("q")); pk.pack(std::string("listen")); pk.pack(std::string("q")); pk.pack(std::string("listen"));
...@@ -961,17 +987,36 @@ void ...@@ -961,17 +987,36 @@ void
ParsedMessage::msgpack_unpack(msgpack::object msg) ParsedMessage::msgpack_unpack(msgpack::object msg)
{ {
auto y = findMapValue(msg, "y"); auto y = findMapValue(msg, "y");
auto a = findMapValue(msg, "a");
auto r = findMapValue(msg, "r"); auto r = findMapValue(msg, "r");
auto e = findMapValue(msg, "e"); auto e = findMapValue(msg, "e");
std::string query; std::string q;
if (auto q = findMapValue(msg, "q")) { if (auto rq = findMapValue(msg, "q")) {
if (q->type != msgpack::type::STR) if (rq->type != msgpack::type::STR)
throw msgpack::type_error(); throw msgpack::type_error();
query = q->as<std::string>(); q = rq->as<std::string>();
} }
if (e)
type = MessageType::Error;
else if (r)
type = MessageType::Reply;
else if (y and y->as<std::string>() != "q")
throw msgpack::type_error();
else if (q == "ping")
type = MessageType::Ping;
else if (q == "find")
type = MessageType::FindNode;
else if (q == "get")
type = MessageType::GetValues;
else if (q == "listen")
type = MessageType::Listen;
else if (q == "put")
type = MessageType::AnnounceValue;
else
throw msgpack::type_error();
auto a = findMapValue(msg, "a");
if (!a && !r && !e) if (!a && !r && !e)
throw msgpack::type_error(); throw msgpack::type_error();
auto& req = a ? *a : (r ? *r : *e); auto& req = a ? *a : (r ? *r : *e);
...@@ -994,6 +1039,9 @@ ParsedMessage::msgpack_unpack(msgpack::object msg) ...@@ -994,6 +1039,9 @@ ParsedMessage::msgpack_unpack(msgpack::object msg)
if (auto rtarget = findMapValue(req, "target")) if (auto rtarget = findMapValue(req, "target"))
target = {*rtarget}; target = {*rtarget};
if (auto rquery = findMapValue(req, "q"))
query.msgpack_unpack(*rquery);
if (auto otoken = findMapValue(req, "token")) if (auto otoken = findMapValue(req, "token"))
token = unpackBlob(*otoken); token = unpackBlob(*otoken);
...@@ -1040,6 +1088,23 @@ ParsedMessage::msgpack_unpack(msgpack::object msg) ...@@ -1040,6 +1088,23 @@ ParsedMessage::msgpack_unpack(msgpack::object msg)
} catch (const std::exception& e) { } catch (const std::exception& e) {
//DHT_LOG.WARN("Error reading value: %s", e.what()); //DHT_LOG.WARN("Error reading value: %s", e.what());
} }
} else if (auto raw_fields = findMapValue(req, "fields")) {
if (auto rfields = findMapValue(*raw_fields, "f")) {
auto fields_ = rfields->as<std::set<Value::Field>>();
if (auto rvalues = findMapValue(*raw_fields, "v")) {
if (rvalues->type != msgpack::type::ARRAY)
throw msgpack::type_error();
for (size_t i = 0; i < rvalues->via.array.size; ++i) {
try {
auto v = std::make_shared<FieldValueIndex>();
v->msgpack_unpack_fields(fields_, *rvalues, i*fields.size());
fields.emplace_back(std::move(v));
} catch (const std::exception& e) { }
}
}
} else {
throw msgpack::type_error();
}
} }
if (auto w = findMapValue(req, "w")) { if (auto w = findMapValue(req, "w")) {
...@@ -1066,24 +1131,6 @@ ParsedMessage::msgpack_unpack(msgpack::object msg) ...@@ -1066,24 +1131,6 @@ ParsedMessage::msgpack_unpack(msgpack::object msg)
if (auto rv = findMapValue(msg, "v")) if (auto rv = findMapValue(msg, "v"))
ua = rv->as<std::string>(); ua = rv->as<std::string>();
if (e)
type = MessageType::Error;
else if (r)
type = MessageType::Reply;
else if (y and y->as<std::string>() != "q")
throw msgpack::type_error();
else if (query == "ping")
type = MessageType::Ping;
else if (query == "find")
type = MessageType::FindNode;
else if (query == "get")
type = MessageType::GetValues;
else if (query == "listen")
type = MessageType::Listen;
else if (query == "put")
type = MessageType::AnnounceValue;
else
throw msgpack::type_error();
} }
} }
...@@ -291,15 +291,15 @@ SecureDht::getCallbackFilter(GetCallback cb, Value::Filter&& filter) ...@@ -291,15 +291,15 @@ SecureDht::getCallbackFilter(GetCallback cb, Value::Filter&& filter)
} }
void void
SecureDht::get(const InfoHash& id, GetCallback cb, DoneCallback donecb, Value::Filter&& f) SecureDht::get(const InfoHash& id, GetCallback cb, DoneCallback donecb, Value::Filter&& f, Where&& w)
{ {
Dht::get(id, getCallbackFilter(cb, std::forward<Value::Filter>(f)), donecb); Dht::get(id, getCallbackFilter(cb, std::forward<Value::Filter>(f)), donecb, {}, std::forward<Where>(w));
} }
size_t size_t
SecureDht::listen(const InfoHash& id, GetCallback cb, Value::Filter&& f) SecureDht::listen(const InfoHash& id, GetCallback cb, Value::Filter&& f, Where&& w)
{ {
return Dht::listen(id, getCallbackFilter(cb, std::forward<Value::Filter>(f))); return Dht::listen(id, getCallbackFilter(cb, std::forward<Value::Filter>(f)), {}, std::forward<Where>(w));
} }
void void
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment