diff --git a/include/opendht/dht.h b/include/opendht/dht.h index aafd2f5606ea354d80b5e058641a8c8177771c7b..982d41c63807208ae2bdccbaa647ac28b2b21d91 100644 --- a/include/opendht/dht.h +++ b/include/opendht/dht.h @@ -230,8 +230,16 @@ public: * * @return a token to cancel the listener later. */ - virtual size_t listen(const InfoHash&, GetCallback, Value::Filter={}, Where w = {}); - virtual size_t listen(const InfoHash& key, GetCallbackSimple cb, Value::Filter f={}, Where w = {}) { + virtual size_t listen(const InfoHash&, ValueCallback, Value::Filter={}, Where={}); + + virtual size_t listen(const InfoHash& key, GetCallback cb, Value::Filter f={}, Where w={}) { + return listen(key, [cb](const std::vector<Sp<Value>>& vals, bool expired){ + if (not expired) + return cb(vals); + return true; + }, std::forward<Value::Filter>(f), std::forward<Where>(w)); + } + virtual size_t listen(const InfoHash& key, GetCallbackSimple cb, Value::Filter f={}, Where w={}) { return listen(key, bindGetCb(cb), std::forward<Value::Filter>(f), std::forward<Where>(w)); } @@ -459,7 +467,7 @@ private: Sp<Search> search(const InfoHash& id, sa_family_t af, GetCallback = {}, QueryCallback = {}, DoneCallback = {}, Value::Filter = {}, Query q = {}); void announce(const InfoHash& id, sa_family_t af, Sp<Value> value, DoneCallback callback, time_point created=time_point::max(), bool permanent = false); - size_t listenTo(const InfoHash& id, sa_family_t af, GetCallback cb, Value::Filter f = Value::AllFilter(), const Sp<Query>& q = {}); + size_t listenTo(const InfoHash& id, sa_family_t af, ValueCallback cb, Value::Filter f = Value::AllFilter(), const Sp<Query>& q = {}); /** * Refill the search with good nodes if possible. diff --git a/include/opendht/dht_interface.h b/include/opendht/dht_interface.h index a312a4c6071ee3907242c8cf9e2be25702a6af17..30d0a60bba1cd504ef3c2d834159da8414972251 100644 --- a/include/opendht/dht_interface.h +++ b/include/opendht/dht_interface.h @@ -165,6 +165,7 @@ public: */ virtual size_t listen(const InfoHash&, GetCallback, Value::Filter={}, Where w = {}) = 0; virtual size_t listen(const InfoHash& key, GetCallbackSimple cb, Value::Filter f={}, Where w = {}) = 0; + virtual size_t listen(const InfoHash&, ValueCallback, Value::Filter={}, Where w = {}) = 0; virtual bool cancelListen(const InfoHash&, size_t token) = 0; diff --git a/include/opendht/dht_proxy_client.h b/include/opendht/dht_proxy_client.h index a15c84c4d22ac386660aed448078099b59b94f78..e3f8287575a31b39af9e1425f296157d9aa95bf3 100644 --- a/include/opendht/dht_proxy_client.h +++ b/include/opendht/dht_proxy_client.h @@ -161,8 +161,16 @@ public: * * @return a token to cancel the listener later. */ - virtual size_t listen(const InfoHash&, GetCallback, Value::Filter={}, Where={}); - virtual size_t listen(const InfoHash& key, GetCallbackSimple cb, Value::Filter f={}, Where w = {}) { + virtual size_t listen(const InfoHash&, ValueCallback, Value::Filter={}, Where={}); + + virtual size_t listen(const InfoHash& key, GetCallback cb, Value::Filter f={}, Where w={}) { + return listen(key, [cb](const std::vector<Sp<Value>>& vals, bool expired){ + if (not expired) + return cb(vals); + return true; + }, std::forward<Value::Filter>(f), std::forward<Where>(w)); + } + virtual size_t listen(const InfoHash& key, GetCallbackSimple cb, Value::Filter f={}, Where w={}) { return listen(key, bindGetCb(cb), std::forward<Value::Filter>(f), std::forward<Where>(w)); } virtual bool cancelListen(const InfoHash& key, size_t token); diff --git a/include/opendht/dhtrunner.h b/include/opendht/dhtrunner.h index 89090330f8cb0bdf2a2350a5046615b5417e4434..d14a67257e77796a9e3d2d4d447f180c1dca3adb 100644 --- a/include/opendht/dhtrunner.h +++ b/include/opendht/dhtrunner.h @@ -136,7 +136,15 @@ public: query(hash, cb, bindDoneCb(done_cb), q); } - std::future<size_t> listen(InfoHash key, GetCallback vcb, Value::Filter f = Value::AllFilter(), Where w = {}); + std::future<size_t> listen(InfoHash key, ValueCallback vcb, Value::Filter f = Value::AllFilter(), Where w = {}); + + std::future<size_t> listen(InfoHash key, GetCallback cb, Value::Filter f={}, Where w={}) { + return listen(key, [cb](const std::vector<Sp<Value>>& vals, bool expired){ + if (not expired) + return cb(vals); + return true; + }, std::forward<Value::Filter>(f), std::forward<Where>(w)); + } std::future<size_t> listen(const std::string& key, GetCallback vcb, Value::Filter f = Value::AllFilter(), Where w = {}); std::future<size_t> listen(InfoHash key, GetCallbackSimple cb, Value::Filter f = Value::AllFilter(), Where w = {}) { return listen(key, bindGetCb(cb), f, w); @@ -446,15 +454,8 @@ private: /** * Store current listeners and translates global tokens for each client. */ - struct Listener { - size_t tokenClassicDht; - size_t tokenProxyDht; - GetCallback gcb; - InfoHash hash; - Value::Filter f; - Where w; - }; - std::map<size_t, Listener> listeners_ {}; + struct Listener; + std::map<size_t, Listener> listeners_; size_t listener_token_ {1}; mutable std::mutex dht_mtx {}; diff --git a/include/opendht/securedht.h b/include/opendht/securedht.h index e1d31a9a4127b472fcd47908c4a9032611abf45a..287dca53fe96a410c8f9b5aefa6f3ab2ee2bbacd 100644 --- a/include/opendht/securedht.h +++ b/include/opendht/securedht.h @@ -100,8 +100,6 @@ public: get(key, bindGetCb(cb), bindDoneCb(donecb), std::forward<Value::Filter>(f), std::forward<Where>(w)); } - size_t listen(const InfoHash& id, GetCallback cb, Value::Filter = {}, Where w = {}); - /** * Will take ownership of the value, sign it using our private key and put it in the DHT. */ @@ -280,8 +278,11 @@ public: bool cancelPut(const InfoHash& h, const Value::Id& vid) { return dht_->cancelPut(h, vid); } + + size_t listen(const InfoHash& key, ValueCallback, Value::Filter={}, Where={}); + size_t listen(const InfoHash& key, GetCallback cb, Value::Filter = {}, Where w = {}); size_t listen(const InfoHash& key, GetCallbackSimple cb, Value::Filter f={}, Where w = {}) { - return dht_->listen(key, cb, f, w); + return listen(key, bindGetCb(cb), f, w); } bool cancelListen(const InfoHash& h, size_t token) { return dht_->cancelListen(h, token); @@ -331,6 +332,8 @@ private: SecureDht(const SecureDht&) = delete; SecureDht& operator=(const SecureDht&) = delete; + Sp<Value> checkValue(const Sp<Value>& v); + ValueCallback getCallbackFilter(ValueCallback, Value::Filter&&); GetCallback getCallbackFilter(GetCallback, Value::Filter&&); Sp<crypto::PrivateKey> key_ {}; diff --git a/src/dht.cpp b/src/dht.cpp index 19bec3ae18766be96cb6d45278e9c1d09712237d..298267d43319b6514528d7df89279f674975716b 100644 --- a/src/dht.cpp +++ b/src/dht.cpp @@ -786,7 +786,7 @@ Dht::announce(const InfoHash& id, } size_t -Dht::listenTo(const InfoHash& id, sa_family_t af, GetCallback cb, Value::Filter f, const Sp<Query>& q) +Dht::listenTo(const InfoHash& id, sa_family_t af, ValueCallback cb, Value::Filter f, const Sp<Query>& q) { if (!isRunning(af)) return 0; @@ -803,7 +803,7 @@ Dht::listenTo(const InfoHash& id, sa_family_t af, GetCallback cb, Value::Filter } size_t -Dht::listen(const InfoHash& id, GetCallback cb, Value::Filter f, Where where) +Dht::listen(const InfoHash& id, ValueCallback cb, Value::Filter f, Where where) { scheduler.syncTime(); @@ -811,27 +811,7 @@ Dht::listen(const InfoHash& id, GetCallback cb, Value::Filter f, Where where) auto vals = std::make_shared<std::map<Value::Id, Sp<Value>>>(); auto token = ++listener_token; - auto gcb = [=](const std::vector<Sp<Value>>& values) { - std::vector<Sp<Value>> newvals; - for (const auto& v : values) { - auto it = vals->find(v->id); - if (it == vals->cend() || !(*it->second == *v)) - newvals.push_back(v); - } - if (!newvals.empty()) { - if (!cb(newvals)) { - // cancelListen is useful here, because we need to cancel on IPv4 and 6 - cancelListen(id, token); - return false; - } - for (const auto& v : newvals) { - auto it = vals->emplace(v->id, v); - if (not it.second) - it.first->second = v; - } - } - return true; - }; + auto gcb = OpValueCache::cacheCallback(std::move(cb)); auto query = std::make_shared<Query>(q); auto filter = f.chain(q.where.getFilter()); @@ -843,7 +823,7 @@ Dht::listen(const InfoHash& id, GetCallback cb, Value::Filter f, Where where) if (not st->second.empty()) { std::vector<Sp<Value>> newvals = st->second.get(filter); if (not newvals.empty()) { - if (!cb(newvals)) + if (!gcb(newvals, false)) return 0; for (const auto& v : newvals) { auto it = vals->emplace(v->id, v); @@ -1149,7 +1129,7 @@ Dht::storageChanged(const InfoHash& id, Storage& st, ValueStorage& v) { if (not st.local_listeners.empty()) { DHT_LOG.d(id, "[store %s] %lu local listeners", id.toString().c_str(), st.local_listeners.size()); - std::vector<std::pair<GetCallback, std::vector<Sp<Value>>>> cbs; + std::vector<std::pair<ValueCallback, std::vector<Sp<Value>>>> cbs; for (const auto& l : st.local_listeners) { std::vector<Sp<Value>> vals; if (not l.second.filter or l.second.filter(*v.data)) @@ -1163,7 +1143,7 @@ Dht::storageChanged(const InfoHash& id, Storage& st, ValueStorage& v) } // listeners are copied: they may be deleted by the callback for (auto& cb : cbs) - cb.first(cb.second); + cb.first(cb.second, false); } if (not st.listeners.empty()) { diff --git a/src/dht_proxy_client.cpp b/src/dht_proxy_client.cpp index 6f9cd390c04ca1a224052318ddc4ec69ba206eea..8562bb39cd46b04e18f9aa2debf12e9aaf0f07db 100644 --- a/src/dht_proxy_client.cpp +++ b/src/dht_proxy_client.cpp @@ -443,7 +443,7 @@ DhtProxyClient::getPublicAddress(sa_family_t family) } size_t -DhtProxyClient::listen(const InfoHash& key, GetCallback cb, Value::Filter filter, Where where) { +DhtProxyClient::listen(const InfoHash& key, ValueCallback cb, Value::Filter filter, Where where) { auto it = listeners_.find(key); if (it == listeners_.end()) { it = listeners_.emplace(key, ProxySearch{}).first; diff --git a/src/dhtrunner.cpp b/src/dhtrunner.cpp index 892fd111c9afb89780e39654980163354e741208..adeb5a06a1afa8b0a4d45767db93017f5c861feb 100644 --- a/src/dhtrunner.cpp +++ b/src/dhtrunner.cpp @@ -43,6 +43,15 @@ namespace dht { constexpr std::chrono::seconds DhtRunner::BOOTSTRAP_PERIOD; +struct DhtRunner::Listener { + size_t tokenClassicDht; + size_t tokenProxyDht; + ValueCallback gcb; + InfoHash hash; + Value::Filter f; + Where w; +}; + DhtRunner::DhtRunner() : dht_() #if OPENDHT_PROXY_CLIENT , dht_via_proxy_() @@ -542,7 +551,7 @@ DhtRunner::query(const InfoHash& hash, QueryCallback cb, DoneCallback done_cb, Q } std::future<size_t> -DhtRunner::listen(InfoHash hash, GetCallback vcb, Value::Filter f, Where w) +DhtRunner::listen(InfoHash hash, ValueCallback vcb, Value::Filter f, Where w) { auto ret_token = std::make_shared<std::promise<size_t>>(); { @@ -553,8 +562,8 @@ DhtRunner::listen(InfoHash hash, GetCallback vcb, Value::Filter f, Where w) listener.hash = hash; listener.f = std::move(f); listener.w = std::move(w); - listener.gcb = [hash,vcb,tokenbGlobal,this](const std::vector<Sp<Value>>& vals){ - if (not vcb(vals)) { + listener.gcb = [hash,vcb,tokenbGlobal,this](const std::vector<Sp<Value>>& vals, bool expired){ + if (not vcb(vals, expired)) { #if OPENDHT_PROXY_CLIENT cancelListen(hash, tokenbGlobal); #endif diff --git a/src/listener.h b/src/listener.h index ad001ecaa98d538260e034e3976caeccb3a3403e..f491929e03d2b0bfffbf2a096ce3278183fdb200 100644 --- a/src/listener.h +++ b/src/listener.h @@ -44,7 +44,7 @@ struct Listener { struct LocalListener { Sp<Query> query; Value::Filter filter; - GetCallback get_cb; + ValueCallback get_cb; }; diff --git a/src/op_cache.h b/src/op_cache.h index 8110aaae2f9c0675bf06d3caa288f6b1eace5f94..e20f8845462297cdce30d21b0a0cca030d3d91f7 100644 --- a/src/op_cache.h +++ b/src/op_cache.h @@ -30,6 +30,56 @@ struct OpCacheValueStorage OpCacheValueStorage(Sp<Value> val = {}) : data(val) {} }; +class OpValueCache { +public: + OpValueCache(ValueCallback&& cb) : callback(std::forward<ValueCallback>(cb)) {} + + static ValueCallback cacheCallback(ValueCallback&& cb) { + auto cache = std::make_shared<OpValueCache>(std::forward<ValueCallback>(cb)); + return [cache](const std::vector<Sp<Value>>& vals, bool expired){ + return cache->onValue(vals, expired); + }; + } + + bool onValue(const std::vector<Sp<Value>>& vals, bool expired) { + if (expired) + return onValuesExpired(vals); + else + return onValuesAdded(vals); + } + + bool onValuesAdded(const std::vector<Sp<Value>>& vals) { + std::vector<Sp<Value>> newValues; + for (const auto& v : vals) { + auto viop = values.emplace(v->id, OpCacheValueStorage{v}); + if (viop.second) { + newValues.emplace_back(v); + //std::cout << "onValuesAdded: new value " << v->id << std::endl; + } else { + viop.first->second.refCount++; + //std::cout << "onValuesAdded: " << viop.first->second.refCount << " refs for value " << v->id << std::endl; + } + } + return callback(newValues, false); + } + bool onValuesExpired(const std::vector<Sp<Value>>& vals) { + std::vector<Sp<Value>> expiredValues; + for (const auto& v : vals) { + auto vit = values.find(v->id); + if (vit != values.end()) { + vit->second.refCount--; + //std::cout << "onValuesExpired: " << vit->second.refCount << " refs remaining for value " << v->id << std::endl; + if (not vit->second.refCount) + values.erase(vit); + } + } + return callback(expiredValues, true); + } +private: + std::map<Value::Id, OpCacheValueStorage> values {}; + ValueCallback callback; +}; + class OpCache { public: bool onValue(const std::vector<Sp<Value>>& vals, bool expired) { @@ -53,11 +103,11 @@ public: } } auto list = listeners; - for (auto& l : list) { - l.second.get_cb(newValues); - } + for (auto& l : list) + l.second.get_cb(l.second.filter.filter(newValues), false); } void onValuesExpired(const std::vector<Sp<Value>>& vals) { + std::vector<Sp<Value>> expiredValues; for (const auto& v : vals) { auto vit = values.find(v->id); if (vit != values.end()) { @@ -67,15 +117,18 @@ public: values.erase(vit); } } + auto list = listeners; + for (auto& l : list) + l.second.get_cb(l.second.filter.filter(expiredValues), true); } - void addListener(size_t token, GetCallback get_cb, Sp<Query> q, Value::Filter filter) { - listeners.emplace(token, LocalListener{q, filter, get_cb}); + void addListener(size_t token, ValueCallback cb, Sp<Query> q, Value::Filter filter) { + listeners.emplace(token, LocalListener{q, filter, cb}); std::vector<Sp<Value>> newValues; newValues.reserve(values.size()); for (const auto& v : values) newValues.emplace_back(v.second.data); - get_cb(newValues); + cb(newValues, false); } bool removeListener(size_t token) { @@ -94,7 +147,7 @@ private: class SearchCache { public: - size_t listen(GetCallback get_cb, Sp<Query> q, Value::Filter filter, std::function<size_t(Sp<Query>, ValueCallback)> onListen) { + size_t listen(ValueCallback get_cb, Sp<Query> q, Value::Filter filter, std::function<size_t(Sp<Query>, ValueCallback)> onListen) { // find exact match auto op = ops.find(q); if (op == ops.end()) { diff --git a/src/search.h b/src/search.h index d59aff471885d1035c1a4bdef9cd2072fb84c12d..9a492eb4d9cac4837dc3e7c2f3655aadfa14e4ea 100644 --- a/src/search.h +++ b/src/search.h @@ -462,7 +462,7 @@ struct Dht::Search { bool isAnnounced(Value::Id id) const; bool isListening(time_point now) const; - size_t listen(GetCallback cb, Value::Filter f, const Sp<Query>& q, Scheduler& scheduler) { + size_t listen(ValueCallback cb, Value::Filter f, const Sp<Query>& q, Scheduler& scheduler) { //DHT_LOG.e(id, "[search %s IPv%c] listen", id.toString().c_str(), (af == AF_INET) ? '4' : '6'); return cache.listen(cb, q, f, [&](const Sp<Query>& q, ValueCallback vcb){ done = false; diff --git a/src/securedht.cpp b/src/securedht.cpp index 7dd8dc86909002e28fd1a0c3b2e0553bc7ed349f..3b67abdb7e9d47dce3cf6a4c5ccac3784485457f 100644 --- a/src/securedht.cpp +++ b/src/securedht.cpp @@ -221,50 +221,73 @@ SecureDht::findPublicKey(const InfoHash& node, std::function<void(const Sp<const }); } +Sp<Value> +SecureDht::checkValue(const Sp<Value>& v) +{ + // Decrypt encrypted values + if (v->isEncrypted()) { + if (not key_) { +#if OPENDHT_PROXY_SERVER + if (forward_all_) // We are currently a proxy, send messages to clients. + return v; +#endif + return {}; + } + try { + Value decrypted_val (decrypt(*v)); + if (decrypted_val.recipient == getId()) { + nodesPubKeys_[decrypted_val.owner->getId()] = decrypted_val.owner; + return std::make_shared<Value>(std::move(decrypted_val)); + } + // Ignore values belonging to other people + } catch (const std::exception& e) { + DHT_LOG.WARN("Could not decrypt value %s : %s", v->toString().c_str(), e.what()); + } + } + // Check signed values + else if (v->isSigned()) { + if (v->owner and v->owner->checkSignature(v->getToSign(), v->signature)) { + nodesPubKeys_[v->owner->getId()] = v->owner; + return v; + } + else + DHT_LOG.WARN("Signature verification failed for %s", v->toString().c_str()); + } + // Forward normal values + else { + return v; + } + return {}; +} + +ValueCallback +SecureDht::getCallbackFilter(ValueCallback cb, Value::Filter&& filter) +{ + return [=](const std::vector<Sp<Value>>& values, bool expired) { + std::vector<Sp<Value>> tmpvals {}; + for (const auto& v : values) { + if (auto nv = checkValue(v)) + if (not filter or filter(*nv)) + tmpvals.emplace_back(std::move(nv)); + } + if (cb and not tmpvals.empty()) + return cb(tmpvals, expired); + return true; + }; +} + + GetCallback SecureDht::getCallbackFilter(GetCallback cb, Value::Filter&& filter) { return [=](const std::vector<Sp<Value>>& values) { std::vector<Sp<Value>> tmpvals {}; for (const auto& v : values) { - // Decrypt encrypted values - if (v->isEncrypted()) { - if (not key_) { -#if OPENDHT_PROXY_SERVER - if (forward_all_) // We are currently a proxy, send messages to clients. - tmpvals.push_back(v); -#endif - continue; - } - try { - Value decrypted_val (decrypt(*v)); - if (decrypted_val.recipient == getId()) { - nodesPubKeys_[decrypted_val.owner->getId()] = decrypted_val.owner; - if (not filter or filter(decrypted_val)) - tmpvals.push_back(std::make_shared<Value>(std::move(decrypted_val))); - } - // Ignore values belonging to other people - } catch (const std::exception& e) { - DHT_LOG.WARN("Could not decrypt value %s : %s", v->toString().c_str(), e.what()); - } - } - // Check signed values - else if (v->isSigned()) { - if (v->owner and v->owner->checkSignature(v->getToSign(), v->signature)) { - nodesPubKeys_[v->owner->getId()] = v->owner; - if (not filter or filter(*v)) - tmpvals.push_back(v); - } - else - DHT_LOG.WARN("Signature verification failed for %s", v->toString().c_str()); - } - // Forward normal values - else { - if (not filter or filter(*v)) - tmpvals.push_back(v); - } + if (auto nv = checkValue(v)) + if (not filter or filter(*nv)) + tmpvals.emplace_back(std::move(nv)); } - if (cb && not tmpvals.empty()) + if (cb and not tmpvals.empty()) return cb(tmpvals); return true; }; @@ -276,6 +299,13 @@ SecureDht::get(const InfoHash& id, GetCallback cb, DoneCallback donecb, Value::F dht_->get(id, getCallbackFilter(cb, std::forward<Value::Filter>(f)), donecb, {}, std::forward<Where>(w)); } +size_t +SecureDht::listen(const InfoHash& id, ValueCallback cb, Value::Filter f, Where w) +{ + return dht_->listen(id, getCallbackFilter(cb, std::forward<Value::Filter>(f)), {}, std::forward<Where>(w)); +} + + size_t SecureDht::listen(const InfoHash& id, GetCallback cb, Value::Filter f, Where w) { diff --git a/tools/dhtnode.cpp b/tools/dhtnode.cpp index 7f0a95b84d6cc78c783e7dc89b82b4e6a1c48c54..777903a18492a09ee51db47ce0b93560885efe02 100644 --- a/tools/dhtnode.cpp +++ b/tools/dhtnode.cpp @@ -319,9 +319,10 @@ void cmd_loop(std::shared_ptr<DhtRunner>& dht, dht_params& params else if (op == "l") { std::string rem; std::getline(iss, rem); - auto token = dht->listen(id, [](std::shared_ptr<Value> value) { - std::cout << "Listen: found value:" << std::endl; - std::cout << "\t" << *value << std::endl; + auto token = dht->listen(id, [](const std::vector<std::shared_ptr<Value>>& values, bool expired) { + std::cout << "Listen: found " << values.size() << " values" << (expired ? " expired" : "") << std::endl; + for (const auto& value : values) + std::cout << "\t" << *value << std::endl; return true; }, {}, dht::Where {std::move(rem)}); auto t = token.get();