From e68a2fd9672f06fb781f16f68c307c82d2d3d6d9 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Adrien=20B=C3=A9raud?= <adrien.beraud@savoirfairelinux.com>
Date: Tue, 17 May 2016 19:31:49 -0400
Subject: [PATCH] securedht: add public key cache

---
 include/opendht/securedht.h |  3 +++
 src/securedht.cpp           | 42 +++++++++++++++++++++++++++++++++----
 2 files changed, 41 insertions(+), 4 deletions(-)

diff --git a/include/opendht/securedht.h b/include/opendht/securedht.h
index d997495c..b61b3238 100644
--- a/include/opendht/securedht.h
+++ b/include/opendht/securedht.h
@@ -126,11 +126,13 @@ public:
     Value decrypt(const Value& v);
 
     void findCertificate(const InfoHash& node, std::function<void(const std::shared_ptr<crypto::Certificate>)> cb);
+    void findPublicKey(const InfoHash& node, std::function<void(const std::shared_ptr<crypto::PublicKey>)> cb);
 
     const std::shared_ptr<crypto::Certificate> registerCertificate(const InfoHash& node, const Blob& cert);
     void registerCertificate(std::shared_ptr<crypto::Certificate>& cert);
 
     const std::shared_ptr<crypto::Certificate> getCertificate(const InfoHash& node) const;
+    const std::shared_ptr<crypto::PublicKey> getPublicKey(const InfoHash& node) const;
 
 
     
@@ -159,6 +161,7 @@ private:
 
     // our certificate cache
     std::map<InfoHash, std::shared_ptr<crypto::Certificate>> nodesCertificates_ {};
+    std::map<InfoHash, std::shared_ptr<crypto::PublicKey>> nodesPubKeys_ {};
 
     std::uniform_int_distribution<Value::Id> rand_id {};
 };
diff --git a/src/securedht.cpp b/src/securedht.cpp
index 9ace496b..8ca9ac96 100644
--- a/src/securedht.cpp
+++ b/src/securedht.cpp
@@ -141,6 +141,18 @@ SecureDht::getCertificate(const InfoHash& node) const
         return it->second;
 }
 
+const std::shared_ptr<crypto::PublicKey>
+SecureDht::getPublicKey(const InfoHash& node) const
+{
+    if (node == getId())
+        return std::make_shared<crypto::PublicKey>(certificate_->getPublicKey());
+    auto it = nodesPubKeys_.find(node);
+    if (it == nodesPubKeys_.end())
+        return nullptr;
+    else
+        return it->second;
+}
+
 const std::shared_ptr<crypto::Certificate>
 SecureDht::registerCertificate(const InfoHash& node, const Blob& data)
 {
@@ -213,6 +225,26 @@ SecureDht::findCertificate(const InfoHash& node, std::function<void(const std::s
     }, Value::TypeFilter(CERTIFICATE_TYPE));
 }
 
+void
+SecureDht::findPublicKey(const InfoHash& node, std::function<void(const std::shared_ptr<crypto::PublicKey>)> cb)
+{
+    auto pk = getPublicKey(node);
+    if (pk && *pk) {
+        DHT_LOG.DEBUG("Found public key from cache for %s", node.toString().c_str());
+        if (cb)
+            cb(pk);
+        return;
+    }
+    findCertificate(node, [=](const std::shared_ptr<crypto::Certificate> crt) {
+        if (crt && *crt) {
+            auto pk = std::make_shared<crypto::PublicKey>(crt->getPublicKey());
+            nodesPubKeys_[pk->getId()] = pk;
+            if (cb) cb(pk);
+        } else {
+            if (cb) cb(nullptr);
+        }
+    });
+}
 
 GetCallback
 SecureDht::getCallbackFilter(GetCallback cb, Value::Filter&& filter)
@@ -227,6 +259,7 @@ SecureDht::getCallbackFilter(GetCallback cb, Value::Filter&& filter)
                 try {
                     Value decrypted_val (decrypt(*v));
                     if (decrypted_val.recipient == getId()) {
+                        nodesPubKeys_[decrypted_val.owner->getId()] = decrypted_val.owner;
                         if (not filter or filter(decrypted_val))
                             tmpvals.push_back(std::make_shared<Value>(std::move(decrypted_val)));
                     }
@@ -238,6 +271,7 @@ SecureDht::getCallbackFilter(GetCallback cb, Value::Filter&& filter)
             // Check signed values
             else if (v->isSigned()) {
                 if (v->owner and v->owner->checkSignature(v->getToSign(), v->signature)) {
+                    nodesPubKeys_[v->owner->getId()] = v->owner;
                     if (not filter  or filter(*v))
                         tmpvals.push_back(v);
                 }
@@ -308,15 +342,15 @@ SecureDht::putSigned(const InfoHash& hash, std::shared_ptr<Value> val, DoneCallb
 void
 SecureDht::putEncrypted(const InfoHash& hash, const InfoHash& to, std::shared_ptr<Value> val, DoneCallback callback)
 {
-    findCertificate(to, [=](const std::shared_ptr<crypto::Certificate> crt) {
-        if(!crt || !*crt) {
+    findPublicKey(to, [=](const std::shared_ptr<crypto::PublicKey> pk) {
+        if(!pk || !*pk) {
             if (callback)
                 callback(false, {});
             return;
         }
-        DHT_LOG.WARN("Encrypting data for PK: %s", crt->getPublicKey().getId().toString().c_str());
+        DHT_LOG.WARN("Encrypting data for PK: %s", pk->getId().toString().c_str());
         try {
-            put(hash, encrypt(*val, crt->getPublicKey()), callback);
+            put(hash, encrypt(*val, *pk), callback);
         } catch (const std::exception& e) {
             DHT_LOG.ERROR("Error putting encrypted data: %s", e.what());
             if (callback)
-- 
GitLab