From e68a2fd9672f06fb781f16f68c307c82d2d3d6d9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrien=20B=C3=A9raud?= <adrien.beraud@savoirfairelinux.com> Date: Tue, 17 May 2016 19:31:49 -0400 Subject: [PATCH] securedht: add public key cache --- include/opendht/securedht.h | 3 +++ src/securedht.cpp | 42 +++++++++++++++++++++++++++++++++---- 2 files changed, 41 insertions(+), 4 deletions(-) diff --git a/include/opendht/securedht.h b/include/opendht/securedht.h index d997495c..b61b3238 100644 --- a/include/opendht/securedht.h +++ b/include/opendht/securedht.h @@ -126,11 +126,13 @@ public: Value decrypt(const Value& v); void findCertificate(const InfoHash& node, std::function<void(const std::shared_ptr<crypto::Certificate>)> cb); + void findPublicKey(const InfoHash& node, std::function<void(const std::shared_ptr<crypto::PublicKey>)> cb); const std::shared_ptr<crypto::Certificate> registerCertificate(const InfoHash& node, const Blob& cert); void registerCertificate(std::shared_ptr<crypto::Certificate>& cert); const std::shared_ptr<crypto::Certificate> getCertificate(const InfoHash& node) const; + const std::shared_ptr<crypto::PublicKey> getPublicKey(const InfoHash& node) const; @@ -159,6 +161,7 @@ private: // our certificate cache std::map<InfoHash, std::shared_ptr<crypto::Certificate>> nodesCertificates_ {}; + std::map<InfoHash, std::shared_ptr<crypto::PublicKey>> nodesPubKeys_ {}; std::uniform_int_distribution<Value::Id> rand_id {}; }; diff --git a/src/securedht.cpp b/src/securedht.cpp index 9ace496b..8ca9ac96 100644 --- a/src/securedht.cpp +++ b/src/securedht.cpp @@ -141,6 +141,18 @@ SecureDht::getCertificate(const InfoHash& node) const return it->second; } +const std::shared_ptr<crypto::PublicKey> +SecureDht::getPublicKey(const InfoHash& node) const +{ + if (node == getId()) + return std::make_shared<crypto::PublicKey>(certificate_->getPublicKey()); + auto it = nodesPubKeys_.find(node); + if (it == nodesPubKeys_.end()) + return nullptr; + else + return it->second; +} + const std::shared_ptr<crypto::Certificate> SecureDht::registerCertificate(const InfoHash& node, const Blob& data) { @@ -213,6 +225,26 @@ SecureDht::findCertificate(const InfoHash& node, std::function<void(const std::s }, Value::TypeFilter(CERTIFICATE_TYPE)); } +void +SecureDht::findPublicKey(const InfoHash& node, std::function<void(const std::shared_ptr<crypto::PublicKey>)> cb) +{ + auto pk = getPublicKey(node); + if (pk && *pk) { + DHT_LOG.DEBUG("Found public key from cache for %s", node.toString().c_str()); + if (cb) + cb(pk); + return; + } + findCertificate(node, [=](const std::shared_ptr<crypto::Certificate> crt) { + if (crt && *crt) { + auto pk = std::make_shared<crypto::PublicKey>(crt->getPublicKey()); + nodesPubKeys_[pk->getId()] = pk; + if (cb) cb(pk); + } else { + if (cb) cb(nullptr); + } + }); +} GetCallback SecureDht::getCallbackFilter(GetCallback cb, Value::Filter&& filter) @@ -227,6 +259,7 @@ SecureDht::getCallbackFilter(GetCallback cb, Value::Filter&& filter) try { Value decrypted_val (decrypt(*v)); if (decrypted_val.recipient == getId()) { + nodesPubKeys_[decrypted_val.owner->getId()] = decrypted_val.owner; if (not filter or filter(decrypted_val)) tmpvals.push_back(std::make_shared<Value>(std::move(decrypted_val))); } @@ -238,6 +271,7 @@ SecureDht::getCallbackFilter(GetCallback cb, Value::Filter&& filter) // Check signed values else if (v->isSigned()) { if (v->owner and v->owner->checkSignature(v->getToSign(), v->signature)) { + nodesPubKeys_[v->owner->getId()] = v->owner; if (not filter or filter(*v)) tmpvals.push_back(v); } @@ -308,15 +342,15 @@ SecureDht::putSigned(const InfoHash& hash, std::shared_ptr<Value> val, DoneCallb void SecureDht::putEncrypted(const InfoHash& hash, const InfoHash& to, std::shared_ptr<Value> val, DoneCallback callback) { - findCertificate(to, [=](const std::shared_ptr<crypto::Certificate> crt) { - if(!crt || !*crt) { + findPublicKey(to, [=](const std::shared_ptr<crypto::PublicKey> pk) { + if(!pk || !*pk) { if (callback) callback(false, {}); return; } - DHT_LOG.WARN("Encrypting data for PK: %s", crt->getPublicKey().getId().toString().c_str()); + DHT_LOG.WARN("Encrypting data for PK: %s", pk->getId().toString().c_str()); try { - put(hash, encrypt(*val, crt->getPublicKey()), callback); + put(hash, encrypt(*val, *pk), callback); } catch (const std::exception& e) { DHT_LOG.ERROR("Error putting encrypted data: %s", e.what()); if (callback) -- GitLab