From dc9a211fb0568d2ac3fc14f692eed78795be6706 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Adrien=20B=C3=A9raud?= <adrien.beraud@savoirfairelinux.com>
Date: Mon, 27 Jun 2022 15:27:31 -0400
Subject: [PATCH] crypto: protect Id caching with atomic_bool

---
 include/opendht/crypto.h |  5 +++++
 src/crypto.cpp           | 12 ++++++++----
 2 files changed, 13 insertions(+), 4 deletions(-)

diff --git a/include/opendht/crypto.h b/include/opendht/crypto.h
index 718a650d..64d43092 100644
--- a/include/opendht/crypto.h
+++ b/include/opendht/crypto.h
@@ -32,6 +32,7 @@ extern "C" {
 
 #include <vector>
 #include <memory>
+#include <atomic>
 
 #ifdef _WIN32
 #include <iso646.h>
@@ -140,6 +141,8 @@ struct OPENDHT_PUBLIC PublicKey
 private:
     mutable InfoHash cachedId_ {};
     mutable PkId cachedLongId_ {};
+    mutable std::atomic_bool idCached_ {false};
+    mutable std::atomic_bool longIdCached_ {false};
 
     PublicKey(const PublicKey&) = delete;
     PublicKey& operator=(const PublicKey&) = delete;
@@ -609,6 +612,8 @@ private:
     Certificate& operator=(const Certificate&) = delete;
     mutable InfoHash cachedId_ {};
     mutable PkId cachedLongId_ {};
+    mutable std::atomic_bool idCached_ {false};
+    mutable std::atomic_bool longIdCached_ {false};
 
     struct crlNumberCmp {
         bool operator() (const std::shared_ptr<RevocationList>& lhs, const std::shared_ptr<RevocationList>& rhs) const {
diff --git a/src/crypto.cpp b/src/crypto.cpp
index 9840c043..f1b18c15 100644
--- a/src/crypto.cpp
+++ b/src/crypto.cpp
@@ -529,7 +529,7 @@ PublicKey::encrypt(const uint8_t* data, size_t data_len) const
 const InfoHash&
 PublicKey::getId() const
 {
-    if (pk and not cachedId_) {
+    if (pk and not idCached_.load()) {
         InfoHash id;
         size_t sz = id.size();
         if (auto err = gnutls_pubkey_get_key_id(pk, 0, id.data(), &sz))
@@ -537,6 +537,7 @@ PublicKey::getId() const
         if (sz != id.size())
             throw CryptoException("Can't get public key ID: wrong output length.");
         cachedId_ = id;
+        idCached_.store(true);
     }
     return cachedId_;
 }
@@ -544,7 +545,7 @@ PublicKey::getId() const
 const PkId&
 PublicKey::getLongId() const
 {
-    if (pk and not cachedLongId_) {
+    if (pk and not longIdCached_.load()) {
         PkId h;
         size_t sz = h.size();
         if (auto err = gnutls_pubkey_get_key_id(pk, GNUTLS_KEYID_USE_SHA256, h.data(), &sz))
@@ -552,6 +553,7 @@ PublicKey::getLongId() const
         if (sz != h.size())
             throw CryptoException("Can't get 256 bits public key ID: wrong output length.");
         cachedLongId_ = h;
+        longIdCached_.store(true);
     }
     return cachedLongId_;
 }
@@ -833,7 +835,7 @@ Certificate::getPublicKey() const
 const InfoHash&
 Certificate::getId() const
 {
-    if (cert and not cachedId_) {
+    if (cert and not idCached_.load()) {
         InfoHash id;
         size_t sz = id.size();
         if (auto err = gnutls_x509_crt_get_key_id(cert, 0, id.data(), &sz))
@@ -841,6 +843,7 @@ Certificate::getId() const
         if (sz != id.size())
             throw CryptoException("Can't get certificate public key ID: wrong output length.");
         cachedId_ = id;
+        idCached_.store(true);
     }
     return cachedId_;
 }
@@ -848,7 +851,7 @@ Certificate::getId() const
 const PkId&
 Certificate::getLongId() const
 {
-    if (cert and not cachedLongId_) {
+    if (cert and not longIdCached_.load()) {
         PkId id;
         size_t sz = id.size();
         if (auto err = gnutls_x509_crt_get_key_id(cert, GNUTLS_KEYID_USE_SHA256, id.data(), &sz))
@@ -856,6 +859,7 @@ Certificate::getLongId() const
         if (sz != id.size())
             throw CryptoException("Can't get certificate 256 bits public key ID: wrong output length.");
         cachedLongId_ = id;
+        longIdCached_.store(true);
     }
     return cachedLongId_;
 }
-- 
GitLab