From 76a2bd234a3e45cbaef82a74dadcbc9d02a635c3 Mon Sep 17 00:00:00 2001
From: Matthew Webb <mwebbmwebb@gmail.com>
Date: Wed, 10 Jul 2024 00:17:04 -0700
Subject: [PATCH] Get off of deprecated GCM AES methods
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit

Multiple compiler warnings note that the gcm_aes_* family of functions are deprecated. They have been replaced with gcm_aes<key_length>_*. This change uses the correct set of functions based on the given key size. Resolves #571.

Sample compiler warning:
/home/noviv/opendht/src/crypto.cpp: In function ‘dht::Blob dht::crypto::aesEncrypt(const uint8_t*, size_t, const dht::Blob&)’: /home/noviv/opendht/src/crypto.cpp:97:20: warning: ‘void nettle_gcm_aes_set_key(gcm_aes_ctx*, size_t, const uint8_t*)’ is deprecated [-Wdeprecated-declarations]
   97 |     gcm_aes_set_key(&aes, key.size(), key.data());
         |                    ^
         In file included from /home/noviv/opendht/src/crypto.cpp:27:
         /usr/include/nettle/gcm.h:276:1: note: declared here
           276 | gcm_aes_set_key(struct gcm_aes_ctx *ctx,
                 | ^~~~~~~~~~~~~~~)

https://github.com/gnutls/nettle/commit/6a19845e6f71791ca98765d490ec08e776494bee marked the functions as deprecated.
---
 src/crypto.cpp         | 52 +++++++++++++++++++++++++++++++++---------
 tests/cryptotester.cpp | 21 +++++++++++++++++
 tests/cryptotester.h   |  2 ++
 3 files changed, 64 insertions(+), 11 deletions(-)

diff --git a/src/crypto.cpp b/src/crypto.cpp
index b578702c..9c5f2c45 100644
--- a/src/crypto.cpp
+++ b/src/crypto.cpp
@@ -93,11 +93,27 @@ Blob aesEncrypt(const uint8_t* data, size_t data_length, const Blob& key)
         std::random_device rdev;
         std::generate_n(ret.begin(), GCM_IV_SIZE, std::bind(rand_byte, std::ref(rdev)));
     }
-    struct gcm_aes_ctx aes;
-    gcm_aes_set_key(&aes, key.size(), key.data());
-    gcm_aes_set_iv(&aes, GCM_IV_SIZE, ret.data());
-    gcm_aes_encrypt(&aes, data_length, ret.data() + GCM_IV_SIZE, data);
-    gcm_aes_digest(&aes, GCM_DIGEST_SIZE, ret.data() + GCM_IV_SIZE + data_length);
+
+    if (key.size() == AES_LENGTHS[0]) {
+        struct gcm_aes128_ctx aes;
+        gcm_aes128_set_key(&aes, key.data());
+        gcm_aes128_set_iv(&aes, GCM_IV_SIZE, ret.data());
+        gcm_aes128_encrypt(&aes, data_length, ret.data() + GCM_IV_SIZE, data);
+        gcm_aes128_digest(&aes, GCM_DIGEST_SIZE, ret.data() + GCM_IV_SIZE + data_length);
+    } else if (key.size() == AES_LENGTHS[1]) {
+        struct gcm_aes192_ctx aes;
+        gcm_aes192_set_key(&aes, key.data());
+        gcm_aes192_set_iv(&aes, GCM_IV_SIZE, ret.data());
+        gcm_aes192_encrypt(&aes, data_length, ret.data() + GCM_IV_SIZE, data);
+        gcm_aes192_digest(&aes, GCM_DIGEST_SIZE, ret.data() + GCM_IV_SIZE + data_length);
+    } else if (key.size() == AES_LENGTHS[2]) {
+        struct gcm_aes256_ctx aes;
+        gcm_aes256_set_key(&aes, key.data());
+        gcm_aes256_set_iv(&aes, GCM_IV_SIZE, ret.data());
+        gcm_aes256_encrypt(&aes, data_length, ret.data() + GCM_IV_SIZE, data);
+        gcm_aes256_digest(&aes, GCM_DIGEST_SIZE, ret.data() + GCM_IV_SIZE + data_length);
+    }
+
     return ret;
 }
 
@@ -118,14 +134,28 @@ Blob aesDecrypt(const uint8_t* data, size_t data_length, const Blob& key)
 
     std::array<uint8_t, GCM_DIGEST_SIZE> digest;
 
-    struct gcm_aes_ctx aes;
-    gcm_aes_set_key(&aes, key.size(), key.data());
-    gcm_aes_set_iv(&aes, GCM_IV_SIZE, data);
-
     size_t data_sz = data_length - GCM_IV_SIZE - GCM_DIGEST_SIZE;
     Blob ret(data_sz);
-    gcm_aes_decrypt(&aes, data_sz, ret.data(), data + GCM_IV_SIZE);
-    gcm_aes_digest(&aes, GCM_DIGEST_SIZE, digest.data());
+
+    if (key.size() == AES_LENGTHS[0]) {
+        struct gcm_aes128_ctx aes;
+        gcm_aes128_set_key(&aes, key.data());
+        gcm_aes128_set_iv(&aes, GCM_IV_SIZE, data);
+        gcm_aes128_decrypt(&aes, data_sz, ret.data(), data + GCM_IV_SIZE);
+        gcm_aes128_digest(&aes, GCM_DIGEST_SIZE, digest.data());
+    } else if (key.size() == AES_LENGTHS[1]) {
+        struct gcm_aes192_ctx aes;
+        gcm_aes192_set_key(&aes, key.data());
+        gcm_aes192_set_iv(&aes, GCM_IV_SIZE, data);
+        gcm_aes192_decrypt(&aes, data_sz, ret.data(), data + GCM_IV_SIZE);
+        gcm_aes192_digest(&aes, GCM_DIGEST_SIZE, digest.data());
+    } else if (key.size() == AES_LENGTHS[2]) {
+        struct gcm_aes256_ctx aes;
+        gcm_aes256_set_key(&aes, key.data());
+        gcm_aes256_set_iv(&aes, GCM_IV_SIZE, data);
+        gcm_aes256_decrypt(&aes, data_sz, ret.data(), data + GCM_IV_SIZE);
+        gcm_aes256_digest(&aes, GCM_DIGEST_SIZE, digest.data());
+    }
 
     if (not std::equal(digest.begin(), digest.end(), data + data_length - GCM_DIGEST_SIZE)) {
         throw DecryptError("Can't decrypt data");
diff --git a/tests/cryptotester.cpp b/tests/cryptotester.cpp
index c33a038f..1673784c 100644
--- a/tests/cryptotester.cpp
+++ b/tests/cryptotester.cpp
@@ -239,6 +239,27 @@ void CryptoTester::testAesEncryption() {
     CPPUNIT_ASSERT(data2 == decrypted2);
 }
 
+void CryptoTester::testAesEncryptionWithMultipleKeySizes() {
+    auto data = std::vector<uint8_t>(rand(), rand());
+
+    // Valid key sizes
+    for (auto key_length : {16, 24, 32}) {
+        auto key = std::vector<uint8_t>(key_length, rand());
+
+        auto encrypted_data = dht::crypto::aesEncrypt(data, key);
+        auto decrypted_data = dht::crypto::aesDecrypt(encrypted_data, key);
+
+        CPPUNIT_ASSERT(data == decrypted_data);
+    }
+
+    // Invalid key sizes
+    for (auto key_length : {12, 28, 36}) {
+        auto key = std::vector<uint8_t>(key_length, rand());
+
+        CPPUNIT_ASSERT_THROW(dht::crypto::aesEncrypt(data, key), dht::crypto::DecryptError);
+    }
+}
+
 void
 CryptoTester::tearDown() {
 
diff --git a/tests/cryptotester.h b/tests/cryptotester.h
index 89019057..6cd552c4 100644
--- a/tests/cryptotester.h
+++ b/tests/cryptotester.h
@@ -34,6 +34,7 @@ class CryptoTester : public CppUnit::TestFixture {
     CPPUNIT_TEST(testCertificateSerialNumber);
     CPPUNIT_TEST(testOcsp);
     CPPUNIT_TEST(testAesEncryption);
+    CPPUNIT_TEST(testAesEncryptionWithMultipleKeySizes);
     CPPUNIT_TEST_SUITE_END();
 
  public:
@@ -69,6 +70,7 @@ class CryptoTester : public CppUnit::TestFixture {
      * Test key streching and aes encryption/decryption
      */
     void testAesEncryption();
+    void testAesEncryptionWithMultipleKeySizes();
 };
 
 }  // namespace test
-- 
GitLab