Skip to content
Snippets Groups Projects
Commit dc9a211f authored by Adrien Béraud's avatar Adrien Béraud
Browse files

crypto: protect Id caching with atomic_bool

parent d49b394e
Branches
Tags
No related merge requests found
...@@ -32,6 +32,7 @@ extern "C" { ...@@ -32,6 +32,7 @@ extern "C" {
#include <vector> #include <vector>
#include <memory> #include <memory>
#include <atomic>
#ifdef _WIN32 #ifdef _WIN32
#include <iso646.h> #include <iso646.h>
...@@ -140,6 +141,8 @@ struct OPENDHT_PUBLIC PublicKey ...@@ -140,6 +141,8 @@ struct OPENDHT_PUBLIC PublicKey
private: private:
mutable InfoHash cachedId_ {}; mutable InfoHash cachedId_ {};
mutable PkId cachedLongId_ {}; mutable PkId cachedLongId_ {};
mutable std::atomic_bool idCached_ {false};
mutable std::atomic_bool longIdCached_ {false};
PublicKey(const PublicKey&) = delete; PublicKey(const PublicKey&) = delete;
PublicKey& operator=(const PublicKey&) = delete; PublicKey& operator=(const PublicKey&) = delete;
...@@ -609,6 +612,8 @@ private: ...@@ -609,6 +612,8 @@ private:
Certificate& operator=(const Certificate&) = delete; Certificate& operator=(const Certificate&) = delete;
mutable InfoHash cachedId_ {}; mutable InfoHash cachedId_ {};
mutable PkId cachedLongId_ {}; mutable PkId cachedLongId_ {};
mutable std::atomic_bool idCached_ {false};
mutable std::atomic_bool longIdCached_ {false};
struct crlNumberCmp { struct crlNumberCmp {
bool operator() (const std::shared_ptr<RevocationList>& lhs, const std::shared_ptr<RevocationList>& rhs) const { bool operator() (const std::shared_ptr<RevocationList>& lhs, const std::shared_ptr<RevocationList>& rhs) const {
......
...@@ -529,7 +529,7 @@ PublicKey::encrypt(const uint8_t* data, size_t data_len) const ...@@ -529,7 +529,7 @@ PublicKey::encrypt(const uint8_t* data, size_t data_len) const
const InfoHash& const InfoHash&
PublicKey::getId() const PublicKey::getId() const
{ {
if (pk and not cachedId_) { if (pk and not idCached_.load()) {
InfoHash id; InfoHash id;
size_t sz = id.size(); size_t sz = id.size();
if (auto err = gnutls_pubkey_get_key_id(pk, 0, id.data(), &sz)) if (auto err = gnutls_pubkey_get_key_id(pk, 0, id.data(), &sz))
...@@ -537,6 +537,7 @@ PublicKey::getId() const ...@@ -537,6 +537,7 @@ PublicKey::getId() const
if (sz != id.size()) if (sz != id.size())
throw CryptoException("Can't get public key ID: wrong output length."); throw CryptoException("Can't get public key ID: wrong output length.");
cachedId_ = id; cachedId_ = id;
idCached_.store(true);
} }
return cachedId_; return cachedId_;
} }
...@@ -544,7 +545,7 @@ PublicKey::getId() const ...@@ -544,7 +545,7 @@ PublicKey::getId() const
const PkId& const PkId&
PublicKey::getLongId() const PublicKey::getLongId() const
{ {
if (pk and not cachedLongId_) { if (pk and not longIdCached_.load()) {
PkId h; PkId h;
size_t sz = h.size(); size_t sz = h.size();
if (auto err = gnutls_pubkey_get_key_id(pk, GNUTLS_KEYID_USE_SHA256, h.data(), &sz)) if (auto err = gnutls_pubkey_get_key_id(pk, GNUTLS_KEYID_USE_SHA256, h.data(), &sz))
...@@ -552,6 +553,7 @@ PublicKey::getLongId() const ...@@ -552,6 +553,7 @@ PublicKey::getLongId() const
if (sz != h.size()) if (sz != h.size())
throw CryptoException("Can't get 256 bits public key ID: wrong output length."); throw CryptoException("Can't get 256 bits public key ID: wrong output length.");
cachedLongId_ = h; cachedLongId_ = h;
longIdCached_.store(true);
} }
return cachedLongId_; return cachedLongId_;
} }
...@@ -833,7 +835,7 @@ Certificate::getPublicKey() const ...@@ -833,7 +835,7 @@ Certificate::getPublicKey() const
const InfoHash& const InfoHash&
Certificate::getId() const Certificate::getId() const
{ {
if (cert and not cachedId_) { if (cert and not idCached_.load()) {
InfoHash id; InfoHash id;
size_t sz = id.size(); size_t sz = id.size();
if (auto err = gnutls_x509_crt_get_key_id(cert, 0, id.data(), &sz)) if (auto err = gnutls_x509_crt_get_key_id(cert, 0, id.data(), &sz))
...@@ -841,6 +843,7 @@ Certificate::getId() const ...@@ -841,6 +843,7 @@ Certificate::getId() const
if (sz != id.size()) if (sz != id.size())
throw CryptoException("Can't get certificate public key ID: wrong output length."); throw CryptoException("Can't get certificate public key ID: wrong output length.");
cachedId_ = id; cachedId_ = id;
idCached_.store(true);
} }
return cachedId_; return cachedId_;
} }
...@@ -848,7 +851,7 @@ Certificate::getId() const ...@@ -848,7 +851,7 @@ Certificate::getId() const
const PkId& const PkId&
Certificate::getLongId() const Certificate::getLongId() const
{ {
if (cert and not cachedLongId_) { if (cert and not longIdCached_.load()) {
PkId id; PkId id;
size_t sz = id.size(); size_t sz = id.size();
if (auto err = gnutls_x509_crt_get_key_id(cert, GNUTLS_KEYID_USE_SHA256, id.data(), &sz)) if (auto err = gnutls_x509_crt_get_key_id(cert, GNUTLS_KEYID_USE_SHA256, id.data(), &sz))
...@@ -856,6 +859,7 @@ Certificate::getLongId() const ...@@ -856,6 +859,7 @@ Certificate::getLongId() const
if (sz != id.size()) if (sz != id.size())
throw CryptoException("Can't get certificate 256 bits public key ID: wrong output length."); throw CryptoException("Can't get certificate 256 bits public key ID: wrong output length.");
cachedLongId_ = id; cachedLongId_ = id;
longIdCached_.store(true);
} }
return cachedLongId_; return cachedLongId_;
} }
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment