diff --git a/include/opendht/securedht.h b/include/opendht/securedht.h index d997495cc2e2ad6ab06e4390a2818a6fc1a5d87c..b61b3238932ec8ec870adecbdd8dea13ad1a1f82 100644 --- a/include/opendht/securedht.h +++ b/include/opendht/securedht.h @@ -126,11 +126,13 @@ public: Value decrypt(const Value& v); void findCertificate(const InfoHash& node, std::function<void(const std::shared_ptr<crypto::Certificate>)> cb); + void findPublicKey(const InfoHash& node, std::function<void(const std::shared_ptr<crypto::PublicKey>)> cb); const std::shared_ptr<crypto::Certificate> registerCertificate(const InfoHash& node, const Blob& cert); void registerCertificate(std::shared_ptr<crypto::Certificate>& cert); const std::shared_ptr<crypto::Certificate> getCertificate(const InfoHash& node) const; + const std::shared_ptr<crypto::PublicKey> getPublicKey(const InfoHash& node) const; @@ -159,6 +161,7 @@ private: // our certificate cache std::map<InfoHash, std::shared_ptr<crypto::Certificate>> nodesCertificates_ {}; + std::map<InfoHash, std::shared_ptr<crypto::PublicKey>> nodesPubKeys_ {}; std::uniform_int_distribution<Value::Id> rand_id {}; }; diff --git a/src/securedht.cpp b/src/securedht.cpp index 9ace496b6c2199f19b46cf686c10a3d67b088546..8ca9ac9663fd8548aea96e80a65a81819a2400d1 100644 --- a/src/securedht.cpp +++ b/src/securedht.cpp @@ -141,6 +141,18 @@ SecureDht::getCertificate(const InfoHash& node) const return it->second; } +const std::shared_ptr<crypto::PublicKey> +SecureDht::getPublicKey(const InfoHash& node) const +{ + if (node == getId()) + return std::make_shared<crypto::PublicKey>(certificate_->getPublicKey()); + auto it = nodesPubKeys_.find(node); + if (it == nodesPubKeys_.end()) + return nullptr; + else + return it->second; +} + const std::shared_ptr<crypto::Certificate> SecureDht::registerCertificate(const InfoHash& node, const Blob& data) { @@ -213,6 +225,26 @@ SecureDht::findCertificate(const InfoHash& node, std::function<void(const std::s }, Value::TypeFilter(CERTIFICATE_TYPE)); } +void +SecureDht::findPublicKey(const InfoHash& node, std::function<void(const std::shared_ptr<crypto::PublicKey>)> cb) +{ + auto pk = getPublicKey(node); + if (pk && *pk) { + DHT_LOG.DEBUG("Found public key from cache for %s", node.toString().c_str()); + if (cb) + cb(pk); + return; + } + findCertificate(node, [=](const std::shared_ptr<crypto::Certificate> crt) { + if (crt && *crt) { + auto pk = std::make_shared<crypto::PublicKey>(crt->getPublicKey()); + nodesPubKeys_[pk->getId()] = pk; + if (cb) cb(pk); + } else { + if (cb) cb(nullptr); + } + }); +} GetCallback SecureDht::getCallbackFilter(GetCallback cb, Value::Filter&& filter) @@ -227,6 +259,7 @@ SecureDht::getCallbackFilter(GetCallback cb, Value::Filter&& filter) try { Value decrypted_val (decrypt(*v)); if (decrypted_val.recipient == getId()) { + nodesPubKeys_[decrypted_val.owner->getId()] = decrypted_val.owner; if (not filter or filter(decrypted_val)) tmpvals.push_back(std::make_shared<Value>(std::move(decrypted_val))); } @@ -238,6 +271,7 @@ SecureDht::getCallbackFilter(GetCallback cb, Value::Filter&& filter) // Check signed values else if (v->isSigned()) { if (v->owner and v->owner->checkSignature(v->getToSign(), v->signature)) { + nodesPubKeys_[v->owner->getId()] = v->owner; if (not filter or filter(*v)) tmpvals.push_back(v); } @@ -308,15 +342,15 @@ SecureDht::putSigned(const InfoHash& hash, std::shared_ptr<Value> val, DoneCallb void SecureDht::putEncrypted(const InfoHash& hash, const InfoHash& to, std::shared_ptr<Value> val, DoneCallback callback) { - findCertificate(to, [=](const std::shared_ptr<crypto::Certificate> crt) { - if(!crt || !*crt) { + findPublicKey(to, [=](const std::shared_ptr<crypto::PublicKey> pk) { + if(!pk || !*pk) { if (callback) callback(false, {}); return; } - DHT_LOG.WARN("Encrypting data for PK: %s", crt->getPublicKey().getId().toString().c_str()); + DHT_LOG.WARN("Encrypting data for PK: %s", pk->getId().toString().c_str()); try { - put(hash, encrypt(*val, crt->getPublicKey()), callback); + put(hash, encrypt(*val, *pk), callback); } catch (const std::exception& e) { DHT_LOG.ERROR("Error putting encrypted data: %s", e.what()); if (callback)