diff --git a/include/opendht/callbacks.h b/include/opendht/callbacks.h index 5318c08c336385f584d15e484c5cc32a5c96ccca..ed75c0d6e830926007243087e9673f70c95ecee0 100644 --- a/include/opendht/callbacks.h +++ b/include/opendht/callbacks.h @@ -70,6 +70,7 @@ static constexpr size_t DEFAULT_STORAGE_LIMIT {1024 * 1024 * 64}; using ValuesExport = std::pair<InfoHash, Blob>; +using QueryCallback = std::function<bool(const std::vector<std::shared_ptr<FieldValueIndex>>& fields)>; using GetCallback = std::function<bool(const std::vector<std::shared_ptr<Value>>& values)>; using GetCallbackSimple = std::function<bool(std::shared_ptr<Value> value)>; using ShutdownCallback = std::function<void()>; diff --git a/include/opendht/default_types.h b/include/opendht/default_types.h index ea0f6a155c641218f0167612beab7dd5b3242ff6..67578f77e97c1a3f08dd92a51911a16bfde69794 100644 --- a/include/opendht/default_types.h +++ b/include/opendht/default_types.h @@ -21,6 +21,8 @@ #include "value.h" +MSGPACK_ADD_ENUM(dht::Value::Field); + namespace dht { enum class ImStatus : uint8_t { NONE = 0, @@ -175,7 +177,7 @@ public: pk.pack_bin_body((const char*)ice_data.data(), ice_data.size()); #else // hack for backward compatibility with old opendht compiled with msgpack 1.0 - // remove when enough people have moved to new versions + // remove when enough people have moved to new versions pk.pack_array(ice_data.size()); for (uint8_t b : ice_data) pk.pack(b); diff --git a/include/opendht/dht.h b/include/opendht/dht.h index a5409e367b715deea7412f61ced4dd92fb2a7e9e..d0a89eb261a28adb6423feb42449b01a173a5e40 100644 --- a/include/opendht/dht.h +++ b/include/opendht/dht.h @@ -133,18 +133,32 @@ public: * @param cb a function called when new values are found on the network. * It should return false to stop the operation. * @param donecb a function called when the operation is complete. - cb and donecb won't be called again afterward. + cb and donecb won't be called again afterward. * @param f a filter function used to prefilter values. */ - virtual void get(const InfoHash& key, GetCallback cb, DoneCallback donecb={}, Value::Filter&& f={}); - virtual void get(const InfoHash& key, GetCallback cb, DoneCallbackSimple donecb={}, Value::Filter&& f={}) { - get(key, cb, bindDoneCb(donecb), std::forward<Value::Filter>(f)); + virtual void get(const InfoHash& key, GetCallback cb, DoneCallback donecb={}, Value::Filter&& f={}, Where&& w = {}); + virtual void get(const InfoHash& key, GetCallback cb, DoneCallbackSimple donecb={}, Value::Filter&& f={}, Where&& w = {}) { + get(key, cb, bindDoneCb(donecb), std::forward<Value::Filter>(f), std::forward<Where>(w)); } - virtual void get(const InfoHash& key, GetCallbackSimple cb, DoneCallback donecb={}, Value::Filter&& f={}) { - get(key, bindGetCb(cb), donecb, std::forward<Value::Filter>(f)); + virtual void get(const InfoHash& key, GetCallbackSimple cb, DoneCallback donecb={}, Value::Filter&& f={}, Where&& w = {}) { + get(key, bindGetCb(cb), donecb, std::forward<Value::Filter>(f), std::forward<Where>(w)); } - virtual void get(const InfoHash& key, GetCallbackSimple cb, DoneCallbackSimple donecb, Value::Filter&& f={}) { - get(key, bindGetCb(cb), bindDoneCb(donecb), std::forward<Value::Filter>(f)); + virtual void get(const InfoHash& key, GetCallbackSimple cb, DoneCallbackSimple donecb, Value::Filter&& f={}, Where&& w = {}) { + get(key, bindGetCb(cb), bindDoneCb(donecb), std::forward<Value::Filter>(f), std::forward<Where>(w)); + } + /** + * Similar to Dht::get, but sends a Query to filter data remotely. + * @param key the key for which to query data for. + * @param cb a function called when new values are found on the network. + * It should return false to stop the operation. + * @param donecb a function called when the operation is complete. + cb and donecb won't be called again afterward. + * @param q a query used to filter values on the remotes before they send a + * response. + */ + virtual void query(const InfoHash& key, QueryCallback cb, DoneCallback done_cb = {}, Query&& q = {}); + virtual void query(const InfoHash& key, QueryCallback cb, DoneCallbackSimple done_cb = {}, Query&& q = {}) { + query(key, cb, bindDoneCb(done_cb), std::forward<Query>(q)); } /** @@ -158,14 +172,10 @@ public: std::shared_ptr<Value> getLocalById(const InfoHash& key, Value::Id vid) const; /** - * Announce a value on all available protocols (IPv4, IPv6), and - * automatically re-announce when it's about to expire. + * Announce a value on all available protocols (IPv4, IPv6). + * * The operation will start as soon as the node is connected to the network. * The done callback will be called once, when the first announce succeeds, or fails. - * - * A "put" operation will never end by itself because the value will need to be - * reannounced on a regular basis. - * User can call #cancelPut(InfoHash, Value::Id) to cancel a put operation. */ void put(const InfoHash& key, std::shared_ptr<Value>, @@ -221,9 +231,9 @@ public: * * @return a token to cancel the listener later. */ - virtual size_t listen(const InfoHash&, GetCallback, Value::Filter&&={}); - virtual size_t listen(const InfoHash& key, GetCallbackSimple cb, Value::Filter f={}) { - return listen(key, bindGetCb(cb), std::forward<Value::Filter>(f)); + virtual size_t listen(const InfoHash&, GetCallback, Value::Filter&&={}, Where&& w = {}); + virtual size_t listen(const InfoHash& key, GetCallbackSimple cb, Value::Filter f={}, Where w = {}) { + return listen(key, bindGetCb(cb), std::forward<Value::Filter>(f), std::forward<Where>(w)); } virtual bool cancelListen(const InfoHash&, size_t token); @@ -317,6 +327,8 @@ private: struct Get { time_point start; Value::Filter filter; + std::shared_ptr<Query> query; + QueryCallback query_cb; GetCallback get_cb; DoneCallback done_cb; }; @@ -335,6 +347,7 @@ private: * A single "listen" operation data */ struct LocalListener { + std::shared_ptr<Query> query; Value::Filter filter; GetCallback get_cb; }; @@ -359,8 +372,9 @@ private: struct Listener { size_t rid {}; time_point time {}; + Query query {}; - /*constexpr*/ Listener(size_t rid, time_point t) : rid(rid), time(t) {} + /*constexpr*/ Listener(size_t rid, time_point t, Query&& q) : rid(rid), time(t), query(q) {} void refresh(size_t tid, time_point t) { rid = tid; @@ -427,7 +441,7 @@ private: decltype(Dht::store)::iterator findStorage(const InfoHash& id); decltype(Dht::store)::const_iterator findStorage(const InfoHash& id) const; - void storageAddListener(const InfoHash& id, const std::shared_ptr<Node>& node, size_t tid); + void storageAddListener(const InfoHash& id, const std::shared_ptr<Node>& node, size_t tid, Query&& = {}); bool storageStore(const InfoHash& id, const std::shared_ptr<Value>& value, time_point created); void expireStorage(); void storageChanged(Storage& st, ValueStorage&); @@ -472,13 +486,13 @@ private: // Searches /** - * Low-level method that will perform a search on the DHT for the - * specified infohash (id), using the specified IP version (IPv4 or IPv6). - * The values can be filtered by an arbitrary provided filter. + * Low-level method that will perform a search on the DHT for the specified + * infohash (id), using the specified IP version (IPv4 or IPv6). */ - std::shared_ptr<Search> search(const InfoHash& id, sa_family_t af, GetCallback = nullptr, DoneCallback = nullptr, Value::Filter = Value::AllFilter()); + std::shared_ptr<Search> search(const InfoHash& id, sa_family_t af, GetCallback = {}, QueryCallback = {}, DoneCallback = {}, Value::Filter = {}, Query q = {}); + void announce(const InfoHash& id, sa_family_t af, std::shared_ptr<Value> value, DoneCallback callback, time_point created=time_point::max(), bool permanent = false); - size_t listenTo(const InfoHash& id, sa_family_t af, GetCallback cb, Value::Filter f = Value::AllFilter()); + size_t listenTo(const InfoHash& id, sa_family_t af, GetCallback cb, Value::Filter f = Value::AllFilter(), const std::shared_ptr<Query>& q = {}); void bootstrapSearch(Search& sr); Search *findSearch(unsigned short tid, sa_family_t af); @@ -508,11 +522,14 @@ private: NetworkEngine::RequestAnswer onFindNode(std::shared_ptr<Node> node, InfoHash& hash, want_t want); void onFindNodeDone(const Request& status, NetworkEngine::RequestAnswer& a, std::shared_ptr<Search> sr); /* when we receive a "get values" request */ - NetworkEngine::RequestAnswer onGetValues(std::shared_ptr<Node> node, InfoHash& hash, want_t want); - void onGetValuesDone(const Request& status, NetworkEngine::RequestAnswer& a, std::shared_ptr<Search> sr); + NetworkEngine::RequestAnswer onGetValues(std::shared_ptr<Node> node, InfoHash& hash, want_t want, const Query& q); + void onGetValuesDone(const Request& status, NetworkEngine::RequestAnswer& a, std::shared_ptr<Search> sr, + const std::shared_ptr<Query>& orig_query); /* when we receive a listen request */ - NetworkEngine::RequestAnswer onListen(std::shared_ptr<Node> node, InfoHash& hash, Blob& token, size_t rid); - void onListenDone(const Request& status, NetworkEngine::RequestAnswer& a, std::shared_ptr<Search>& sr); + NetworkEngine::RequestAnswer onListen(std::shared_ptr<Node> node, InfoHash& hash, Blob& token, size_t rid, + Query&& query); + void onListenDone(const Request& status, NetworkEngine::RequestAnswer& a, + std::shared_ptr<Search>& sr, const std::shared_ptr<Query>& orig_query); /* when we receive an announce request */ NetworkEngine::RequestAnswer onAnnounce(std::shared_ptr<Node> node, InfoHash& hash, Blob& token, std::vector<std::shared_ptr<Value>> v, time_point created); diff --git a/include/opendht/dhtrunner.h b/include/opendht/dhtrunner.h index 6ee06ed0e224ead41ef776e60c824860e9d27d97..47f970704b37577475c3492c0ab31d3ed3385198 100644 --- a/include/opendht/dhtrunner.h +++ b/include/opendht/dhtrunner.h @@ -54,20 +54,20 @@ public: DhtRunner(); virtual ~DhtRunner(); - void get(InfoHash id, GetCallbackSimple cb, DoneCallback donecb={}, Value::Filter f = Value::AllFilter()) { - get(id, bindGetCb(cb), donecb, f); + void get(InfoHash id, GetCallbackSimple cb, DoneCallback donecb={}, Value::Filter f = Value::AllFilter(), Where w = {}) { + get(id, bindGetCb(cb), donecb, f, w); } - void get(InfoHash id, GetCallbackSimple cb, DoneCallbackSimple donecb={}, Value::Filter f = Value::AllFilter()) { - get(id, bindGetCb(cb), donecb, f); + void get(InfoHash id, GetCallbackSimple cb, DoneCallbackSimple donecb={}, Value::Filter f = Value::AllFilter(), Where w = {}) { + get(id, bindGetCb(cb), donecb, f, w); } - void get(InfoHash hash, GetCallback vcb, DoneCallback dcb, Value::Filter f={}); + void get(InfoHash hash, GetCallback vcb, DoneCallback dcb, Value::Filter f={}, Where w = {}); - void get(InfoHash id, GetCallback cb, DoneCallbackSimple donecb={}, Value::Filter f = Value::AllFilter()) { - get(id, cb, bindDoneCb(donecb), f); + void get(InfoHash id, GetCallback cb, DoneCallbackSimple donecb={}, Value::Filter f = Value::AllFilter(), Where w = {}) { + get(id, cb, bindDoneCb(donecb), f, w); } - void get(const std::string& key, GetCallback vcb, DoneCallbackSimple dcb={}, Value::Filter f = Value::AllFilter()); + void get(const std::string& key, GetCallback vcb, DoneCallbackSimple dcb={}, Value::Filter f = Value::AllFilter(), Where w = {}); template <class T> void get(InfoHash hash, std::function<bool(std::vector<T>&&)> cb, DoneCallbackSimple dcb={}) @@ -96,7 +96,7 @@ public: getFilterSet<T>()); } - std::future<std::vector<std::shared_ptr<dht::Value>>> get(InfoHash key, Value::Filter f = Value::AllFilter()) { + std::future<std::vector<std::shared_ptr<dht::Value>>> get(InfoHash key, Value::Filter f = Value::AllFilter(), Where w = {}) { auto p = std::make_shared<std::promise<std::vector<std::shared_ptr< dht::Value >>>>(); auto values = std::make_shared<std::vector<std::shared_ptr< dht::Value >>>(); get(key, [=](const std::vector<std::shared_ptr<dht::Value>>& vlist) { @@ -105,7 +105,7 @@ public: }, [=](bool) { p->set_value(std::move(*values)); }, - f); + f, w); return p->get_future(); } @@ -122,10 +122,15 @@ public: return p->get_future(); } - std::future<size_t> listen(InfoHash key, GetCallback vcb, Value::Filter f = Value::AllFilter()); - std::future<size_t> listen(const std::string& key, GetCallback vcb, Value::Filter f = Value::AllFilter()); - std::future<size_t> listen(InfoHash key, GetCallbackSimple cb, Value::Filter f = Value::AllFilter()) { - return listen(key, bindGetCb(cb), f); + void query(const InfoHash& hash, QueryCallback cb, DoneCallback done_cb = {}, Query q = {}); + void query(const InfoHash& hash, QueryCallback cb, DoneCallbackSimple done_cb = {}, Query q = {}) { + query(hash, cb, bindDoneCb(done_cb), q); + } + + std::future<size_t> listen(InfoHash key, GetCallback vcb, Value::Filter f = Value::AllFilter(), Where w = {}); + std::future<size_t> listen(const std::string& key, GetCallback vcb, Value::Filter f = Value::AllFilter(), Where w = {}); + std::future<size_t> listen(InfoHash key, GetCallbackSimple cb, Value::Filter f = Value::AllFilter(), Where w = {}) { + return listen(key, bindGetCb(cb), f, w); } template <class T> @@ -137,7 +142,7 @@ public: getFilterSet<T>()); } template <typename T> - std::future<size_t> listen(InfoHash hash, std::function<bool(T&&)> cb, Value::Filter f = Value::AllFilter()) + std::future<size_t> listen(InfoHash hash, std::function<bool(T&&)> cb, Value::Filter f = Value::AllFilter(), Where w = {}) { return listen(hash, [=](const std::vector<std::shared_ptr<Value>>& vals) { for (const auto& v : vals) { @@ -150,7 +155,7 @@ public: } return true; }, - getFilterSet<T>(f)); + getFilterSet<T>(f), w); } void cancelListen(InfoHash h, size_t token); diff --git a/include/opendht/network_engine.h b/include/opendht/network_engine.h index 3d336c354e3e04d76ce9e03d33d0e352eba16bf1..04afdecf392739656c67571f895935a2aa7bf66b 100644 --- a/include/opendht/network_engine.h +++ b/include/opendht/network_engine.h @@ -153,6 +153,7 @@ public: Blob ntoken {}; Value::Id vid {}; std::vector<std::shared_ptr<Value>> values {}; + std::vector<std::shared_ptr<FieldValueIndex>> fields {}; std::vector<std::shared_ptr<Node>> nodes4 {}; std::vector<std::shared_ptr<Node>> nodes6 {}; RequestAnswer() {} @@ -223,7 +224,8 @@ private: */ std::function<RequestAnswer(std::shared_ptr<Node>, InfoHash&, - want_t)> onGetValues {}; + want_t, + Query)> onGetValues {}; /** * @brief on listen request callback. * @@ -235,7 +237,8 @@ private: std::function<RequestAnswer(std::shared_ptr<Node>, InfoHash&, Blob&, - uint16_t)> onListen {}; + uint16_t, + Query)> onListen {}; /** * @brief on announce request callback. * @@ -290,9 +293,9 @@ public: * @param nodes6 The ipv6 closest nodes. * @param values The values to send. */ - void tellListener(std::shared_ptr<Node> n, uint16_t rid, InfoHash hash, want_t want, Blob ntoken, - std::vector<std::shared_ptr<Node>> nodes, std::vector<std::shared_ptr<Node>> nodes6, - std::vector<std::shared_ptr<Value>> values); + void tellListener(std::shared_ptr<Node> n, uint16_t rid, const InfoHash& hash, want_t want, const Blob& ntoken, + std::vector<std::shared_ptr<Node>>&& nodes, std::vector<std::shared_ptr<Node>>&& nodes6, + std::vector<std::shared_ptr<Value>>&& values, const Query& q); bool isRunning(sa_family_t af) const; inline want_t want () const { return dht_socket >= 0 && dht_socket6 >= 0 ? (WANT4 | WANT6) : -1; } @@ -314,13 +317,15 @@ public: RequestExpiredCb on_expired); std::shared_ptr<Request> sendGetValues(std::shared_ptr<Node> n, - const InfoHash& target, + const InfoHash& info_hash, + const Query& query, want_t want, RequestCb on_done, RequestExpiredCb on_expired); std::shared_ptr<Request> sendListen(std::shared_ptr<Node> n, const InfoHash& infohash, + const Query& query, const Blob& token, RequestCb on_done, RequestExpiredCb on_expired); @@ -432,6 +437,7 @@ private: const Blob& nodes, const Blob& nodes6, const std::vector<std::shared_ptr<Value>>& st, + const Query& query, const Blob& token); Blob bufferNodes(sa_family_t af, const InfoHash& id, std::vector<std::shared_ptr<Node>>& nodes); @@ -452,7 +458,7 @@ private: const std::string& message, bool include_id=false); - void deserializeNodesValues(ParsedMessage& msg); + void deserializeNodes(ParsedMessage& msg); std::queue<time_point> rate_limit_time {}; static std::mt19937 rd_device; diff --git a/include/opendht/securedht.h b/include/opendht/securedht.h index d5b5e6d040fd8c6f9d545cf067f761d8aa716bc0..11b92fb9d196354139ae19703a664b78eb7d5dac 100644 --- a/include/opendht/securedht.h +++ b/include/opendht/securedht.h @@ -85,18 +85,18 @@ public: * If the signature can't be checked, or if the data can't be decrypted, it is not returned. * Public, non-signed & non-encrypted data is retransmitted as-is. */ - virtual void get(const InfoHash& id, GetCallback cb, DoneCallback donecb={}, Value::Filter&& = {}) override; - virtual void get(const InfoHash& id, GetCallback cb, DoneCallbackSimple donecb={}, Value::Filter&& f = {}) override { - get(id, cb, bindDoneCb(donecb), std::forward<Value::Filter>(f)); + virtual void get(const InfoHash& id, GetCallback cb, DoneCallback donecb={}, Value::Filter&& = {}, Where&& w = {}) override; + virtual void get(const InfoHash& id, GetCallback cb, DoneCallbackSimple donecb={}, Value::Filter&& f = {}, Where&& w = {}) override { + get(id, cb, bindDoneCb(donecb), std::forward<Value::Filter>(f), std::forward<Where>(w)); } - virtual void get(const InfoHash& key, GetCallbackSimple cb, DoneCallback donecb={}, Value::Filter&& f={}) override { - get(key, bindGetCb(cb), donecb, std::forward<Value::Filter>(f)); + virtual void get(const InfoHash& key, GetCallbackSimple cb, DoneCallback donecb={}, Value::Filter&& f={}, Where&& w = {}) override { + get(key, bindGetCb(cb), donecb, std::forward<Value::Filter>(f), std::forward<Where>(w)); } - virtual void get(const InfoHash& key, GetCallbackSimple cb, DoneCallbackSimple donecb, Value::Filter&& f={}) override { - get(key, bindGetCb(cb), bindDoneCb(donecb), std::forward<Value::Filter>(f)); + virtual void get(const InfoHash& key, GetCallbackSimple cb, DoneCallbackSimple donecb, Value::Filter&& f={}, Where&& w = {}) override { + get(key, bindGetCb(cb), bindDoneCb(donecb), std::forward<Value::Filter>(f), std::forward<Where>(w)); } - virtual size_t listen(const InfoHash& id, GetCallback cb, Value::Filter&& = {}) override; + virtual size_t listen(const InfoHash& id, GetCallback cb, Value::Filter&& = {}, Where&& w = {}) override; /** * Will take ownership of the value, sign it using our private key and put it in the DHT. diff --git a/include/opendht/utils.h b/include/opendht/utils.h index 0b892fd1539e2895310573f119b61406ca4fb00a..9dcb1660be57ba763beedfa52b879935ec8267d6 100644 --- a/include/opendht/utils.h +++ b/include/opendht/utils.h @@ -168,4 +168,6 @@ unpackMsg(Blob b) { msgpack::unpacked unpackMsg(Blob b); +msgpack::object* findMapValue(msgpack::object& map, const std::string& key); + } // namespace dht diff --git a/include/opendht/value.h b/include/opendht/value.h index 58351e70ced099808cdd4b55ba5906fe11e4ebf7..5c7c2ce90ed5beceb61840d84b402a7bc70aead9 100644 --- a/include/opendht/value.h +++ b/include/opendht/value.h @@ -35,10 +35,12 @@ #include <functional> #include <memory> #include <chrono> +#include <set> namespace dht { struct Value; +struct Query; /** * A storage policy is applied once to every incoming value storage requests. @@ -114,6 +116,14 @@ struct ValueType { */ struct Value { + enum class Field { + None = 0, + Id, + ValueType, + OwnerPk, + UserType, + }; + typedef uint64_t Id; static const Id INVALID_ID {0}; @@ -135,14 +145,18 @@ struct Value return f1(v) and f2(v); }; } - static Filter chain(std::initializer_list<Filter> l) { - const std::vector<Filter> list(l.begin(), l.end()); - return [list](const Value& v){ - for (const auto& f : list) + template <typename T> + static Filter chainAll(T&& set) { + using namespace std::placeholders; + return std::bind([](const Value& v, T& s) { + for (const auto& f : s) if (f and not f(v)) return false; return true; - }; + }, _1, std::move(set)); + } + static Filter chain(std::initializer_list<Filter> l) { + return chainAll(std::move(l)); } static Filter chainOr(Filter&& f1, Filter&& f2) { if (not f1 or not f2) return AllFilter(); @@ -162,6 +176,11 @@ struct Value return v.type == tid; }; } + static Filter TypeFilter(const ValueType::Id& tid) { + return [tid](const Value& v) { + return v.type == tid; + }; + } static Filter IdFilter(const Id id) { return [id](const Value& v) { @@ -175,6 +194,23 @@ struct Value }; } + static Filter ownerFilter(const crypto::PublicKey& pk) { + return ownerFilter(pk.getId()); + } + + static Filter ownerFilter(const InfoHash& pkh) { + return [pkh](const Value& v) { + return v.owner and v.owner->getId() == pkh; + }; + } + + static Filter userTypeFilter(const std::string& ut) + { + return [ut](const Value& v) { + return v.user_type == ut; + }; + } + class SerializableBase { public: @@ -415,6 +451,31 @@ struct Value pk.pack(std::string("dat")); msgpack_pack_to_encrypt(pk); } + template <typename Packer> + void msgpack_pack_fields(const std::set<Value::Field>& fields, Packer& pk) const + { + for (const auto& field : fields) + switch (field) { + case Value::Field::Id: + pk.pack(id); + break; + case Value::Field::ValueType: + pk.pack(type); + break; + case Value::Field::OwnerPk: + if (owner) + owner->msgpack_pack(pk); + else + InfoHash().msgpack_pack(pk); + break; + case Value::Field::UserType: + pk.pack(user_type); + break; + default: + break; + } + } + void msgpack_unpack(msgpack::object o); void msgpack_unpack_body(const msgpack::object& o); Blob getPacked() const { @@ -424,6 +485,8 @@ struct Value return {buffer.data(), buffer.data()+buffer.size()}; } + void msgpack_unpack_fields(const std::set<Value::Field>& fields, const msgpack::object& o, unsigned offset); + Id id {INVALID_ID}; /** @@ -467,6 +530,338 @@ struct Value using ValuesExport = std::pair<InfoHash, Blob>; +/** + * @class FieldValue + * @brief Describes a value filter. + * @details + * This structure holds the value for a specified field. It's type can either be + * uint64_t, InfoHash or Blob. + */ +struct FieldValue +{ + FieldValue() {} + FieldValue(Value::Field f, uint64_t int_value) : field(f), intValue(int_value) {} + FieldValue(Value::Field f, InfoHash hash_value) : field(f), hashValue(hash_value) {} + FieldValue(Value::Field f, Blob blob_value) : field(f), blobValue(blob_value) {} + + bool operator==(const FieldValue& fd) const; + + // accessors + Value::Field getField() const { return field; } + uint64_t getInt() const { return intValue; } + InfoHash getHash() const { return hashValue; } + Blob getBlob() const { return blobValue; } + + template <typename Packer> + void msgpack_pack(Packer& p) const { + p.pack_map(2); + p.pack(std::string("f")); p.pack(static_cast<uint8_t>(field)); + + p.pack(std::string("v")); + switch (field) { + case Value::Field::Id: + case Value::Field::ValueType: + p.pack(intValue); + break; + case Value::Field::OwnerPk: + p.pack(hashValue); + break; + case Value::Field::UserType: + p.pack_bin(blobValue.size()); + p.pack_bin_body((const char*)blobValue.data(), blobValue.size()); + break; + default: + throw msgpack::type_error(); + } + } + + void msgpack_unpack(msgpack::object msg) { + hashValue = {}; + blobValue.clear(); + + if (auto f = findMapValue(msg, "f")) + field = (Value::Field)f->as<unsigned>(); + else + throw msgpack::type_error(); + + auto v = findMapValue(msg, "v"); + if (not v) + throw msgpack::type_error(); + else + switch (field) { + case Value::Field::Id: + case Value::Field::ValueType: + intValue = v->as<decltype(intValue)>(); + break; + case Value::Field::OwnerPk: + hashValue = v->as<decltype(hashValue)>(); + break; + case Value::Field::UserType: + blobValue = unpackBlob(*v); + break; + default: + throw msgpack::type_error(); + } + } + + Value::Filter getLocalFilter() const; + +private: + Value::Field field {Value::Field::None}; + // three possible value types + uint64_t intValue {}; + InfoHash hashValue {}; + Blob blobValue {}; +}; + + +/** + * @struct FieldSelectorDescription + * @brief Describes a selection. + * @details + * This is meant to narrow data to a set of specified fields. This structure is + * used to construct a Select structure. + */ +struct FieldSelectorDescription +{ + FieldSelectorDescription() {} + FieldSelectorDescription(Value::Field f) : field(f) {} + + Value::Field getField() const { return field; } + + bool operator==(const FieldSelectorDescription& fd) const { return field == fd.field; } + + template <typename Packer> + void msgpack_pack(Packer& p) const { p.pack(static_cast<uint8_t>(field)); } + void msgpack_unpack(msgpack::object msg) { field = static_cast<Value::Field>(msg.as<int>()); } +private: + Value::Field field {Value::Field::None}; +}; + +/** + * @class Select + * @brief Serializable Value field selection. + * @details + * This is a container for a list of FieldSelectorDescription instances. It + * describes a complete SELECT query for dht::Value. + */ +struct Select +{ + Select() { } + Select(const std::string& q_str); + + bool isSatisfiedBy(const Select& os) const; + + /** + * Selects a field of type Value::Field. + * + * @param field the field to require. + * + * @return the resulting Select instance. + */ + Select& field(Value::Field field) { + fieldSelection_.emplace_back(field); + return *this; + } + + /** + * Computes the set of selected fields based on previous require* calls. + * + * @return the set of fields. + */ + std::set<Value::Field> getSelection() const { + std::set<Value::Field> fields {}; + for (const auto& f : fieldSelection_) { + fields.insert(f.getField()); + } + return fields; + } + + template <typename Packer> + void msgpack_pack(Packer& pk) const { pk.pack(fieldSelection_); } + void msgpack_unpack(const msgpack::object& o) { + fieldSelection_.clear(); + fieldSelection_ = o.as<decltype(fieldSelection_)>(); + } + + friend std::ostream& operator<<(std::ostream& s, const dht::Select& q); +private: + std::vector<FieldSelectorDescription> fieldSelection_ {}; +}; + +/** + * @class Where + * @brief Serializable dht::Value filter. + * @details + * This is container for a list of FieldValue instances. It describes a + * complete WHERE query for dht::Value. + */ +struct Where +{ + Where() { } + Where(const std::string& q_str); + + bool isSatisfiedBy(const Where& where) const; + + /** + * Adds restriction on Value::Id based on the id argument. + * + * @param id the id. + * + * @return the resulting Where instance. + */ + Where& id(Value::Id id) { + filters_.emplace_back(Value::Field::Id, id); + return *this; + } + + /** + * Adds restriction on Value::ValueType based on the type argument. + * + * @param type the value type. + * + * @return the resulting Where instance. + */ + Where& valueType(ValueType::Id type) { + filters_.emplace_back(Value::Field::ValueType, type); + return *this; + } + + /** + * Adds restriction on Value::OwnerPk based on the owner_pk_hash argument. + * + * @param owner_pk_hash the owner public key fingerprint. + * + * @return the resulting Where instance. + */ + Where& owner(InfoHash owner_pk_hash) { + filters_.emplace_back(Value::Field::OwnerPk, owner_pk_hash); + return *this; + } + + /** + * Adds restriction on Value::UserType based on the user_type argument. + * + * @param user_type the user type. + * + * @return the resulting Where instance. + */ + Where& userType(std::string user_type) { + filters_.emplace_back(Value::Field::UserType, Blob {user_type.begin(), user_type.end()}); + return *this; + } + + /** + * Computes the Value::Filter based on the list of field value set. + * + * @return the resulting Value::Filter. + */ + Value::Filter getFilter() const { + std::vector<Value::Filter> fset(filters_.size()); + std::transform(filters_.begin(), filters_.end(), fset.begin(), [](const FieldValue& f) { + return f.getLocalFilter(); + }); + return Value::Filter::chainAll(std::move(fset)); + } + + template <typename Packer> + void msgpack_pack(Packer& pk) const { pk.pack(filters_); } + void msgpack_unpack(const msgpack::object& o) { + filters_.clear(); + filters_ = o.as<decltype(filters_)>(); + } + + friend std::ostream& operator<<(std::ostream& s, const dht::Where& q); + +private: + std::vector<FieldValue> filters_; +}; + +/** + * @class Query + * @brief Describes a query destined to another peer. + * @details + * This class describes the list of filters on field values and the field + * itselves to include in the peer response to a GET operation. See + * FieldValue. + */ +struct Query +{ + static const std::string QUERY_PARSE_ERROR; + + Query(Select s = {}, Where w = {}) : select(s), where(w) { }; + + /** + * Initializes a query based on a SQL-ish formatted string. The abstract + * form of such a string is the following: + * + * [SELECT <$field$> [WHERE <$field$=$value$>]] + * + * where + * + * - $field$ = *|id|value_type|owner_pk|user_type + * - $value$ = $string$|$integer$ + * - $string$: a simple string WITHOUT SPACES. + * - $integer$: a simple integer. + */ + Query(std::string q_str) { + auto pos_W = q_str.find("WHERE"); + auto pos_w = q_str.find("where"); + auto pos = std::min(pos_W != std::string::npos ? pos_W : q_str.size(), + pos_w != std::string::npos ? pos_w : q_str.size()); + select = q_str.substr(0, pos); + where = q_str.substr(pos, q_str.size()-pos); + } + + /** + * Tell if the query is satisfied by another query. + */ + bool isSatisfiedBy(const Query& q) const; + + template <typename Packer> + void msgpack_pack(Packer& pk) const { + pk.pack_map(2); + pk.pack(std::string("s")); pk.pack(select); /* packing field selectors */ + pk.pack(std::string("w")); pk.pack(where); /* packing filters */ + } + + void msgpack_unpack(const msgpack::object& o); + + friend std::ostream& operator<<(std::ostream& s, const dht::Query& q) { + s << "Query[" << q.select << " " << q.where << "]"; + } + + Select select {}; + Where where {}; +}; + +/*! + * @class FieldValueIndex + * @brief An index for field values. + * @details + * This structures is meant to manipulate a subset of fields normally contained + * in Value. + */ +struct FieldValueIndex { + FieldValueIndex() {} + FieldValueIndex(const Value& v, Select s = {}); + /** + * Tells if all the fields of this are contained in the other + * FieldValueIndex with the same value. + * + * @param other The other FieldValueIndex instance. + */ + bool containedIn(const FieldValueIndex& other) const; + + friend std::ostream& operator<<(std::ostream& os, const FieldValueIndex& fvi); + + void msgpack_unpack_fields(const std::set<Value::Field>& fields, + const msgpack::object& o, + unsigned offset); + + std::map<Value::Field, FieldValue> index {}; +}; + template <typename T, typename std::enable_if<std::is_base_of<Value::SerializableBase, T>::value, T>::type* = nullptr> Value::Filter diff --git a/src/dht.cpp b/src/dht.cpp index 590bea804180f610b54b64d50782e3af0490963a..9ebc4f81c1a06f295ce7bda487b35f1c10cf94bd 100644 --- a/src/dht.cpp +++ b/src/dht.cpp @@ -176,6 +176,7 @@ struct Dht::SearchNode { SearchNode(std::shared_ptr<Node> node) : node(node) {} using AnnounceStatusMap = std::map<Value::Id, std::shared_ptr<Request>>; + using SyncStatusMap = std::map<std::shared_ptr<Query>, std::shared_ptr<Request>>; /** * Can we use this node to listen/announce now ? @@ -189,10 +190,31 @@ struct Dht::SearchNode { * Could a "get" request be sent to this node now ? * update: time of the last "get" op for the search. */ - bool canGet(time_point now, time_point update) const { + bool canGet(time_point now, time_point update, std::shared_ptr<Query> q = {}) const { + const auto& get_status = q ? getStatus.find(q) : getStatus.end(); + const auto& sq_status = std::find_if(getStatus.begin(), getStatus.end(), + [q](const SyncStatusMap::value_type& s) { + return q and s.first and q->isSatisfiedBy(*s.first); + } + ); return not node->isExpired() and (now > last_get_reply + Node::NODE_EXPIRE_TIME or update > last_get_reply) - and (not getStatus or not getStatus->pending()); + and ((get_status == getStatus.end() or not get_status->second or not get_status->second->pending()) and + (sq_status == getStatus.end() or not sq_status->second or not sq_status->second->pending())); + } + + bool expired(const SyncStatusMap& status) const { + return std::find_if(status.begin(), status.end(), + [](const SyncStatusMap::value_type& r){ + return r.second and r.second->expired(); + }) != status.end(); + } + + bool pending(const SyncStatusMap& status) const { + return std::find_if(status.begin(), status.end(), + [](const SyncStatusMap::value_type& r){ + return r.second and r.second->pending(); + }) != status.end(); } bool isAnnounced(Value::Id vid, const ValueType& type, time_point now) const { @@ -202,11 +224,27 @@ struct Dht::SearchNode { } return ack->second->reply_time + type.expiration > now; } + bool isListening(time_point now) const { - if (not listenStatus) + auto ls = listenStatus.begin(); + for ( ; ls != listenStatus.end() ; ++ls) { + if (isListening(now, ls)) { + break; + } + } + return ls != listenStatus.end(); + } + bool isListening(time_point now, const std::shared_ptr<Query>& q) const { + const auto& ls = listenStatus.find(q); + if (ls == listenStatus.end()) return false; - - return listenStatus->reply_time + LISTEN_EXPIRE_TIME > now; + else + return isListening(now, ls); + } + bool isListening(time_point now, SyncStatusMap::const_iterator listen_status) const { + if (listen_status == listenStatus.end()) + return false; + return listen_status->second->reply_time + LISTEN_EXPIRE_TIME > now; } /** @@ -226,10 +264,20 @@ struct Dht::SearchNode { * Assumng the node is synced, should a "listen" request be sent to this node now ? */ time_point getListenTime() const { - if (not listenStatus) + time_point t {time_point::max()}; + for (auto ls = listenStatus.begin(); ls != listenStatus.end() ; ++ls) { + t = std::min(t, getListenTime(ls)); + } + return t; + } + time_point getListenTime(const std::shared_ptr<Query>& q) const { + return getListenTime(listenStatus.find(q)); + } + time_point getListenTime(SyncStatusMap::const_iterator listen_status) const { + if (listen_status == listenStatus.end()) return time_point::min(); - - return listenStatus->pending() ? time_point::max() : listenStatus->reply_time + LISTEN_EXPIRE_TIME - REANNOUNCE_MARGIN; + return listen_status->second->pending() ? time_point::max() : + listen_status->second->reply_time + LISTEN_EXPIRE_TIME - REANNOUNCE_MARGIN; } /** @@ -241,10 +289,10 @@ struct Dht::SearchNode { std::shared_ptr<Node> node {}; - time_point last_get_reply {time_point::min()}; /* last time received valid token */ - std::shared_ptr<Request> getStatus {}; /* get/sync status */ - std::shared_ptr<Request> listenStatus {}; - AnnounceStatusMap acked {}; /* announcement status for a given value id */ + time_point last_get_reply {time_point::min()}; /* last time received valid token */ + SyncStatusMap getStatus {}; /* get/sync status */ + SyncStatusMap listenStatus {}; /* listen status */ + AnnounceStatusMap acked {}; /* announcement status for a given value id */ Blob token {}; @@ -272,7 +320,7 @@ struct Dht::Search { std::vector<Announce> announce {}; /* pending gets */ - std::vector<Get> callbacks {}; + std::multimap<time_point, Get> callbacks {}; /* listeners */ std::map<size_t, LocalListener> listeners {}; @@ -295,7 +343,7 @@ struct Dht::Search { unsigned currentGetRequests() const { unsigned count = 0; for (const auto& n : nodes) - if (not n.isBad() and n.getStatus and n.getStatus->pending()) + if (not n.isBad() and n.pending(n.getStatus)) count++; return count; } @@ -724,48 +772,64 @@ Dht::searchSendGetValues(std::shared_ptr<Search> sr, SearchNode* pn, bool update const auto& now = scheduler.time(); const time_point up = update ? sr->getLastGetTime() : time_point::min(); + SearchNode* n = nullptr; - if (pn) { - if (not pn->canGet(now, up)) - return nullptr; - n = pn; - } else { - for (auto& sn : sr->nodes) { - if (sn.canGet(now, up)) { - n = &sn; - break; + auto cb = sr->callbacks.begin(); + do { /* for all queries to send */ + + /* cases v 'get' v 'find_node' */ + auto query = cb != sr->callbacks.end() ? cb->second.query : std::make_shared<Query>(); + if (pn) { + if (not pn->canGet(now, up, query)) + return nullptr; + n = pn; + } else { + for (auto& sn : sr->nodes) { + if (sn.canGet(now, up, query)) { + n = &sn; + break; + } } + if (not n) + return nullptr; } - if (not n) - return nullptr; - } - /*DHT_LOG.DEBUG("[search %s IPv%c] [node %s] sending 'get'", - sr->id.toString().c_str(), sr->af == AF_INET ? '4' : '6', - n->node->toString().c_str());*/ + std::weak_ptr<Search> ws = sr; + auto onDone = + [this,ws,query](const Request& status, NetworkEngine::RequestAnswer&& answer) mutable { + if (auto sr = ws.lock()) { + sr->insertNode(status.node, scheduler.time(), answer.ntoken); + onGetValuesDone(status, answer, sr, query); + } + }; + auto onExpired = + [this,ws,query](const Request& status, bool over) mutable { + if (auto sr = ws.lock()) { + if (auto srn = sr->getNode(status.node)) { + srn->candidate = not over; + if (over) + srn->getStatus.erase(query); + } + scheduler.edit(sr->nextSearchStep, scheduler.time()); + } + }; + std::shared_ptr<Request> rstatus; + if (sr->callbacks.empty() and sr->listeners.empty()) { + DHT_LOG.WARN("[search %s IPv%c] [node %s] sending 'find_node'", + sr->id.toString().c_str(), sr->af == AF_INET ? '4' : '6', + n->node->toString().c_str()); + rstatus = network_engine.sendFindNode(n->node, sr->id, -1, onDone, onExpired); + } else { + DHT_LOG.WARN("[search %s IPv%c] [node %s] sending 'get'", + sr->id.toString().c_str(), sr->af == AF_INET ? '4' : '6', + n->node->toString().c_str()); + rstatus = network_engine.sendGetValues(n->node, sr->id, query ? *query : Query {}, -1, onDone, onExpired); + } + n->getStatus[query] = rstatus; - std::weak_ptr<Search> ws = sr; - auto onDone = - [this,ws](const Request& status, NetworkEngine::RequestAnswer&& answer) mutable { - if (auto sr = ws.lock()) { - sr->insertNode(status.node, scheduler.time(), answer.ntoken); - onGetValuesDone(status, answer, sr); - } - }; - auto onExpired = - [this,ws](const Request& status, bool over) mutable { - if (auto sr = ws.lock()) { - if (auto srn = sr->getNode(status.node)) - srn->candidate = not over; - scheduler.edit(sr->nextSearchStep, scheduler.time()); - } - }; - std::shared_ptr<Request> rstatus; - if (sr->callbacks.empty() and sr->listeners.empty()) - rstatus = network_engine.sendFindNode(n->node, sr->id, -1, onDone, onExpired); - else - rstatus = network_engine.sendGetValues(n->node, sr->id, -1, onDone, onExpired); - n->getStatus = rstatus; + if (not sr->isSynced(now) or cb == sr->callbacks.end()) + break; /* only trying to find nodes, only send the oldest query */ + } while (++cb != sr->callbacks.end()); return n; } @@ -777,13 +841,15 @@ Dht::searchStep(std::shared_ptr<Search> sr) if (not sr or sr->expired or sr->done) return; const auto& now = scheduler.time(); - DHT_LOG.DEBUG("[search %s IPv%c] step (%d requests)", sr->id.toString().c_str(), sr->af == AF_INET ? '4' : '6', sr->currentGetRequests()); + DHT_LOG.DEBUG("[search %s IPv%c] step (%d requests)", + sr->id.toString().c_str(), sr->af == AF_INET ? '4' : '6', sr->currentGetRequests()); sr->step_time = now; if (sr->refill_time + Node::NODE_EXPIRE_TIME < now and sr->nodes.size()-sr->getNumberOfBadNodes() < SEARCH_NODES) { if (auto added = sr->refill(sr->af == AF_INET ? buckets : buckets6, now)) { sr->refill_time = now; - DHT_LOG.DEBUG("[search %s IPv%c] refilled with %u nodes", sr->id.toString().c_str(), (sr->af == AF_INET) ? '4' : '6', added); + DHT_LOG.DEBUG("[search %s IPv%c] refilled with %u nodes", + sr->id.toString().c_str(), (sr->af == AF_INET) ? '4' : '6', added); } } @@ -793,9 +859,11 @@ Dht::searchStep(std::shared_ptr<Search> sr) // search is synced but some (newer) get operations are not complete // Call callbacks when done for (auto b = sr->callbacks.begin(); b != sr->callbacks.end();) { - if (sr->isDone(*b, now)) { - if (b->done_cb) - b->done_cb(true, sr->getNodes()); + if (sr->isDone(b->second, now)) { + if (b->second.done_cb) + b->second.done_cb(true, sr->getNodes()); + for (auto& n : sr->nodes) + n.getStatus.erase(b->second.query); b = sr->callbacks.erase(b); } else @@ -808,43 +876,50 @@ Dht::searchStep(std::shared_ptr<Search> sr) // true if this node is part of the target nodes cluter. bool in = sr->id.xorCmp(myid, sr->nodes.back().node->id) < 0; - DHT_LOG.DEBUG("[search %s IPv%c] synced%s", sr->id.toString().c_str(), sr->af == AF_INET ? '4' : '6', in ? ", in" : ""); + DHT_LOG.DEBUG("[search %s IPv%c] synced%s", + sr->id.toString().c_str(), sr->af == AF_INET ? '4' : '6', in ? ", in" : ""); if (not sr->listeners.empty()) { unsigned i = 0; for (auto& n : sr->nodes) { if (not n.isSynced(now)) continue; - if (n.getListenTime() <= now) { - DHT_LOG.WARN("[search %s IPv%c] [node %s] sending 'listen'", - sr->id.toString().c_str(), sr->af == AF_INET ? '4' : '6', - n.node->toString().c_str()); - //std::cout << "Sending listen to " << n.node->id << " " << print_addr(n.node->ss, n.node->sslen) << std::endl; - - //network_engine.cancelRequest(n.listenStatus); - auto ls = n.listenStatus; - - std::weak_ptr<Search> ws = sr; - n.listenStatus = network_engine.sendListen(n.node, sr->id, n.token, - [this,ws,ls](const Request& status, - NetworkEngine::RequestAnswer&& answer) mutable - { /* on done */ - // cancel previous request - network_engine.cancelRequest(ls); - if (auto sr = ws.lock()) { - onListenDone(status, answer, sr); - searchStep(sr); + for (const auto& l : sr->listeners) { + const auto& query = l.second.query; + if (n.getListenTime(query) <= now) { + DHT_LOG.WARN("[search %s IPv%c] [node %s] sending 'listen'", + sr->id.toString().c_str(), sr->af == AF_INET ? '4' : '6', + n.node->toString().c_str()); + //std::cout << "Sending listen to " << n.node->id << " " << print_addr(n.node->ss, n.node->sslen) << std::endl; + + const auto& r = n.listenStatus.find(query); + auto last_req = r != n.listenStatus.end() ? r->second : std::shared_ptr<Request> {}; + + std::weak_ptr<Search> ws = sr; + n.listenStatus[query] = network_engine.sendListen(n.node, sr->id, *query, n.token, + [this,ws,last_req,query](const Request& req, + NetworkEngine::RequestAnswer&& answer) mutable + { /* on done */ + network_engine.cancelRequest(last_req); + if (auto sr = ws.lock()) { + onListenDone(req, answer, sr, query); + searchStep(sr); + if (auto sn = sr->getNode(req.node)) + sn->listenStatus.erase(query); + } + }, + [this,ws,last_req,query](const Request& req, bool over) mutable + { /* on expired */ + network_engine.cancelRequest(last_req); + if (auto sr = ws.lock()) { + searchStep(sr); + if (over) + if (auto sn = sr->getNode(req.node)) + sn->listenStatus.erase(query); + } } - }, - [this,ws,ls](const Request&, bool over) mutable - { /* on expired */ - if (over) { - network_engine.cancelRequest(ls); - if (auto sr = ws.lock()) - scheduler.edit(sr->nextSearchStep, scheduler.time()); - } - } - ); + ); + } } if (not n.candidate and ++i == LISTEN_NODES) break; @@ -928,8 +1003,8 @@ Dht::searchStep(std::shared_ptr<Search> sr) { auto get_cbs = std::move(sr->callbacks); for (const auto& g : get_cbs) { - if (g.done_cb) - g.done_cb(false, {}); + if (g.second.done_cb) + g.second.done_cb(false, {}); } } { @@ -994,7 +1069,7 @@ Dht::Search::getLastGetTime() const { time_point last = time_point::min(); for (const auto& g : callbacks) - last = std::max(last, g.start); + last = std::max(last, g.second.start); return last; } @@ -1004,10 +1079,11 @@ Dht::Search::isDone(const Get& get, time_point now) const unsigned i = 0; const auto limit = std::max(get.start, now - Node::NODE_EXPIRE_TIME); for (const auto& sn : nodes) { + const auto& gs = sn.getStatus.find(get.query); if (sn.isBad()) continue; - if (sn.last_get_reply < limit) - return false; + if (gs == sn.getStatus.end() or not gs->second or gs->second->reply_time < limit) + return false; if (++i == TARGET_NODES) break; } @@ -1024,7 +1100,7 @@ Dht::Search::getUpdateTime(time_point now) const for (const auto& sn : nodes) { if (sn.node->isExpired() or (sn.candidate and t >= TARGET_NODES)) continue; - bool pending = sn.getStatus and sn.getStatus->pending(); + auto pending = sn.pending(sn.getStatus); if (sn.last_get_reply < std::max(now - Node::NODE_EXPIRE_TIME, last_get) or pending) { // not isSynced if (not pending and reqs < SEARCH_REQUESTS) @@ -1072,7 +1148,12 @@ Dht::Search::isListening(time_point now) const for (const auto& n : nodes) { if (n.isBad()) continue; - if (!n.isListening(now)) + SearchNode::SyncStatusMap::const_iterator ls {}; + for (ls = n.listenStatus.begin(); ls != n.listenStatus.end() ; ++ls) { + if (n.isListening(now, ls)) + break; + } + if (ls == n.listenStatus.end()) return false; if (++i == LISTEN_NODES) break; @@ -1203,12 +1284,12 @@ Dht::Search::refill(const RoutingTable& r, time_point now) { /* Start a search. */ std::shared_ptr<Dht::Search> -Dht::search(const InfoHash& id, sa_family_t af, GetCallback callback, DoneCallback done_callback, Value::Filter filter) +Dht::search(const InfoHash& id, sa_family_t af, GetCallback gcb, QueryCallback qcb, DoneCallback dcb, Value::Filter f, Query q) { if (!isRunning(af)) { DHT_LOG.ERR("[search %s IPv%c] unsupported protocol", id.toString().c_str(), (af == AF_INET) ? '4' : '6'); - if (done_callback) - done_callback(false, {}); + if (dcb) + dcb(false, {}); return {}; } @@ -1248,14 +1329,22 @@ Dht::search(const InfoHash& id, sa_family_t af, GetCallback callback, DoneCallba search_id++; } - if (callback) - sr->callbacks.push_back({/*.start=*/scheduler.time(), /*.filter=*/filter, /*.get_cb=*/callback, /*.done_cb=*/done_callback}); - bootstrapSearch(*sr); + if (gcb or qcb) { + auto now = scheduler.time(); + sr->callbacks.insert(std::make_pair<time_point, Get>( + std::move(now), + Get { scheduler.time(), f, std::make_shared<Query>(q), + qcb ? qcb : QueryCallback {}, gcb ? gcb : GetCallback {}, dcb + } + )); + } + bootstrapSearch(*sr); if (sr->nextSearchStep) scheduler.edit(sr->nextSearchStep, sr->getNextStepTime(types, scheduler.time())); else sr->nextSearchStep = scheduler.add(scheduler.time(), std::bind(&Dht::searchStep, this, sr)); + return sr; } @@ -1271,7 +1360,7 @@ Dht::announce(const InfoHash& id, sa_family_t af, std::shared_ptr<Value> value, } auto& srs = af == AF_INET ? searches4 : searches6; auto srp = srs.find(id); - auto sr = srp == srs.end() ? search(id, af, nullptr, nullptr) : srp->second; + auto sr = srp == srs.end() ? search(id, af) : srp->second; if (!sr) { if (callback) callback(false, {}); @@ -1317,7 +1406,7 @@ Dht::announce(const InfoHash& id, sa_family_t af, std::shared_ptr<Value> value, } size_t -Dht::listenTo(const InfoHash& id, sa_family_t af, GetCallback cb, Value::Filter f) +Dht::listenTo(const InfoHash& id, sa_family_t af, GetCallback cb, Value::Filter f, const std::shared_ptr<Query>& q) { const auto& now = scheduler.time(); if (!isRunning(af)) @@ -1327,22 +1416,23 @@ Dht::listenTo(const InfoHash& id, sa_family_t af, GetCallback cb, Value::Filter //DHT_LOG.WARN("listenTo %s", id.toString().c_str()); auto& srs = af == AF_INET ? searches4 : searches6; auto srp = srs.find(id); - std::shared_ptr<Search> sr = (srp == srs.end()) ? search(id, af, nullptr, nullptr) : srp->second; + std::shared_ptr<Search> sr = (srp == srs.end()) ? search(id, af) : srp->second; if (!sr) throw DhtException("Can't create search"); DHT_LOG.ERR("[search %s IPv%c] listen", id.toString().c_str(), (af == AF_INET) ? '4' : '6'); sr->done = false; auto token = ++sr->listener_token; - sr->listeners.emplace(token, LocalListener{f, cb}); + sr->listeners.emplace(token, LocalListener{q, f, cb}); scheduler.edit(sr->nextSearchStep, sr->getNextStepTime(types, now)); return token; } size_t -Dht::listen(const InfoHash& id, GetCallback cb, Value::Filter&& f) +Dht::listen(const InfoHash& id, GetCallback cb, Value::Filter&& f, Where&& where) { scheduler.syncTime(); + Query q {{}, where}; auto vals = std::make_shared<std::map<Value::Id, std::shared_ptr<Value>>>(); auto token = ++listener_token; @@ -1367,6 +1457,8 @@ Dht::listen(const InfoHash& id, GetCallback cb, Value::Filter&& f) return true; }; + auto query = std::make_shared<Query>(q); + auto filter = f.chain(q.where.getFilter()); auto st = findStorage(id); size_t tokenlocal = 0; if (st == store.end() && store.size() < MAX_HASHES) { @@ -1375,7 +1467,7 @@ Dht::listen(const InfoHash& id, GetCallback cb, Value::Filter&& f) } if (st != store.end()) { if (not st->empty()) { - std::vector<std::shared_ptr<Value>> newvals = st->get(f); + std::vector<std::shared_ptr<Value>> newvals = st->get(filter); if (not newvals.empty()) { if (!cb(newvals)) return 0; @@ -1387,11 +1479,11 @@ Dht::listen(const InfoHash& id, GetCallback cb, Value::Filter&& f) } } tokenlocal = ++st->listener_token; - st->local_listeners.emplace(tokenlocal, LocalListener{f, gcb}); + st->local_listeners.emplace(tokenlocal, LocalListener{query, filter, gcb}); } - auto token4 = Dht::listenTo(id, AF_INET, gcb, f); - auto token6 = Dht::listenTo(id, AF_INET6, gcb, f); + auto token4 = Dht::listenTo(id, AF_INET, gcb, filter, query); + auto token6 = Dht::listenTo(id, AF_INET6, gcb, filter, query); DHT_LOG.DEBUG("Added listen : %d -> %d %d %d", token, tokenlocal, token4, token6); listeners.emplace(token, std::make_tuple(tokenlocal, token4, token6)); @@ -1424,9 +1516,11 @@ Dht::cancelListen(const InfoHash& id, size_t token) s->listeners.erase(af_token); if (s->listeners.empty()) { for (auto& sn : s->nodes) { - // also erase requests for all searchnodes. - network_engine.cancelRequest(sn.listenStatus); - sn.listenStatus.reset(); + /* also erase requests for all searchnodes. */ + for (auto& ls : sn.listenStatus) { + network_engine.cancelRequest(ls.second); + } + sn.listenStatus.clear(); } } } @@ -1475,69 +1569,138 @@ Dht::put(const InfoHash& id, std::shared_ptr<Value> val, DoneCallback callback, }, created, permanent); } +template <typename T> struct OpStatus { - bool done {false}; - bool ok {false}; + struct Status { + bool done {false}; + bool ok {false}; + }; + Status status; + Status status4; + Status status6; + std::vector<std::shared_ptr<T>> values; + std::vector<std::shared_ptr<Node>> nodes; +}; + +template <typename T> +void doneCallbackWrapper(DoneCallback dcb, const std::vector<std::shared_ptr<Node>>& nodes, std::shared_ptr<OpStatus<T>> op) { + if (op->status.done) + return; + op->nodes.insert(op->nodes.end(), nodes.begin(), nodes.end()); + if (op->status.ok || (op->status4.done and op->status6.done)) { + bool ok = op->status.ok || op->status4.ok || op->status6.ok; + op->status.done = true; + if (dcb) + dcb(ok, op->nodes); + } +}; + +template <typename T, typename Cb> +bool callbackWrapper(Cb get_cb, + DoneCallback done_cb, + const std::vector<std::shared_ptr<T>>& values, + std::function<std::vector<std::shared_ptr<T>>(const std::vector<std::shared_ptr<T>>&)> add_values, + std::shared_ptr<OpStatus<T>> op) +{ + if (op->status.done) + return false; + auto newvals = add_values(values); + if (not newvals.empty()) { + op->status.ok = !get_cb(newvals); + op->values.insert(op->values.end(), newvals.begin(), newvals.end()); + } + doneCallbackWrapper(done_cb, {}, op); + return !op->status.ok; }; void -Dht::get(const InfoHash& id, GetCallback getcb, DoneCallback donecb, Value::Filter&& filter) +Dht::get(const InfoHash& id, GetCallback getcb, DoneCallback donecb, Value::Filter&& filter, Where&& where) { scheduler.syncTime(); - auto status = std::make_shared<OpStatus>(); - auto status4 = std::make_shared<OpStatus>(); - auto status6 = std::make_shared<OpStatus>(); - auto vals = std::make_shared<std::vector<std::shared_ptr<Value>>>(); - auto all_nodes = std::make_shared<std::vector<std::shared_ptr<Node>>>(); + Query q {{}, where}; + auto op = std::make_shared<OpStatus<Value>>(); - auto done_l = [=](const std::vector<std::shared_ptr<Node>>& nodes) { - if (status->done) - return; - all_nodes->insert(all_nodes->end(), nodes.begin(), nodes.end()); - if (status->ok || (status4->done && status6->done)) { - bool ok = status->ok || status4->ok || status6->ok; - status->done = true; - if (donecb) - donecb(ok, *all_nodes); - } - }; - auto cb = [=](const std::vector<std::shared_ptr<Value>>& values) { - if (status->done) - return false; + auto f = filter.chain(q.where.getFilter()); + auto add_values = [=](const std::vector<std::shared_ptr<Value>>& values) { std::vector<std::shared_ptr<Value>> newvals {}; for (const auto& v : values) { - auto it = std::find_if(vals->cbegin(), vals->cend(), [&](const std::shared_ptr<Value>& sv) { - return sv == v || *sv == *v; + auto it = std::find_if(op->values.cbegin(), op->values.cend(), [&](const std::shared_ptr<Value>& sv) { + return sv == v or *sv == *v; }); - if (it == vals->cend()) { - if (!filter || filter(*v)) - newvals.push_back(v); + if (it == op->values.cend()) { + if (not f or f(*v)) + newvals.push_back(v); } } - if (!newvals.empty()) { - status->ok = !getcb(newvals); - vals->insert(vals->end(), newvals.begin(), newvals.end()); - } - done_l({}); - return !status->ok; + return newvals; }; + auto gcb = std::bind(callbackWrapper<Value, GetCallback>, getcb, donecb, _1, add_values, op); /* Try to answer this search locally. */ - cb(getLocal(id, filter)); + gcb(getLocal(id, f)); - Dht::search(id, AF_INET, cb, [=](bool ok, const std::vector<std::shared_ptr<Node>>& nodes) { + Dht::search(id, AF_INET, gcb, {}, [=](bool ok, const std::vector<std::shared_ptr<Node>>& nodes) { //DHT_LOG.WARN("DHT done IPv4"); - status4->done = true; - status4->ok = ok; - done_l(nodes); - }); - Dht::search(id, AF_INET6, cb, [=](bool ok, const std::vector<std::shared_ptr<Node>>& nodes) { + op->status4.done = true; + op->status4.ok = ok; + doneCallbackWrapper(donecb, nodes, op); + }, f, q); + Dht::search(id, AF_INET6, gcb, {}, [=](bool ok, const std::vector<std::shared_ptr<Node>>& nodes) { //DHT_LOG.WARN("DHT done IPv6"); - status6->done = true; - status6->ok = ok; - done_l(nodes); + op->status6.done = true; + op->status6.ok = ok; + doneCallbackWrapper(donecb, nodes, op); + }, f, q); +} + +void Dht::query(const InfoHash& id, QueryCallback cb, DoneCallback done_cb, Query&& q) +{ + scheduler.syncTime(); + auto op = std::make_shared<OpStatus<FieldValueIndex>>(); + + auto f = q.where.getFilter(); + auto values = getLocal(id, f); + auto add_fields = [=](const std::vector<std::shared_ptr<FieldValueIndex>>& fields) { + std::vector<std::shared_ptr<FieldValueIndex>> newvals {}; + for (const auto& f : fields) { + auto it = std::find_if(op->values.cbegin(), op->values.cend(), + [&](const std::shared_ptr<FieldValueIndex>& sf) { + return sf == f or f->containedIn(*sf); + }); + if (it == op->values.cend()) { + auto lesser = std::find_if(op->values.begin(), op->values.end(), + [&](const std::shared_ptr<FieldValueIndex>& sf) { + return sf->containedIn(*f); + }); + if (lesser != op->values.end()) + op->values.erase(lesser); + newvals.push_back(f); + } + } + return newvals; + }; + std::vector<std::shared_ptr<FieldValueIndex>> local_fields(values.size()); + std::transform(values.begin(), values.end(), local_fields.begin(), [](const std::shared_ptr<Value>& v) { + return std::make_shared<FieldValueIndex>(*v); }); + auto qcb = std::bind(callbackWrapper<FieldValueIndex, QueryCallback>, cb, done_cb, _1, add_fields, op); + + /* Try to answer this search locally. */ + qcb(local_fields); + + Dht::search(id, AF_INET, {}, qcb, [=](bool ok, const std::vector<std::shared_ptr<Node>>& nodes) { + //DHT_LOG.WARN("DHT done IPv4"); + op->status4.done = true; + op->status4.ok = ok; + doneCallbackWrapper(done_cb, nodes, op); + }, f, q); + Dht::search(id, AF_INET6, {}, qcb, [=](bool ok, const std::vector<std::shared_ptr<Node>>& nodes) { + //DHT_LOG.WARN("DHT done IPv6"); + op->status6.done = true; + op->status6.ok = ok; + doneCallbackWrapper(done_cb, nodes, op); + }, f, q); } std::vector<std::shared_ptr<Value>> @@ -1658,10 +1821,14 @@ Dht::storageChanged(Storage& st, ValueStorage& v) for (const auto& l : st.listeners) { DHT_LOG.DEBUG("Storage changed. Sending update to %s.", l.first->toString().c_str()); - std::vector<std::shared_ptr<Value>> vals; + auto f = l.second.query.where.getFilter(); + if (f and not f(*v.data)) + continue; + std::vector<std::shared_ptr<Value>> vals {}; vals.push_back(v.data); Blob ntoken = makeToken((const sockaddr*)&l.first->ss, false); - network_engine.tellListener(l.first, l.second.rid, st.id, 0, ntoken, {}, {}, vals); + network_engine.tellListener(l.first, l.second.rid, st.id, 0, ntoken, {}, {}, + std::move(vals), l.second.query); } } @@ -1724,7 +1891,7 @@ Dht::Storage::clear() } void -Dht::storageAddListener(const InfoHash& id, const std::shared_ptr<Node>& node, size_t rid) +Dht::storageAddListener(const InfoHash& id, const std::shared_ptr<Node>& node, size_t rid, Query&& query) { const auto& now = scheduler.time(); auto st = findStorage(id); @@ -1736,16 +1903,13 @@ Dht::storageAddListener(const InfoHash& id, const std::shared_ptr<Node>& node, s } auto l = st->listeners.find(node); if (l == st->listeners.end()) { - const auto& stvalues = st->getValues(); - if (not stvalues.empty()) { - std::vector<std::shared_ptr<Value>> values(stvalues.size()); - std::transform(stvalues.begin(), stvalues.end(), values.begin(), [=](const ValueStorage& vs) { return vs.data; }); - + auto vals = st->get(query.where.getFilter()); + if (not vals.empty()) { network_engine.tellListener(node, rid, id, WANT4 | WANT6, makeToken((sockaddr*)&node->ss, false), buckets.findClosestNodes(id, now, TARGET_NODES), buckets6.findClosestNodes(id, now, TARGET_NODES), - values); + std::move(vals), query); } - st->listeners.emplace(node, Listener {rid, now}); + st->listeners.emplace(node, Listener {rid, now, std::forward<Query>(query)}); } else l->second.refresh(rid, now); @@ -1813,7 +1977,7 @@ Dht::connectivityChanged() auto stop_listen = [&](std::map<InfoHash, std::shared_ptr<Search>> srs) { for (auto& sp : srs) for (auto& sn : sp.second->nodes) - sn.listenStatus.reset(); + sn.listenStatus.clear(); }; stop_listen(searches4); stop_listen(searches6); @@ -1957,6 +2121,16 @@ Dht::dumpSearch(const Search& sr, std::ostream& out) const } out << std::endl; + /*printing the queries*/ + if (sr.callbacks.size() + sr.listeners.size() > 0) + out << "Queries:" << std::endl; + for (const auto& cb : sr.callbacks) { + out << *cb.second.query << std::endl; + } + for (const auto& l : sr.listeners) { + out << *l.second.query << std::endl; + } + for (const auto& n : sr.announce) { bool announced = sr.isAnnounced(n.value->id, getType(n.value->type), now); out << "Announcement: " << *n.value << (announced ? " [announced]" : "") << std::endl; @@ -1978,18 +2152,18 @@ Dht::dumpSearch(const Search& sr, std::ostream& out) const // Get status { - char g_i = (n.getStatus && n.getStatus->pending()) ? (n.candidate ? 'c' : 'f') : ' '; + char g_i = n.pending(n.getStatus) ? (n.candidate ? 'c' : 'f') : ' '; char s_i = n.isSynced(now) ? (n.last_get_reply > last_get ? 'u' : 's') : '-'; out << " [" << s_i << g_i << "] "; } // Listen status if (not sr.listeners.empty()) { - if (not n.listenStatus) + if (n.listenStatus.empty()) out << " "; else out << "[" - << (n.isListening(now) ? 'l' : (n.listenStatus->pending() ? 'f' : ' ')) << "] "; + << (n.isListening(now) ? 'l' : (n.pending(n.listenStatus) ? 'f' : ' ')) << "] "; } // Announce status @@ -2108,8 +2282,8 @@ Dht::Dht(int s, int s6, Config config) std::bind(&Dht::onReportedAddr, this, _1, _2, _3), std::bind(&Dht::onPing, this, _1), std::bind(&Dht::onFindNode, this, _1, _2, _3), - std::bind(&Dht::onGetValues, this, _1, _2, _3), - std::bind(&Dht::onListen, this, _1, _2, _3, _4), + std::bind(&Dht::onGetValues, this, _1, _2, _3, _4), + std::bind(&Dht::onListen, this, _1, _2, _3, _4, _5), std::bind(&Dht::onAnnounce, this, _1, _2, _3, _4, _5)) { scheduler.syncTime(); @@ -2539,7 +2713,7 @@ Dht::onFindNode(std::shared_ptr<Node> node, InfoHash& target, want_t want) } NetworkEngine::RequestAnswer -Dht::onGetValues(std::shared_ptr<Node> node, InfoHash& hash, want_t) +Dht::onGetValues(std::shared_ptr<Node> node, InfoHash& hash, want_t, const Query& query) { if (hash == zeroes) { DHT_LOG.WARN("[node %s] Eek! Got get_values with no info_hash.", node->toString().c_str()); @@ -2552,11 +2726,7 @@ Dht::onGetValues(std::shared_ptr<Node> node, InfoHash& hash, want_t) answer.nodes4 = buckets.findClosestNodes(hash, now, TARGET_NODES); answer.nodes6 = buckets6.findClosestNodes(hash, now, TARGET_NODES); if (st != store.end() && not st->empty()) { - auto values = st->getValues(); - answer.values.resize(values.size()); - std::transform(values.begin(), values.end(), answer.values.begin(), [](const ValueStorage& vs) { - return vs.data; - }); + answer.values = st->get(query.where.getFilter()); DHT_LOG.DEBUG("[node %s] sending %u values.", node->toString().c_str(), answer.values.size()); } else { DHT_LOG.DEBUG("[node %s] sending nodes.", node->toString().c_str()); @@ -2566,34 +2736,52 @@ Dht::onGetValues(std::shared_ptr<Node> node, InfoHash& hash, want_t) void Dht::onGetValuesDone(const Request& status, - NetworkEngine::RequestAnswer& a, std::shared_ptr<Search> sr) + NetworkEngine::RequestAnswer& a, std::shared_ptr<Search> sr, const std::shared_ptr<Query>& orig_query) { if (not sr) { DHT_LOG.WARN("[search unknown] got reply to 'get'. Ignoring."); return; } - DHT_LOG.DEBUG("[search %s IPv%c] got reply to 'get' from %s with %u nodes", sr->id.toString().c_str(), sr->af == AF_INET ? '4' : '6', status.node->toString().c_str(), a.nodes4.size()); + DHT_LOG.DEBUG("[search %s IPv%c] got reply to 'get' from %s with %u nodes", + sr->id.toString().c_str(), sr->af == AF_INET ? '4' : '6', status.node->toString().c_str(), a.nodes4.size()); if (not a.ntoken.empty()) { - if (!a.values.empty()) { + if (not a.values.empty() or not a.fields.empty()) { DHT_LOG.DEBUG("[search %s IPv%c] found %u values", sr->id.toString().c_str(), sr->af == AF_INET ? '4' : '6', a.values.size()); - for (auto& cb : sr->callbacks) { - if (!cb.get_cb) continue; - std::vector<std::shared_ptr<Value>> tmp; - std::copy_if(a.values.begin(), a.values.end(), std::back_inserter(tmp), - [&](const std::shared_ptr<Value>& v) { - return not static_cast<bool>(cb.filter) or cb.filter(*v); + for (auto& getp : sr->callbacks) { + auto& get = getp.second; + if (not (get.get_cb or get.query_cb) or + (orig_query and get.query and not get.query->isSatisfiedBy(*orig_query))) + continue; + if (get.query_cb) { + if (not a.fields.empty()) { + get.query_cb(a.fields); + } else if (not a.values.empty()) { + std::vector<std::shared_ptr<FieldValueIndex>> fields(a.values.size()); + std::transform(a.values.begin(), a.values.end(), fields.begin(), + [&](const std::shared_ptr<Value>& v) { + return std::make_shared<FieldValueIndex>(*v, orig_query ? orig_query->select : Select {}); + }); + get.query_cb(fields); } - ); - if (not tmp.empty()) - cb.get_cb(tmp); + } else if (get.get_cb) { + std::vector<std::shared_ptr<Value>> tmp; + std::copy_if(a.values.begin(), a.values.end(), std::back_inserter(tmp), + [&](const std::shared_ptr<Value>& v) { + return not static_cast<bool>(get.filter) or get.filter(*v); + } + ); + if (not tmp.empty()) + get.get_cb(tmp); + } } std::vector<std::pair<GetCallback, std::vector<std::shared_ptr<Value>>>> tmp_lists; for (auto& l : sr->listeners) { - if (!l.second.get_cb) continue; + if (!l.second.get_cb or (orig_query and l.second.query and not l.second.query->isSatisfiedBy(*orig_query))) + continue; std::vector<std::shared_ptr<Value>> tmp; std::copy_if(a.values.begin(), a.values.end(), std::back_inserter(tmp), [&](const std::shared_ptr<Value>& v) { @@ -2620,7 +2808,7 @@ Dht::onGetValuesDone(const Request& status, } NetworkEngine::RequestAnswer -Dht::onListen(std::shared_ptr<Node> node, InfoHash& hash, Blob& token, size_t rid) +Dht::onListen(std::shared_ptr<Node> node, InfoHash& hash, Blob& token, size_t rid, Query&& query) { if (hash == zeroes) { DHT_LOG.WARN("Listen with no info_hash."); @@ -2633,18 +2821,18 @@ Dht::onListen(std::shared_ptr<Node> node, InfoHash& hash, Blob& token, size_t ri DHT_LOG.WARN("[node %s] incorrect token %s for 'listen'.", node->toString().c_str(), hash.toString().c_str()); throw DhtProtocolException {DhtProtocolException::UNAUTHORIZED, DhtProtocolException::LISTEN_WRONG_TOKEN}; } - storageAddListener(hash, node, rid); + storageAddListener(hash, node, rid, std::forward<Query>(query)); return {}; } void -Dht::onListenDone(const Request& status, NetworkEngine::RequestAnswer& answer, std::shared_ptr<Search>& sr) +Dht::onListenDone(const Request& status, NetworkEngine::RequestAnswer& answer, std::shared_ptr<Search>& sr, const std::shared_ptr<Query>& orig_query) { DHT_LOG.DEBUG("[search %s] Got reply to listen.", sr->id.toString().c_str()); if (sr) { if (not answer.values.empty()) { /* got new values from listen request */ DHT_LOG.DEBUG("[listen %s] Got new values.", sr->id.toString().c_str()); - onGetValuesDone(status, answer, sr); + onGetValuesDone(status, answer, sr, orig_query); } if (not sr->done) { diff --git a/src/dhtrunner.cpp b/src/dhtrunner.cpp index ed24473d877a296661aebbfdb04a3b463d003559..8a8cc01b69ca868a6766f520497aebd8c92433e1 100644 --- a/src/dhtrunner.cpp +++ b/src/dhtrunner.cpp @@ -429,37 +429,44 @@ DhtRunner::doRun(const sockaddr_in* sin4, const sockaddr_in6* sin6, SecureDht::C } void -DhtRunner::get(InfoHash hash, GetCallback vcb, DoneCallback dcb, Value::Filter f) +DhtRunner::get(InfoHash hash, GetCallback vcb, DoneCallback dcb, Value::Filter f, Where w) { std::lock_guard<std::mutex> lck(storage_mtx); pending_ops.emplace([=](SecureDht& dht) mutable { - dht.get(hash, vcb, dcb, std::move(f)); + dht.get(hash, vcb, dcb, std::move(f), std::move(w)); }); cv.notify_all(); } void -DhtRunner::get(const std::string& key, GetCallback vcb, DoneCallbackSimple dcb, Value::Filter f) +DhtRunner::get(const std::string& key, GetCallback vcb, DoneCallbackSimple dcb, Value::Filter f, Where w) { - get(InfoHash::get(key), vcb, dcb, f); + get(InfoHash::get(key), vcb, dcb, f, w); +} +void DhtRunner::query(const InfoHash& hash, QueryCallback cb, DoneCallback done_cb, Query q) { + std::lock_guard<std::mutex> lck(storage_mtx); + pending_ops.emplace([=](SecureDht& dht) mutable { + dht.query(hash, cb, done_cb, std::move(q)); + }); + cv.notify_all(); } std::future<size_t> -DhtRunner::listen(InfoHash hash, GetCallback vcb, Value::Filter f) +DhtRunner::listen(InfoHash hash, GetCallback vcb, Value::Filter f, Where w) { std::lock_guard<std::mutex> lck(storage_mtx); auto ret_token = std::make_shared<std::promise<size_t>>(); pending_ops.emplace([=](SecureDht& dht) mutable { - ret_token->set_value(dht.listen(hash, vcb, std::move(f))); + ret_token->set_value(dht.listen(hash, vcb, std::move(f), std::move(w))); }); cv.notify_all(); return ret_token->get_future(); } std::future<size_t> -DhtRunner::listen(const std::string& key, GetCallback vcb, Value::Filter f) +DhtRunner::listen(const std::string& key, GetCallback vcb, Value::Filter f, Where w) { - return listen(InfoHash::get(key), vcb, f); + return listen(InfoHash::get(key), vcb, f, w); } void diff --git a/src/network_engine.cpp b/src/network_engine.cpp index 15f2431c958d7dfbd6c2a95dc0f27aca89adef5d..56be1a69d0edf716a20a44eec732801ac0075432 100644 --- a/src/network_engine.cpp +++ b/src/network_engine.cpp @@ -21,6 +21,7 @@ #include "network_engine.h" #include "request.h" +#include "default_types.h" #include <msgpack.hpp> @@ -63,36 +64,40 @@ enum class MessageType { struct ParsedMessage { MessageType type; - InfoHash id; /* the id of the sender */ - NetId network {0}; /* network id */ - InfoHash info_hash; /* hash for which values are requested */ - InfoHash target; /* target id around which to find nodes */ - NetworkEngine::TransId tid; /* transaction id */ - Blob token; /* security token */ - Value::Id value_id; /* the value id */ - time_point created { time_point::max() }; /* time when value was first created */ - Blob nodes4_raw, nodes6_raw; /* IPv4 nodes in response to a 'find' request */ + InfoHash id; /* the id of the sender */ + NetId network {0}; /* network id */ + InfoHash info_hash; /* hash for which values are requested */ + InfoHash target; /* target id around which to find nodes */ + NetworkEngine::TransId tid; /* transaction id */ + Blob token; /* security token */ + Value::Id value_id; /* the value id */ + time_point created { time_point::max() }; /* time when value was first created */ + Blob nodes4_raw, nodes6_raw; /* IPv4 nodes in response to a 'find' request */ std::vector<std::shared_ptr<Node>> nodes4, nodes6; - std::vector<std::shared_ptr<Value>> values; /* values for a 'get' request */ - want_t want; /* states if ipv4 or ipv6 request */ - uint16_t error_code; /* error code in case of error */ + std::vector<std::shared_ptr<Value>> values; /* values for a 'get' request */ + std::vector<std::shared_ptr<FieldValueIndex>> fields; /* index for fields values */ + Query query; /* query describing a filter to apply on values. */ + want_t want; /* states if ipv4 or ipv6 request */ + uint16_t error_code; /* error code in case of error */ std::string ua; - Address addr; /* reported address by the distant node */ + Address addr; /* reported address by the distant node */ void msgpack_unpack(msgpack::object o); }; NetworkEngine::RequestAnswer::RequestAnswer(ParsedMessage&& msg) - : ntoken(std::move(msg.token)), values(std::move(msg.values)), nodes4(std::move(msg.nodes4)), nodes6(std::move(msg.nodes6)) {} + : ntoken(std::move(msg.token)), values(std::move(msg.values)), fields(std::move(msg.fields)), + nodes4(std::move(msg.nodes4)), nodes6(std::move(msg.nodes6)) {} void -NetworkEngine::tellListener(std::shared_ptr<Node> node, uint16_t rid, InfoHash hash, want_t want, - Blob ntoken, std::vector<std::shared_ptr<Node>> nodes, std::vector<std::shared_ptr<Node>> nodes6, - std::vector<std::shared_ptr<Value>> values) +NetworkEngine::tellListener(std::shared_ptr<Node> node, uint16_t rid, const InfoHash& hash, want_t want, + const Blob& ntoken, std::vector<std::shared_ptr<Node>>&& nodes, + std::vector<std::shared_ptr<Node>>&& nodes6, std::vector<std::shared_ptr<Value>>&& values, + const Query& query) { auto nnodes = bufferNodes(node->getFamily(), hash, want, nodes, nodes6); try { sendNodesValues((const sockaddr*)&node->ss, node->sslen, TransId {TransPrefix::GET_VALUES, (uint16_t)rid}, nnodes.first, nnodes.second, - values, ntoken); + values, query, ntoken); } catch (const std::overflow_error& e) { DHT_LOG.ERR("Can't send value: buffer not large enough !"); } @@ -380,7 +385,7 @@ NetworkEngine::processMessage(const uint8_t *buf, size_t buflen, const sockaddr* requests.erase(reqp); req->reply_time = scheduler.time(); - deserializeNodesValues(msg); + deserializeNodes(msg); req->setDone(std::move(msg)); break; default: @@ -404,16 +409,16 @@ NetworkEngine::processMessage(const uint8_t *buf, size_t buflen, const sockaddr* ++in_stats.find; RequestAnswer answer = onFindNode(node, msg.target, msg.want); auto nnodes = bufferNodes(from->sa_family, msg.target, msg.want, answer.nodes4, answer.nodes6); - sendNodesValues(from, fromlen, msg.tid, nnodes.first, nnodes.second, {}, answer.ntoken); + sendNodesValues(from, fromlen, msg.tid, nnodes.first, nnodes.second, {}, {}, answer.ntoken); break; } case MessageType::GetValues: { DHT_LOG.DEBUG("[node %s %s] got 'get' request for %s.", msg.id.toString().c_str(), print_addr(from, fromlen).c_str(), msg.info_hash.toString().c_str()); ++in_stats.get; - RequestAnswer answer = onGetValues(node, msg.info_hash, msg.want); + RequestAnswer answer = onGetValues(node, msg.info_hash, msg.want, msg.query); auto nnodes = bufferNodes(from->sa_family, msg.info_hash, msg.want, answer.nodes4, answer.nodes6); - sendNodesValues(from, fromlen, msg.tid, nnodes.first, nnodes.second, answer.values, answer.ntoken); + sendNodesValues(from, fromlen, msg.tid, nnodes.first, nnodes.second, answer.values, msg.query, answer.ntoken); break; } case MessageType::AnnounceValue: { @@ -435,7 +440,7 @@ NetworkEngine::processMessage(const uint8_t *buf, size_t buflen, const sockaddr* DHT_LOG.DEBUG("[node %s %s] got 'listen' request for %s.", msg.id.toString().c_str(), print_addr(from, fromlen).c_str(), msg.info_hash.toString().c_str()); ++in_stats.listen; - RequestAnswer answer = onListen(node, msg.info_hash, msg.token, msg.tid.getTid()); + RequestAnswer answer = onListen(node, msg.info_hash, msg.token, msg.tid.getTid(), std::move(msg.query)); sendListenConfirmation(from, fromlen, msg.tid); break; } @@ -594,16 +599,19 @@ NetworkEngine::sendFindNode(std::shared_ptr<Node> n, const InfoHash& target, wan std::shared_ptr<Request> -NetworkEngine::sendGetValues(std::shared_ptr<Node> n, const InfoHash& info_hash, want_t want, +NetworkEngine::sendGetValues(std::shared_ptr<Node> n, const InfoHash& info_hash, const Query& query, want_t want, RequestCb on_done, RequestExpiredCb on_expired) { auto tid = TransId {TransPrefix::GET_VALUES, getNewTid()}; msgpack::sbuffer buffer; msgpack::packer<msgpack::sbuffer> pk(&buffer); pk.pack_map(5+(network?1:0)); - pk.pack(std::string("a")); pk.pack_map(2 + (want>0?1:0)); + pk.pack(std::string("a")); pk.pack_map(2 + + (query.where.getFilter() or not query.select.getSelection().empty() ? 1:0) + + (want>0?1:0)); pk.pack(std::string("id")); pk.pack(myid); pk.pack(std::string("h")); pk.pack(info_hash); + pk.pack(std::string("q")); pk.pack(query); if (want > 0) { pk.pack(std::string("w")); pk.pack_array(((want & WANT4)?1:0) + ((want & WANT6)?1:0)); @@ -639,7 +647,7 @@ NetworkEngine::sendGetValues(std::shared_ptr<Node> n, const InfoHash& info_hash, } void -NetworkEngine::deserializeNodesValues(ParsedMessage& msg) { +NetworkEngine::deserializeNodes(ParsedMessage& msg) { if (msg.nodes4_raw.size() % NODE4_INFO_BUF_LEN != 0 || msg.nodes6_raw.size() % NODE6_INFO_BUF_LEN != 0) { throw DhtProtocolException {DhtProtocolException::WRONG_NODE_INFO_BUF_LEN}; } else { @@ -680,7 +688,7 @@ NetworkEngine::deserializeNodesValues(ParsedMessage& msg) { void NetworkEngine::sendNodesValues(const sockaddr* sa, socklen_t salen, TransId tid, const Blob& nodes, const Blob& nodes6, - const std::vector<std::shared_ptr<Value>>& st, const Blob& token) { + const std::vector<std::shared_ptr<Value>>& st, const Query& query, const Blob& token) { msgpack::sbuffer buffer; msgpack::packer<msgpack::sbuffer> pk(&buffer); pk.pack_map(4+(network?1:0)); @@ -702,34 +710,50 @@ NetworkEngine::sendNodesValues(const sockaddr* sa, socklen_t salen, TransId tid, if (not token.empty()) { pk.pack(std::string("token")); packToken(pk, token); } - if (not st.empty()) { - // We treat the storage as a circular list, and serve a randomly - // chosen slice. In order to make sure we fit, - // we limit ourselves to 50 values. - std::uniform_int_distribution<> pos_dis(0, st.size()-1); - std::vector<Blob> subset {}; - subset.reserve(std::min<size_t>(st.size(), 50)); - + if (not st.empty()) { /* pack complete values */ + auto fields = query.select.getSelection(); size_t total_size = 0; - unsigned j0 = pos_dis(rd_device); - unsigned j = j0; - unsigned k = 0; - - do { - subset.emplace_back(packMsg(st[j])); - total_size += subset.back().size(); - ++k; - j = (j + 1) % st.size(); - } while (j != j0 && k < 50 && total_size < MAX_VALUE_SIZE); - - pk.pack(std::string("values")); - pk.pack_array(subset.size()); - for (const auto& b : subset) - buffer.write((const char*)b.data(), b.size()); - DHT_LOG.DEBUG("sending closest nodes (%d+%d nodes.), %lu bytes of values", nodes.size(), nodes6.size(), total_size); + if (fields.empty()) { + // We treat the storage as a circular list, and serve a randomly + // chosen slice. In order to make sure we fit, + // we limit ourselves to 50 values. + std::uniform_int_distribution<> pos_dis(0, st.size()-1); + std::vector<Blob> subset {}; + subset.reserve(std::min<size_t>(st.size(), 50)); + + unsigned j0 = pos_dis(rd_device); + unsigned j = j0; + unsigned k = 0; + + do { + subset.emplace_back(packMsg(st[j])); + total_size += subset.back().size(); + ++k; + j = (j + 1) % st.size(); + } while (j != j0 && k < 50 && total_size < MAX_VALUE_SIZE); + + pk.pack(std::string("values")); + pk.pack_array(subset.size()); + for (const auto& b : subset) + buffer.write((const char*)b.data(), b.size()); + DHT_LOG.DEBUG("sending closest nodes (%d+%d nodes.), %lu bytes of values", + nodes.size(), nodes6.size(), total_size); + } else { /* pack fields */ + pk.pack(std::string("fields")); + pk.pack_map(2); + pk.pack(std::string("f")); pk.pack(fields); + pk.pack(std::string("v")); pk.pack_array(st.size()*fields.size()); + for (const auto& v : st) { + v->msgpack_pack_fields(fields, pk); + } + DHT_LOG.DEBUG("sending closest nodes (%d+%d nodes.), %u value headers containing %u fields", + nodes.size(), nodes6.size(), st.size(), fields.size()); + } } else DHT_LOG.DEBUG("sending closest nodes (%d+%d nodes.)", nodes.size(), nodes6.size()); + DHT_LOG.DEBUG("sending closest nodes (%d+%d nodes.)", nodes.size(), nodes6.size()); + pk.pack(std::string("t")); pk.pack_bin(tid.size()); pk.pack_bin_body((const char*)tid.data(), tid.size()); pk.pack(std::string("y")); pk.pack(std::string("r")); @@ -794,16 +818,18 @@ NetworkEngine::bufferNodes(sa_family_t af, const InfoHash& id, want_t want, } std::shared_ptr<Request> -NetworkEngine::sendListen(std::shared_ptr<Node> n, const InfoHash& infohash, const Blob& token, +NetworkEngine::sendListen(std::shared_ptr<Node> n, const InfoHash& infohash, const Query& query, const Blob& token, RequestCb on_done, RequestExpiredCb on_expired) { auto tid = TransId {TransPrefix::LISTEN, getNewTid()}; msgpack::sbuffer buffer; msgpack::packer<msgpack::sbuffer> pk(&buffer); pk.pack_map(5+(network?1:0)); - pk.pack(std::string("a")); pk.pack_map(3); + pk.pack(std::string("a")); pk.pack_map(3 + + (query.where.getFilter() or not query.select.getSelection().empty() ? 1:0)); pk.pack(std::string("id")); pk.pack(myid); pk.pack(std::string("h")); pk.pack(infohash); + pk.pack(std::string("q")); pk.pack(query); pk.pack(std::string("token")); packToken(pk, token); pk.pack(std::string("q")); pk.pack(std::string("listen")); @@ -957,35 +983,40 @@ NetworkEngine::sendError(const sockaddr* sa, send(buffer.data(), buffer.size(), 0, sa, salen); } -msgpack::object* -findMapValue(msgpack::object& map, const std::string& key) { - if (map.type != msgpack::type::MAP) throw msgpack::type_error(); - for (unsigned i = 0; i < map.via.map.size; i++) { - auto& o = map.via.map.ptr[i]; - if(o.key.type != msgpack::type::STR) - continue; - if (o.key.as<std::string>() == key) { - return &o.val; - } - } - return nullptr; -} - void ParsedMessage::msgpack_unpack(msgpack::object msg) { auto y = findMapValue(msg, "y"); - auto a = findMapValue(msg, "a"); auto r = findMapValue(msg, "r"); auto e = findMapValue(msg, "e"); - std::string query; - if (auto q = findMapValue(msg, "q")) { - if (q->type != msgpack::type::STR) + std::string q; + if (auto rq = findMapValue(msg, "q")) { + if (rq->type != msgpack::type::STR) throw msgpack::type_error(); - query = q->as<std::string>(); + q = rq->as<std::string>(); } + if (e) + type = MessageType::Error; + else if (r) + type = MessageType::Reply; + else if (y and y->as<std::string>() != "q") + throw msgpack::type_error(); + else if (q == "ping") + type = MessageType::Ping; + else if (q == "find") + type = MessageType::FindNode; + else if (q == "get") + type = MessageType::GetValues; + else if (q == "listen") + type = MessageType::Listen; + else if (q == "put") + type = MessageType::AnnounceValue; + else + throw msgpack::type_error(); + + auto a = findMapValue(msg, "a"); if (!a && !r && !e) throw msgpack::type_error(); auto& req = a ? *a : (r ? *r : *e); @@ -1008,6 +1039,9 @@ ParsedMessage::msgpack_unpack(msgpack::object msg) if (auto rtarget = findMapValue(req, "target")) target = {*rtarget}; + if (auto rquery = findMapValue(req, "q")) + query.msgpack_unpack(*rquery); + if (auto otoken = findMapValue(req, "token")) token = unpackBlob(*otoken); @@ -1054,6 +1088,23 @@ ParsedMessage::msgpack_unpack(msgpack::object msg) } catch (const std::exception& e) { //DHT_LOG.WARN("Error reading value: %s", e.what()); } + } else if (auto raw_fields = findMapValue(req, "fields")) { + if (auto rfields = findMapValue(*raw_fields, "f")) { + auto fields_ = rfields->as<std::set<Value::Field>>(); + if (auto rvalues = findMapValue(*raw_fields, "v")) { + if (rvalues->type != msgpack::type::ARRAY) + throw msgpack::type_error(); + for (size_t i = 0; i < rvalues->via.array.size; ++i) { + try { + auto v = std::make_shared<FieldValueIndex>(); + v->msgpack_unpack_fields(fields_, *rvalues, i*fields.size()); + fields.emplace_back(std::move(v)); + } catch (const std::exception& e) { } + } + } + } else { + throw msgpack::type_error(); + } } if (auto w = findMapValue(req, "w")) { @@ -1080,24 +1131,6 @@ ParsedMessage::msgpack_unpack(msgpack::object msg) if (auto rv = findMapValue(msg, "v")) ua = rv->as<std::string>(); - if (e) - type = MessageType::Error; - else if (r) - type = MessageType::Reply; - else if (y and y->as<std::string>() != "q") - throw msgpack::type_error(); - else if (query == "ping") - type = MessageType::Ping; - else if (query == "find") - type = MessageType::FindNode; - else if (query == "get") - type = MessageType::GetValues; - else if (query == "listen") - type = MessageType::Listen; - else if (query == "put") - type = MessageType::AnnounceValue; - else - throw msgpack::type_error(); } } diff --git a/src/securedht.cpp b/src/securedht.cpp index 2c37ad6e85e3120f7f7672f8590630b0258b242d..8e1ba6f7828db86e5260f0c924124039e2cd2890 100644 --- a/src/securedht.cpp +++ b/src/securedht.cpp @@ -291,15 +291,15 @@ SecureDht::getCallbackFilter(GetCallback cb, Value::Filter&& filter) } void -SecureDht::get(const InfoHash& id, GetCallback cb, DoneCallback donecb, Value::Filter&& f) +SecureDht::get(const InfoHash& id, GetCallback cb, DoneCallback donecb, Value::Filter&& f, Where&& w) { - Dht::get(id, getCallbackFilter(cb, std::forward<Value::Filter>(f)), donecb); + Dht::get(id, getCallbackFilter(cb, std::forward<Value::Filter>(f)), donecb, {}, std::forward<Where>(w)); } size_t -SecureDht::listen(const InfoHash& id, GetCallback cb, Value::Filter&& f) +SecureDht::listen(const InfoHash& id, GetCallback cb, Value::Filter&& f, Where&& w) { - return Dht::listen(id, getCallbackFilter(cb, std::forward<Value::Filter>(f))); + return Dht::listen(id, getCallbackFilter(cb, std::forward<Value::Filter>(f)), {}, std::forward<Where>(w)); } void diff --git a/src/utils.cpp b/src/utils.cpp index 33ee40635805cf9f376719ae50ba7666847d6c43..1abf0fe0471048c55662c68ecbe714379d2ff503 100644 --- a/src/utils.cpp +++ b/src/utils.cpp @@ -85,4 +85,15 @@ unpackMsg(Blob b) { return msgpack::unpack((const char*)b.data(), b.size()); } +msgpack::object* +findMapValue(msgpack::object& map, const std::string& key) { + if (map.type != msgpack::type::MAP) throw msgpack::type_error(); + for (unsigned i = 0; i < map.via.map.size; i++) { + auto& o = map.via.map.ptr[i]; + if (o.key.type == msgpack::type::STR && o.key.as<std::string>() == key) + return &o.val; + } + return nullptr; +} + } diff --git a/src/value.cpp b/src/value.cpp index 58a86556d2908ebc9334210c6c17e623f1f2ae2a..d11a7ba3fcafa1e678d0ae943a15337f42c72425 100644 --- a/src/value.cpp +++ b/src/value.cpp @@ -25,6 +25,8 @@ namespace dht { +const std::string Query::QUERY_PARSE_ERROR {"Error parsing query."}; + std::ostream& operator<< (std::ostream& s, const Value& v) { s << "Value[id:" << std::hex << v.id << std::dec << " "; @@ -157,4 +159,306 @@ Value::msgpack_unpack_body(const msgpack::object& o) } } +bool +FieldValue::operator==(const FieldValue& vfd) const +{ + if (field != vfd.field) + return false; + switch (field) { + case Value::Field::Id: + case Value::Field::ValueType: + return intValue == vfd.intValue; + case Value::Field::OwnerPk: + return hashValue == vfd.hashValue; + case Value::Field::UserType: + return blobValue == vfd.blobValue; + case Value::Field::None: + return true; + default: + return false; + } +} + +Value::Filter +FieldValue::getLocalFilter() const +{ + switch (field) { + case Value::Field::Id: + return Value::IdFilter(intValue); + case Value::Field::ValueType: + return Value::TypeFilter(intValue); + case Value::Field::OwnerPk: + return Value::ownerFilter(hashValue); + case Value::Field::UserType: + return Value::userTypeFilter(std::string {blobValue.begin(), blobValue.end()}); + default: + return Value::AllFilter(); + } +} + +FieldValueIndex::FieldValueIndex(const Value& v, Select s) +{ + auto selection = s.getSelection(); + if (not selection.empty()) { + std::transform(selection.begin(), selection.end(), std::inserter(index, index.end()), + [](const std::set<Value::Field>::value_type& f) { + return std::make_pair(f, FieldValue {}); + }); + } else { + index.clear(); + for (size_t f = 1 ; f < 5 ; ++f) + index[static_cast<Value::Field>(f)] = {}; + } + for (const auto& fvp : index) { + const auto& f = fvp.first; + switch (f) { + case Value::Field::Id: + index[f] = {f, v.id}; + break; + case Value::Field::ValueType: + index[f] = {f, v.type}; + break; + case Value::Field::OwnerPk: + index[f] = {f, v.owner ? v.owner->getId() : InfoHash() }; + break; + case Value::Field::UserType: + index[f] = {f, Blob {v.user_type.begin(), v.user_type.end()}}; + break; + default: + break; + } + } +} + +bool FieldValueIndex::containedIn(const FieldValueIndex& other) const { + if (index.size() > other.index.size()) + return false; + for (const auto& field : index) { + auto other_field = other.index.find(field.first); + if (other_field == other.index.end()) + return false; + } + return true; +} + +std::ostream& operator<<(std::ostream& os, const FieldValueIndex& fvi) { + os << "Index["; + for (auto v = fvi.index.begin(); v != fvi.index.end(); ++v) { + switch (v->first) { + case Value::Field::Id: + os << "Id:" << std::hex << v->second.getInt(); + break; + case Value::Field::ValueType: + os << "ValueType:" << v->second.getInt(); + break; + case Value::Field::OwnerPk: + os << "Owner:" << v->second.getHash().toString(); + break; + case Value::Field::UserType: { + auto ut = v->second.getBlob(); + os << "UserType:" << std::string(ut.begin(), ut.end()); + break; + } + default: + break; + } + os << (std::next(v) != fvi.index.end() ? "," : ""); + } + return os << "]"; +} + +void +FieldValueIndex::msgpack_unpack_fields(const std::set<Value::Field>& fields, const msgpack::object& o, unsigned offset) +{ + index.clear(); + + unsigned j = 0; + for (const auto& field : fields) { + auto& field_value = o.via.array.ptr[offset+(j++)]; + switch (field) { + case Value::Field::Id: + case Value::Field::ValueType: + index[field] = FieldValue(field, field_value.as<uint64_t>()); + break; + case Value::Field::OwnerPk: + index[field] = FieldValue(field, field_value.as<InfoHash>()); + break; + case Value::Field::UserType: + index[field] = FieldValue(field, field_value.as<Blob>()); + break; + default: + throw msgpack::type_error(); + } + } +} + +void trim_str(std::string& str) { + auto first = std::min(str.size(), str.find_first_not_of(" ")); + auto last = std::min(str.size(), str.find_last_not_of(" ")); + str = str.substr(first, last - first + 1); } + +Select::Select(const std::string& q_str) { + std::istringstream q_iss {q_str}; + std::string token {}; + q_iss >> token; + + if (token == "SELECT" or token == "select") { + q_iss >> token; + std::istringstream fields {token}; + + while (std::getline(fields, token, ',')) { + trim_str(token); + if (token == "id") + field(Value::Field::Id); + else if (token == "value_type") + field(Value::Field::ValueType); + else if (token == "owner_pk") + field(Value::Field::OwnerPk); + else if (token == "user_type") + field(Value::Field::UserType); + } + } +} + +Where::Where(const std::string& q_str) { + std::istringstream q_iss {q_str}; + std::string token {}; + q_iss >> token; + if (token == "WHERE" or token == "where") { + std::getline(q_iss, token); + std::istringstream restrictions {token}; + while (std::getline(restrictions, token, ',')) { + trim_str(token); + std::istringstream eq_ss {token}; + std::string field_str, value_str; + std::getline(eq_ss, field_str, '='); + trim_str(field_str); + std::getline(eq_ss, value_str, '='); + trim_str(value_str); + + if (not value_str.empty()) { + uint64_t v = 0; + std::string s {}; + std::istringstream convert {value_str}; + convert >> v; + if (convert.failbit and value_str.size() > 1 and value_str[0] == '\"' and value_str[value_str.size()-1] == '\"') + s = value_str.substr(1, value_str.size()-2); + else + s = value_str; + if (field_str == "id") + id(v); + else if (field_str == "value_type") + valueType(v); + else if (field_str == "owner_pk") + owner(InfoHash(s)); + else if (field_str == "user_type") + userType(s); + else + throw std::invalid_argument(Query::QUERY_PARSE_ERROR + " (WHERE) wrong token near: " + field_str); + } + } + } +} + +void +Query::msgpack_unpack(const msgpack::object& o) +{ + if (o.type != msgpack::type::MAP) + throw msgpack::type_error(); + + auto rfilters = findMapValue(o, "w"); /* unpacking filters */ + if (rfilters) + where.msgpack_unpack(*rfilters); + else + throw msgpack::type_error(); + + auto rfield_selector = findMapValue(o, "s"); /* unpacking field selectors */ + if (rfield_selector) + select.msgpack_unpack(*rfield_selector); + else + throw msgpack::type_error(); +} + +template <typename T> +bool subset(std::vector<T> fds, std::vector<T> qfds) +{ + for (auto& fd : fds) { + auto correspondance = std::find_if(qfds.begin(), qfds.end(), [&fd](T& _vfd) { return fd == _vfd; }); + if (correspondance == qfds.end()) + return false; + } + return true; +}; + +bool Select::isSatisfiedBy(const Select& os) const { + /* empty, means all values are selected. */ + if (fieldSelection_.empty() and not os.fieldSelection_.empty()) + return false; + else + return subset(fieldSelection_, os.fieldSelection_); +} + +bool Where::isSatisfiedBy(const Where& ow) const { + return subset(ow.filters_, filters_); +} + +bool Query::isSatisfiedBy(const Query& q) const { + return where.isSatisfiedBy(q.where) and select.isSatisfiedBy(q.select); +} + +std::ostream& operator<<(std::ostream& s, const dht::Select& select) { + s << "SELECT " << (select.fieldSelection_.empty() ? "*" : ""); + for (auto fs = select.fieldSelection_.begin() ; fs != select.fieldSelection_.end() ; ++fs) { + switch (fs->getField()) { + case Value::Field::Id: + s << "id"; + break; + case Value::Field::ValueType: + s << "value_type"; + break; + case Value::Field::UserType: + s << "user_type"; + break; + case Value::Field::OwnerPk: + s << "owner_public_key"; + break; + default: + break; + } + s << (std::next(fs) != select.fieldSelection_.end() ? "," : ""); + } + return s; +} + +std::ostream& operator<<(std::ostream& s, const dht::Where& where) { + if (not where.filters_.empty()) { + s << "WHERE "; + for (auto f = where.filters_.begin() ; f != where.filters_.end() ; ++f) { + switch (f->getField()) { + case Value::Field::Id: + s << "id=" << f->getInt(); + break; + case Value::Field::ValueType: + s << "value_type=" << f->getInt(); + break; + case Value::Field::OwnerPk: + s << "owner_pk_hash=" << f->getHash().toString(); + break; + case Value::Field::UserType: { + auto b = f->getBlob(); + s << "user_type=" << std::string {b.begin(), b.end()}; + break; + } + default: + break; + } + s << (std::next(f) != where.filters_.end() ? "," : ""); + } + } + return s; +} + + +} + diff --git a/tools/dhtnode.cpp b/tools/dhtnode.cpp index 5112fbaa2644701b95b5e50409fa5947af7b0482..cbaed12318ae81831033caf7dc310a35d3d7b02d 100644 --- a/tools/dhtnode.cpp +++ b/tools/dhtnode.cpp @@ -50,7 +50,7 @@ void print_help() { std::cout << "OpenDht command line interface (CLI)" << std::endl; std::cout << "Possible commands:" << std::endl << " h, help Print this help message." << std::endl - << " q, quit Quit the program." << std::endl + << " x, quit Quit the program." << std::endl << " log Start/stop printing DHT logs." << std::endl; std::cout << std::endl << "Node information:" << std::endl @@ -61,8 +61,9 @@ void print_help() { std::cout << std::endl << "Operations on the DHT:" << std::endl << " b ip:port Ping potential node at given IP address/port." << std::endl - << " g [key] Get values at [key]." << std::endl - << " l [key] Listen for value changes at [key]." << std::endl + << " g [key] [where] Get values at [key]. [where] is the 'where' part of an SQL-ish string." << std::endl + << " q [key] [query] Query field values at [key]. [query] is an SQL-ish string." << std::endl + << " l [key] [where] Listen for value changes at [key]. [where] is the 'where' part of an SQL-ish string." << std::endl << " p [key] [str] Put string value at [key]." << std::endl << " s [key] [str] Put string value at [key], signed with our generated private key." << std::endl << " e [key] [dest] [str] Put string value at [key], encrypted for [dest] with its public key (if found)." << std::endl @@ -88,7 +89,7 @@ void cmd_loop(DhtRunner& dht, dht_params& params) std::string op, idstr, value; iss >> op >> idstr; - if (op == "x" || op == "q" || op == "exit" || op == "quit") { + if (op == "x" || op == "exit" || op == "quit") { break; } else if (op == "h" || op == "help") { print_help(); @@ -147,7 +148,7 @@ void cmd_loop(DhtRunner& dht, dht_params& params) continue; dht::InfoHash id {idstr}; - static const std::set<std::string> VALID_OPS {"g", "l", "p", "s", "e", "a"}; + static const std::set<std::string> VALID_OPS {"g", "q", "l", "p", "s", "e", "a"}; if (VALID_OPS.find(op) == VALID_OPS.cend()) { std::cout << "Unknown command: " << op << std::endl; std::cout << " (type 'h' or 'help' for a list of possible commands)" << std::endl; @@ -161,6 +162,11 @@ void cmd_loop(DhtRunner& dht, dht_params& params) auto start = std::chrono::high_resolution_clock::now(); if (op == "g") { + std::string rem; + std::getline(iss, rem); + dht::Where w {std::move(rem)}; + dht::Query q {{}, w}; + std::cout << q << std::endl; dht.get(id, [start](std::shared_ptr<Value> value) { auto now = std::chrono::high_resolution_clock::now(); std::cout << "Get: found value (after " << print_dt(now-start) << "s)" << std::endl; @@ -169,14 +175,36 @@ void cmd_loop(DhtRunner& dht, dht_params& params) }, [start](bool ok) { auto end = std::chrono::high_resolution_clock::now(); std::cout << "Get: " << (ok ? "completed" : "failure") << " (took " << print_dt(end-start) << "s)" << std::endl; - }); + }, {}, std::move(w)); + } + else if (op == "q") { + std::string rem; + std::getline(iss, rem); + dht::Query q {std::move(rem)}; + std::cout << q << std::endl; + dht.query(id, [start](const std::vector<std::shared_ptr<FieldValueIndex>>& field_value_indexes) { + auto now = std::chrono::high_resolution_clock::now(); + for (auto& index : field_value_indexes) { + std::cout << "Query: found field value index (after " << print_dt(now-start) << "s)" << std::endl; + std::cout << "\t" << *index << std::endl; + } + return true; + }, [start](bool ok) { + auto end = std::chrono::high_resolution_clock::now(); + std::cout << "Query: " << (ok ? "completed" : "failure") << " (took " << print_dt(end-start) << "s)" << std::endl; + }, std::move(q)); } else if (op == "l") { + std::string rem; + std::getline(iss, rem); + dht::Where w {std::move(rem)}; + dht::Query q {{}, w}; + std::cout << q << std::endl; dht.listen(id, [](std::shared_ptr<Value> value) { std::cout << "Listen: found value:" << std::endl; std::cout << "\t" << *value << std::endl; return true; - }); + }, {}, std::move(w)); } else if (op == "p") { std::string v;