diff --git a/include/opendht/crypto.h b/include/opendht/crypto.h index 4c67dbecc986e4e75db40f6bc8a380aa7d70db7c..718a650decff4fab765c7ca8e03818d6663cef2c 100644 --- a/include/opendht/crypto.h +++ b/include/opendht/crypto.h @@ -100,12 +100,12 @@ struct OPENDHT_PUBLIC PublicKey /** * Get public key fingerprint */ - InfoHash getId() const; + const InfoHash& getId() const; /** * Get public key long fingerprint */ - PkId getLongId() const; + const PkId& getLongId() const; bool checkSignature(const uint8_t* data, size_t data_len, const uint8_t* signature, size_t signature_len) const; inline bool checkSignature(const Blob& data, const Blob& signature) const { @@ -138,6 +138,9 @@ struct OPENDHT_PUBLIC PublicKey gnutls_pubkey_t pk {nullptr}; private: + mutable InfoHash cachedId_ {}; + mutable PkId cachedLongId_ {}; + PublicKey(const PublicKey&) = delete; PublicKey& operator=(const PublicKey&) = delete; void encryptBloc(const uint8_t* src, size_t src_size, uint8_t* dst, size_t dst_size) const; @@ -486,9 +489,9 @@ struct OPENDHT_PUBLIC Certificate { PublicKey getPublicKey() const; /** Same as getPublicKey().getId() */ - InfoHash getId() const; + const InfoHash& getId() const; /** Same as getPublicKey().getLongId() */ - PkId getLongId() const; + const PkId& getLongId() const; Blob getSerialNumber() const; diff --git a/src/crypto.cpp b/src/crypto.cpp index 5a1b8ed2b522b02580cc4b9cd7630ccda924b120..9840c04339a923a8d08ecb51b30e1d9883d4f68e 100644 --- a/src/crypto.cpp +++ b/src/crypto.cpp @@ -361,12 +361,12 @@ PrivateKey::getPublicKey() const const std::shared_ptr<PublicKey>& PrivateKey::getSharedPublicKey() const { - if (publicKey_) - return publicKey_; - auto pk = std::make_shared<PublicKey>(); - if (auto err = gnutls_pubkey_import_privkey(pk->pk, key, GNUTLS_KEY_KEY_CERT_SIGN | GNUTLS_KEY_CRL_SIGN, 0)) - throw CryptoException(std::string("Can't retreive public key: ") + gnutls_strerror(err)); - publicKey_ = pk; + if (not publicKey_) { + auto pk = std::make_shared<PublicKey>(); + if (auto err = gnutls_pubkey_import_privkey(pk->pk, key, GNUTLS_KEY_KEY_CERT_SIGN | GNUTLS_KEY_CRL_SIGN, 0)) + throw CryptoException(std::string("Can't retreive public key: ") + gnutls_strerror(err)); + publicKey_ = pk; + } return publicKey_; } @@ -526,36 +526,34 @@ PublicKey::encrypt(const uint8_t* data, size_t data_len) const return ret; } -InfoHash +const InfoHash& PublicKey::getId() const { - if (not pk) - return {}; - InfoHash id; - size_t sz = id.size(); - if (auto err = gnutls_pubkey_get_key_id(pk, 0, id.data(), &sz)) - throw CryptoException(std::string("Can't get public key ID: ") + gnutls_strerror(err)); - if (sz != id.size()) - throw CryptoException("Can't get public key ID: wrong output length."); - return id; + if (pk and not cachedId_) { + InfoHash id; + size_t sz = id.size(); + if (auto err = gnutls_pubkey_get_key_id(pk, 0, id.data(), &sz)) + throw CryptoException(std::string("Can't get public key ID: ") + gnutls_strerror(err)); + if (sz != id.size()) + throw CryptoException("Can't get public key ID: wrong output length."); + cachedId_ = id; + } + return cachedId_; } -PkId +const PkId& PublicKey::getLongId() const { - if (not pk) - return {}; -#if GNUTLS_VERSION_NUMBER < 0x030401 - throw CryptoException("Can't get 256 bits public key ID: GnuTLS 3.4.1 or higher required."); -#else - PkId h; - size_t sz = h.size(); - if (auto err = gnutls_pubkey_get_key_id(pk, GNUTLS_KEYID_USE_SHA256, h.data(), &sz)) - throw CryptoException(std::string("Can't get 256 bits public key ID: ") + gnutls_strerror(err)); - if (sz != h.size()) - throw CryptoException("Can't get 256 bits public key ID: wrong output length."); - return h; -#endif + if (pk and not cachedLongId_) { + PkId h; + size_t sz = h.size(); + if (auto err = gnutls_pubkey_get_key_id(pk, GNUTLS_KEYID_USE_SHA256, h.data(), &sz)) + throw CryptoException(std::string("Can't get 256 bits public key ID: ") + gnutls_strerror(err)); + if (sz != h.size()) + throw CryptoException("Can't get 256 bits public key ID: wrong output length."); + cachedLongId_ = h; + } + return cachedLongId_; } gnutls_digest_algorithm_t @@ -832,42 +830,34 @@ Certificate::getPublicKey() const return pk_ret; } -InfoHash +const InfoHash& Certificate::getId() const { - if (not cert) - return {}; - if (cachedId_) - return cachedId_; - InfoHash id; - size_t sz = id.size(); - if (auto err = gnutls_x509_crt_get_key_id(cert, 0, id.data(), &sz)) - throw CryptoException(std::string("Can't get certificate public key ID: ") + gnutls_strerror(err)); - if (sz != id.size()) - throw CryptoException("Can't get certificate public key ID: wrong output length."); - cachedId_ = id; - return id; -} - -PkId + if (cert and not cachedId_) { + InfoHash id; + size_t sz = id.size(); + if (auto err = gnutls_x509_crt_get_key_id(cert, 0, id.data(), &sz)) + throw CryptoException(std::string("Can't get certificate public key ID: ") + gnutls_strerror(err)); + if (sz != id.size()) + throw CryptoException("Can't get certificate public key ID: wrong output length."); + cachedId_ = id; + } + return cachedId_; +} + +const PkId& Certificate::getLongId() const { - if (not cert) - return {}; - if (cachedLongId_) - return cachedLongId_; -#if GNUTLS_VERSION_NUMBER < 0x030401 - throw CryptoException("Can't get certificate 256 bits public key ID: GnuTLS 3.4.1 or higher required."); -#else - PkId id; - size_t sz = id.size(); - if (auto err = gnutls_x509_crt_get_key_id(cert, GNUTLS_KEYID_USE_SHA256, id.data(), &sz)) - throw CryptoException(std::string("Can't get certificate 256 bits public key ID: ") + gnutls_strerror(err)); - if (sz != id.size()) - throw CryptoException("Can't get certificate 256 bits public key ID: wrong output length."); - cachedLongId_ = id; - return id; -#endif + if (cert and not cachedLongId_) { + PkId id; + size_t sz = id.size(); + if (auto err = gnutls_x509_crt_get_key_id(cert, GNUTLS_KEYID_USE_SHA256, id.data(), &sz)) + throw CryptoException(std::string("Can't get certificate 256 bits public key ID: ") + gnutls_strerror(err)); + if (sz != id.size()) + throw CryptoException("Can't get certificate 256 bits public key ID: wrong output length."); + cachedLongId_ = id; + } + return cachedLongId_; } Blob