From 01a2a8d4d9c743bbc536ee2de8e92066812c8225 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Adrien=20B=C3=A9raud?= <adrien.beraud@savoirfairelinux.com>
Date: Tue, 25 May 2021 14:22:23 -0400
Subject: [PATCH] securedht: add cert store query with sha256 PkId

---
 include/opendht/callbacks.h |  1 +
 include/opendht/dhtrunner.h |  3 ++-
 include/opendht/securedht.h |  4 +++-
 src/dhtrunner.cpp           | 12 ++++++------
 src/securedht.cpp           | 17 ++++++++++++++---
 5 files changed, 26 insertions(+), 11 deletions(-)

diff --git a/include/opendht/callbacks.h b/include/opendht/callbacks.h
index 2820ed9b..b20e4a12 100644
--- a/include/opendht/callbacks.h
+++ b/include/opendht/callbacks.h
@@ -168,6 +168,7 @@ using ShutdownCallback = std::function<void()>;
 using IdentityAnnouncedCb = std::function<void(bool)>;
 
 using CertificateStoreQuery = std::function<std::vector<std::shared_ptr<crypto::Certificate>>(const InfoHash& pk_id)>;
+using CertificateStoreLongQuery = std::function<std::vector<std::shared_ptr<crypto::Certificate>>(const PkId& pk_id)>;
 
 typedef bool (*GetCallbackRaw)(std::shared_ptr<Value>, void *user_data);
 typedef bool (*ValueCallbackRaw)(std::shared_ptr<Value>, bool expired, void *user_data);
diff --git a/include/opendht/dhtrunner.h b/include/opendht/dhtrunner.h
index febf14a8..b4511797 100644
--- a/include/opendht/dhtrunner.h
+++ b/include/opendht/dhtrunner.h
@@ -74,6 +74,7 @@ public:
         std::shared_ptr<PeerDiscovery> peerDiscovery {};
         StatusCallback statusChangedCallback {};
         CertificateStoreQuery certificateStore {};
+        CertificateStoreLongQuery certificateStoreSha256 {};
         IdentityAnnouncedCb identityAnnouncedCb {};
         Context() {}
     };
@@ -368,7 +369,7 @@ public:
 
     void findCertificate(InfoHash hash, std::function<void(const std::shared_ptr<crypto::Certificate>&)>);
     void registerCertificate(std::shared_ptr<crypto::Certificate> cert);
-    void setLocalCertificateStore(CertificateStoreQuery&& query_method);
+    void setLocalCertificateStore(CertificateStoreQuery&& query_method, CertificateStoreLongQuery&& query_sha256_method = {});
 
     /**
      * @param port: Local port to bind. Both IPv4 and IPv6 will be tried (ANY).
diff --git a/include/opendht/securedht.h b/include/opendht/securedht.h
index ff021890..7cffdcf8 100644
--- a/include/opendht/securedht.h
+++ b/include/opendht/securedht.h
@@ -145,8 +145,9 @@ public:
      * The search key used is the public key ID, so there may be multiple certificates retured, signed with
      * the same private key.
      */
-    void setLocalCertificateStore(CertificateStoreQuery&& query_method) {
+    void setLocalCertificateStore(CertificateStoreQuery&& query_method, CertificateStoreLongQuery&& query_sha256_method) {
         localQueryMethod_ = std::move(query_method);
+        localQuerySha256Method_ = std::move(query_sha256_method);
     }
 
     /**
@@ -357,6 +358,7 @@ private:
 
     // method to query the local certificate store
     CertificateStoreQuery localQueryMethod_ {};
+    CertificateStoreLongQuery localQuerySha256Method_ {};
 
     // our certificate cache
     std::map<InfoHash, Sp<crypto::Certificate>> nodesCertificates_ {};
diff --git a/src/dhtrunner.cpp b/src/dhtrunner.cpp
index 645acc27..6c948fe1 100644
--- a/src/dhtrunner.cpp
+++ b/src/dhtrunner.cpp
@@ -185,10 +185,10 @@ DhtRunner::run(const Config& config, Context&& context)
     if (context.statusChangedCallback) {
         statusCb = std::move(context.statusChangedCallback);
     }
-    if (context.certificateStore) {
-        dht_->setLocalCertificateStore(std::move(context.certificateStore));
+    if (context.certificateStore || context.certificateStoreSha256) {
+        dht_->setLocalCertificateStore(std::move(context.certificateStore), std::move(context.certificateStoreSha256));
         if (dht_via_proxy_)
-            dht_via_proxy_->setLocalCertificateStore(std::move(context.certificateStore));
+            dht_via_proxy_->setLocalCertificateStore(std::move(context.certificateStore), std::move(context.certificateStoreSha256));
     }
 
     if (not config.threaded)
@@ -627,14 +627,14 @@ DhtRunner::registerCertificate(std::shared_ptr<crypto::Certificate> cert) {
     activeDht()->registerCertificate(cert);
 }
 void
-DhtRunner::setLocalCertificateStore(CertificateStoreQuery&& query_method) {
+DhtRunner::setLocalCertificateStore(CertificateStoreQuery&& query_method, CertificateStoreLongQuery&& query_sha256_method) {
     std::lock_guard<std::mutex> lck(dht_mtx);
 #ifdef OPENDHT_PROXY_CLIENT
     if (dht_via_proxy_)
-        dht_via_proxy_->setLocalCertificateStore(std::forward<CertificateStoreQuery>(query_method));
+        dht_via_proxy_->setLocalCertificateStore(std::move(query_method), std::move(query_sha256_method));
 #endif
     if (dht_)
-        dht_->setLocalCertificateStore(std::forward<CertificateStoreQuery>(query_method));
+        dht_->setLocalCertificateStore(std::move(query_method), std::move(query_sha256_method));
 }
 
 time_point
diff --git a/src/securedht.cpp b/src/securedht.cpp
index e64d77ca..b044ae6e 100644
--- a/src/securedht.cpp
+++ b/src/securedht.cpp
@@ -254,10 +254,21 @@ SecureDht::findCertificate(const PkId& node, const std::function<void(const Sp<c
         if (cb)
             cb(b);
         return;
-    } else {
-        if (cb)
-            cb(nullptr);
     }
+    if (localQueryMethod_) {
+        auto res = localQuerySha256Method_(node);
+        if (not res.empty()) {
+            if (logger_)
+                logger_->d("Registering certificate from local store for %s", node.to_c_str());
+            nodesCertificates_.emplace(res.front()->getId(), res.front());
+            nodesCertificatesLong_.emplace(node, res.front());
+            if (cb)
+                cb(res.front());
+            return;
+        }
+    }
+    if (cb)
+        cb(nullptr);
 }
 
 void
-- 
GitLab