diff --git a/include/opendht/dhtrunner.h b/include/opendht/dhtrunner.h index 8998e62c7b1ba66ed851cc11360f49ae8b431967..2a14ac197557b7d616e2035dfb5c6846ea2098ed 100644 --- a/include/opendht/dhtrunner.h +++ b/include/opendht/dhtrunner.h @@ -223,6 +223,10 @@ public: std::lock_guard<std::mutex> lck(dht_mtx); dht_->registerCertificate(cert); } + void setLocalCertificateStore(SecureDht::CertificateStoreQuery&& query_method) { + std::lock_guard<std::mutex> lck(dht_mtx); + dht_->setLocalCertificateStore(std::move(query_method)); + } /** * If threaded is false, loop() must be called periodically. diff --git a/include/opendht/securedht.h b/include/opendht/securedht.h index d1e29a37db6f153d87f5d33211cd5291efd770dc..2ba2d2928161d2bf1ce84a64a8b7db5d2d6f6a47 100644 --- a/include/opendht/securedht.h +++ b/include/opendht/securedht.h @@ -121,6 +121,18 @@ public: const std::shared_ptr<crypto::Certificate> getCertificate(const InfoHash& node) const; + + using CertificateStoreQuery = std::function<std::vector<std::shared_ptr<crypto::Certificate>>(const InfoHash& pk_id)>; + + /** + * Allows to set a custom callback called by the library to find a locally-stored certificate. + * The search key used is the public key ID, so there may be multiple certificates retured, signed with + * the same private key. + */ + void setLocalCertificateStore(CertificateStoreQuery&& query_method) { + localQueryMethod_ = std::move(query_method); + } + private: // prevent copy SecureDht(const SecureDht&) = delete; @@ -131,6 +143,10 @@ private: std::shared_ptr<crypto::PrivateKey> key_ {}; std::shared_ptr<crypto::Certificate> certificate_ {}; + // method to query the local certificate store + CertificateStoreQuery localQueryMethod_ {}; + + // our certificate cache std::map<InfoHash, std::shared_ptr<crypto::Certificate>> nodesCertificates_ {}; std::uniform_int_distribution<Value::Id> rand_id {}; diff --git a/src/securedht.cpp b/src/securedht.cpp index f74fec7d18b4092323301b1e5d46e0664940412a..09fd4cf1bfea908eb1db713d43489efee8871983 100644 --- a/src/securedht.cpp +++ b/src/securedht.cpp @@ -144,15 +144,16 @@ SecureDht::registerCertificate(const InfoHash& node, const Blob& data) InfoHash h = crt->getPublicKey().getId(); if (node == h) { DHT_DEBUG("Registering public key for %s", h.toString().c_str()); - nodesCertificates_[h] = crt; + auto it = nodesCertificates_.find(h); + if (it == nodesCertificates_.end()) + std::tie(it, std::ignore) = nodesCertificates_.emplace(h, std::move(crt)); + else + it->second = std::move(crt); + return it->second; } else { DHT_DEBUG("Certificate %s for node %s does not match node id !", h.toString().c_str(), node.toString().c_str()); return nullptr; } - auto it = nodesCertificates_.find(h); - if (it == nodesCertificates_.end()) - return nullptr; - return it->second; } void @@ -172,6 +173,17 @@ SecureDht::findCertificate(const InfoHash& node, std::function<void(const std::s cb(b); return; } + if (localQueryMethod_) { + auto res = localQueryMethod_(node); + if (not res.empty()) { + DHT_DEBUG("Registering public key from local store for %s", node.toString().c_str()); + nodesCertificates_.emplace(node, res.front()); + if (cb) + cb(res.front()); + return; + } + } + auto found = std::make_shared<bool>(false); Dht::get(node, [cb,node,found,this](const std::vector<std::shared_ptr<Value>>& vals) { if (*found)