From e88e78a2d95f46fee8466095865b3ca801c70ff7 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Adrien=20B=C3=A9raud?= <adrien.beraud@savoirfairelinux.com>
Date: Thu, 25 Jun 2015 11:40:49 -0400
Subject: [PATCH] securedht: support external certificate store

With setLocalCertificateStore, add an optional
custom callback to find an immediately available
certificate to use for encryption etc.
---
 include/opendht/dhtrunner.h |  4 ++++
 include/opendht/securedht.h | 16 ++++++++++++++++
 src/securedht.cpp           | 22 +++++++++++++++++-----
 3 files changed, 37 insertions(+), 5 deletions(-)

diff --git a/include/opendht/dhtrunner.h b/include/opendht/dhtrunner.h
index 8998e62c..2a14ac19 100644
--- a/include/opendht/dhtrunner.h
+++ b/include/opendht/dhtrunner.h
@@ -223,6 +223,10 @@ public:
         std::lock_guard<std::mutex> lck(dht_mtx);
         dht_->registerCertificate(cert);
     }
+    void setLocalCertificateStore(SecureDht::CertificateStoreQuery&& query_method) {
+        std::lock_guard<std::mutex> lck(dht_mtx);
+        dht_->setLocalCertificateStore(std::move(query_method));
+    }
 
     /**
      * If threaded is false, loop() must be called periodically.
diff --git a/include/opendht/securedht.h b/include/opendht/securedht.h
index d1e29a37..2ba2d292 100644
--- a/include/opendht/securedht.h
+++ b/include/opendht/securedht.h
@@ -121,6 +121,18 @@ public:
 
     const std::shared_ptr<crypto::Certificate> getCertificate(const InfoHash& node) const;
 
+
+    using CertificateStoreQuery = std::function<std::vector<std::shared_ptr<crypto::Certificate>>(const InfoHash& pk_id)>;
+
+    /**
+     * Allows to set a custom callback called by the library to find a locally-stored certificate.
+     * The search key used is the public key ID, so there may be multiple certificates retured, signed with
+     * the same private key.
+     */
+    void setLocalCertificateStore(CertificateStoreQuery&& query_method) {
+        localQueryMethod_ = std::move(query_method);
+    }
+
 private:
     // prevent copy
     SecureDht(const SecureDht&) = delete;
@@ -131,6 +143,10 @@ private:
     std::shared_ptr<crypto::PrivateKey> key_ {};
     std::shared_ptr<crypto::Certificate> certificate_ {};
 
+    // method to query the local certificate store
+    CertificateStoreQuery localQueryMethod_ {};
+
+    // our certificate cache
     std::map<InfoHash, std::shared_ptr<crypto::Certificate>> nodesCertificates_ {};
 
     std::uniform_int_distribution<Value::Id> rand_id {};
diff --git a/src/securedht.cpp b/src/securedht.cpp
index f74fec7d..09fd4cf1 100644
--- a/src/securedht.cpp
+++ b/src/securedht.cpp
@@ -144,15 +144,16 @@ SecureDht::registerCertificate(const InfoHash& node, const Blob& data)
     InfoHash h = crt->getPublicKey().getId();
     if (node == h) {
         DHT_DEBUG("Registering public key for %s", h.toString().c_str());
-        nodesCertificates_[h] = crt;
+        auto it = nodesCertificates_.find(h);
+        if (it == nodesCertificates_.end())
+            std::tie(it, std::ignore) = nodesCertificates_.emplace(h, std::move(crt));
+        else
+            it->second = std::move(crt);
+        return it->second;
     } else {
         DHT_DEBUG("Certificate %s for node %s does not match node id !", h.toString().c_str(), node.toString().c_str());
         return nullptr;
     }
-    auto it = nodesCertificates_.find(h);
-    if (it == nodesCertificates_.end())
-        return nullptr;
-    return it->second;
 }
 
 void
@@ -172,6 +173,17 @@ SecureDht::findCertificate(const InfoHash& node, std::function<void(const std::s
             cb(b);
         return;
     }
+    if (localQueryMethod_) {
+        auto res = localQueryMethod_(node);
+        if (not res.empty()) {
+            DHT_DEBUG("Registering public key from local store for %s", node.toString().c_str());
+            nodesCertificates_.emplace(node, res.front());
+            if (cb)
+                cb(res.front());
+            return;
+        }
+    }
+
     auto found = std::make_shared<bool>(false);
     Dht::get(node, [cb,node,found,this](const std::vector<std::shared_ptr<Value>>& vals) {
         if (*found)
-- 
GitLab