From dc9a211fb0568d2ac3fc14f692eed78795be6706 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrien=20B=C3=A9raud?= <adrien.beraud@savoirfairelinux.com> Date: Mon, 27 Jun 2022 15:27:31 -0400 Subject: [PATCH] crypto: protect Id caching with atomic_bool --- include/opendht/crypto.h | 5 +++++ src/crypto.cpp | 12 ++++++++---- 2 files changed, 13 insertions(+), 4 deletions(-) diff --git a/include/opendht/crypto.h b/include/opendht/crypto.h index 718a650d..64d43092 100644 --- a/include/opendht/crypto.h +++ b/include/opendht/crypto.h @@ -32,6 +32,7 @@ extern "C" { #include <vector> #include <memory> +#include <atomic> #ifdef _WIN32 #include <iso646.h> @@ -140,6 +141,8 @@ struct OPENDHT_PUBLIC PublicKey private: mutable InfoHash cachedId_ {}; mutable PkId cachedLongId_ {}; + mutable std::atomic_bool idCached_ {false}; + mutable std::atomic_bool longIdCached_ {false}; PublicKey(const PublicKey&) = delete; PublicKey& operator=(const PublicKey&) = delete; @@ -609,6 +612,8 @@ private: Certificate& operator=(const Certificate&) = delete; mutable InfoHash cachedId_ {}; mutable PkId cachedLongId_ {}; + mutable std::atomic_bool idCached_ {false}; + mutable std::atomic_bool longIdCached_ {false}; struct crlNumberCmp { bool operator() (const std::shared_ptr<RevocationList>& lhs, const std::shared_ptr<RevocationList>& rhs) const { diff --git a/src/crypto.cpp b/src/crypto.cpp index 9840c043..f1b18c15 100644 --- a/src/crypto.cpp +++ b/src/crypto.cpp @@ -529,7 +529,7 @@ PublicKey::encrypt(const uint8_t* data, size_t data_len) const const InfoHash& PublicKey::getId() const { - if (pk and not cachedId_) { + if (pk and not idCached_.load()) { InfoHash id; size_t sz = id.size(); if (auto err = gnutls_pubkey_get_key_id(pk, 0, id.data(), &sz)) @@ -537,6 +537,7 @@ PublicKey::getId() const if (sz != id.size()) throw CryptoException("Can't get public key ID: wrong output length."); cachedId_ = id; + idCached_.store(true); } return cachedId_; } @@ -544,7 +545,7 @@ PublicKey::getId() const const PkId& PublicKey::getLongId() const { - if (pk and not cachedLongId_) { + if (pk and not longIdCached_.load()) { PkId h; size_t sz = h.size(); if (auto err = gnutls_pubkey_get_key_id(pk, GNUTLS_KEYID_USE_SHA256, h.data(), &sz)) @@ -552,6 +553,7 @@ PublicKey::getLongId() const if (sz != h.size()) throw CryptoException("Can't get 256 bits public key ID: wrong output length."); cachedLongId_ = h; + longIdCached_.store(true); } return cachedLongId_; } @@ -833,7 +835,7 @@ Certificate::getPublicKey() const const InfoHash& Certificate::getId() const { - if (cert and not cachedId_) { + if (cert and not idCached_.load()) { InfoHash id; size_t sz = id.size(); if (auto err = gnutls_x509_crt_get_key_id(cert, 0, id.data(), &sz)) @@ -841,6 +843,7 @@ Certificate::getId() const if (sz != id.size()) throw CryptoException("Can't get certificate public key ID: wrong output length."); cachedId_ = id; + idCached_.store(true); } return cachedId_; } @@ -848,7 +851,7 @@ Certificate::getId() const const PkId& Certificate::getLongId() const { - if (cert and not cachedLongId_) { + if (cert and not longIdCached_.load()) { PkId id; size_t sz = id.size(); if (auto err = gnutls_x509_crt_get_key_id(cert, GNUTLS_KEYID_USE_SHA256, id.data(), &sz)) @@ -856,6 +859,7 @@ Certificate::getLongId() const if (sz != id.size()) throw CryptoException("Can't get certificate 256 bits public key ID: wrong output length."); cachedLongId_ = id; + longIdCached_.store(true); } return cachedLongId_; } -- GitLab