From cdcaebbc80d01e94443742b8e234b81c580413e7 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Adrien=20B=C3=A9raud?= <adrien.beraud@savoirfairelinux.com>
Date: Mon, 3 Aug 2015 15:32:35 -0400
Subject: [PATCH] crypto: use AES-GCM + RSA for encryption rather than RSA-ECB

---
 include/opendht/crypto.h |  12 +++
 src/crypto.cpp           | 170 +++++++++++++++++++++++++++++++--------
 2 files changed, 150 insertions(+), 32 deletions(-)

diff --git a/include/opendht/crypto.h b/include/opendht/crypto.h
index 59f8fcfd..e1daf256 100644
--- a/include/opendht/crypto.h
+++ b/include/opendht/crypto.h
@@ -110,6 +110,7 @@ struct PublicKey
 private:
     PublicKey(const PublicKey&) = delete;
     PublicKey& operator=(const PublicKey&) = delete;
+    void encryptBloc(const uint8_t* src, size_t src_size, uint8_t* dst, size_t dst_size) const;
 };
 
 /**
@@ -145,6 +146,7 @@ struct PrivateKey
     /**
      * Generate a new RSA key pair
      * @param key_length : size of the modulus in bits
+     *      Minimim value: 2048
      *      Recommended values: 4096, 8192
      */
     static PrivateKey generate(unsigned key_length = 4096);
@@ -154,6 +156,7 @@ struct PrivateKey
 private:
     PrivateKey(const PrivateKey&) = delete;
     PrivateKey& operator=(const PrivateKey&) = delete;
+    Blob decryptBloc(const uint8_t* src, size_t src_size) const;
 
     friend dht::crypto::Identity dht::crypto::generateIdentity(const std::string&, dht::crypto::Identity, unsigned key_length);
 };
@@ -301,6 +304,15 @@ private:
     friend dht::crypto::Identity dht::crypto::generateIdentity(const std::string&, dht::crypto::Identity, unsigned key_length);
 };
 
+/**
+ * AES-GCM encryption. Key must be 128, 192 or 126 bits long (16, 24 or 32 bytes).
+ */
+Blob aesEncrypt(const Blob& data, const Blob& key);
+
+/**
+ * AES-GCM decryption.
+ */
+Blob aesDecrypt(const Blob& data, const Blob& key);
 
 }
 }
diff --git a/src/crypto.cpp b/src/crypto.cpp
index e0125964..c6bf16af 100644
--- a/src/crypto.cpp
+++ b/src/crypto.cpp
@@ -35,14 +35,17 @@ extern "C" {
 #include <gnutls/gnutls.h>
 #include <gnutls/abstract.h>
 #include <gnutls/x509.h>
+#include <nettle/gcm.h>
+#include <nettle/aes.h>
 }
 
 #include <random>
 #include <sstream>
-#include <random>
 #include <stdexcept>
 #include <cassert>
 
+static std::uniform_int_distribution<uint8_t> rand_byte;
+
 static gnutls_digest_algorithm_t get_dig_for_pub(gnutls_pubkey_t pubkey)
 {
     gnutls_digest_algorithm_t dig;
@@ -82,6 +85,87 @@ static gnutls_digest_algorithm_t get_dig(gnutls_x509_crt_t crt)
 namespace dht {
 namespace crypto {
 
+static constexpr std::array<size_t, 3> AES_LENGTHS {128/8, 192/8, 256/8};
+
+size_t aesKeySize(size_t max)
+{
+    unsigned aes_key_len = 0;
+    for (size_t s = 0; s < AES_LENGTHS.size(); s++) {
+        if (AES_LENGTHS[s] <= max)
+            aes_key_len = AES_LENGTHS[s];
+        else break;
+    }
+    return aes_key_len;
+}
+
+bool aesKeySizeGood(size_t key_size)
+{
+    for (auto& i : AES_LENGTHS)
+        if (key_size == i)
+            return true;
+    return false;
+}
+
+#ifndef GCM_DIGEST_SIZE
+#define GCM_DIGEST_SIZE GCM_BLOCK_SIZE
+#endif
+
+Blob
+aesEncrypt(const Blob& data, const Blob& key)
+{
+    std::array<uint8_t, GCM_IV_SIZE> iv;
+    {
+        crypto::random_device rdev;
+        std::generate_n(iv.begin(), iv.size(), std::bind(rand_byte, std::ref(rdev)));
+    }
+    struct gcm_aes_ctx aes;
+    gcm_aes_set_key(&aes, key.size(), key.data());
+    gcm_aes_set_iv(&aes, iv.size(), iv.data());
+    gcm_aes_update(&aes, data.size(), data.data());
+
+    Blob ret(data.size() + GCM_IV_SIZE + GCM_DIGEST_SIZE);
+    std::copy(iv.begin(), iv.end(), ret.begin());
+    gcm_aes_encrypt(&aes, data.size(), ret.data() + GCM_IV_SIZE, data.data());
+    gcm_aes_digest(&aes, GCM_DIGEST_SIZE, ret.data() + GCM_IV_SIZE + data.size());
+    return ret;
+}
+
+Blob
+aesDecrypt(const Blob& data, const Blob& key)
+{
+    if (not aesKeySizeGood(key.size()))
+        throw DecryptError("Wrong key size");
+
+    if (data.size() <= GCM_IV_SIZE + GCM_DIGEST_SIZE)
+        throw DecryptError("Wrong data size");
+
+    std::array<uint8_t, GCM_DIGEST_SIZE> digest;
+
+    struct gcm_aes_ctx aes;
+    gcm_aes_set_key(&aes, key.size(), key.data());
+    gcm_aes_set_iv(&aes, GCM_IV_SIZE, data.data());
+
+    size_t data_sz = data.size() - GCM_IV_SIZE - GCM_DIGEST_SIZE;
+    Blob ret(data_sz);
+    //gcm_aes_update(&aes, data_sz, data.data() + GCM_IV_SIZE);
+    gcm_aes_decrypt(&aes, data_sz, ret.data(), data.data() + GCM_IV_SIZE);
+    //gcm_aes_digest(aes, GCM_DIGEST_SIZE, digest.data());
+
+    // TODO compute the proper digest directly from the decryption pass
+    Blob ret_tmp(data_sz);
+    struct gcm_aes_ctx aes_d;
+    gcm_aes_set_key(&aes_d, key.size(), key.data());
+    gcm_aes_set_iv(&aes_d, GCM_IV_SIZE, data.data());
+    gcm_aes_update(&aes_d, ret.size() , ret.data());
+    gcm_aes_encrypt(&aes_d, ret.size(), ret_tmp.data(), ret.data());
+    gcm_aes_digest(&aes_d, GCM_DIGEST_SIZE, digest.data());
+
+    if (not std::equal(digest.begin(), digest.end(), data.end() - GCM_DIGEST_SIZE))
+        throw DecryptError("Can't decrypt data");
+
+    return ret;
+}
+
 PrivateKey::PrivateKey()
 {
 #if GNUTLS_VERSION_NUMBER < 0x030300
@@ -194,6 +278,19 @@ PrivateKey::sign(const Blob& data) const
     return ret;
 }
 
+Blob
+PrivateKey::decryptBloc(const uint8_t* src, size_t src_size) const
+{
+    const gnutls_datum_t dat {(uint8_t*)src, (unsigned)src_size};
+    gnutls_datum_t out;
+    int err = gnutls_privkey_decrypt_data(key, 0, &dat, &out);
+    if (err != GNUTLS_E_SUCCESS)
+        throw DecryptError(std::string("Can't decrypt data: ") + gnutls_strerror(err));
+    Blob ret {out.data, out.data+out.size};
+    gnutls_free(out.data);
+    return ret;
+}
+
 Blob
 PrivateKey::decrypt(const Blob& cipher) const
 {
@@ -208,20 +305,12 @@ PrivateKey::decrypt(const Blob& cipher) const
         throw CryptoException("Must be an RSA key");
 
     unsigned cypher_block_sz = key_len / 8;
-    if (cipher.size() % cypher_block_sz)
-        throw CryptoException("Unexpected cipher length");
+    if (cipher.size() < cypher_block_sz)
+        throw DecryptError("Unexpected cipher length");
+    else if (cipher.size() == cypher_block_sz)
+        return decryptBloc(cipher.data(), cypher_block_sz);
 
-    Blob ret;
-    for (auto cb = cipher.cbegin(), ce = cipher.cend(); cb < ce; cb += cypher_block_sz) {
-        const gnutls_datum_t dat {(uint8_t*)(&(*cb)), cypher_block_sz};
-        gnutls_datum_t out;
-        int err = gnutls_privkey_decrypt_data(key, 0, &dat, &out);
-        if (err != GNUTLS_E_SUCCESS)
-            throw DecryptError(std::string("Can't decrypt data: ") + gnutls_strerror(err));
-        ret.insert(ret.end(), out.data, out.data+out.size);
-        gnutls_free(out.data);
-    }
-    return ret;
+    return aesDecrypt(Blob {cipher.begin() + cypher_block_sz, cipher.end()}, decryptBloc(cipher.data(), cypher_block_sz));
 }
 
 Blob
@@ -321,6 +410,20 @@ PublicKey::checkSignature(const Blob& data, const Blob& signature) const {
     return rc >= 0;
 }
 
+void
+PublicKey::encryptBloc(const uint8_t* src, size_t src_size, uint8_t* dst, size_t dst_size) const
+{
+    const gnutls_datum_t key_dat {(uint8_t*)src, (unsigned)src_size};
+    gnutls_datum_t encrypted;
+    auto err = gnutls_pubkey_encrypt_data(pk, 0, &key_dat, &encrypted);
+    if (err != GNUTLS_E_SUCCESS)
+        throw CryptoException(std::string("Can't encrypt data: ") + gnutls_strerror(err));
+    if (encrypted.size != dst_size)
+        throw CryptoException("Unexpected cypherblock size");
+    std::copy_n(encrypted.data, encrypted.size, dst);
+    gnutls_free(encrypted.data);
+}
+
 Blob
 PublicKey::encrypt(const Blob& data) const
 {
@@ -334,27 +437,30 @@ PublicKey::encrypt(const Blob& data) const
     if (err != GNUTLS_PK_RSA)
         throw CryptoException("Must be an RSA key");
 
-    unsigned max_block_sz = key_len / 8 - 11;
-    unsigned cypher_block_sz = key_len / 8;
-    unsigned block_num = data.empty() ? 1 : 1 + (data.size() - 1) / max_block_sz;
+    const unsigned max_block_sz = key_len / 8 - 11;
+    const unsigned cypher_block_sz = key_len / 8;
+    if (data.size() <= max_block_sz) {
+        Blob ret(cypher_block_sz);
+        encryptBloc(data.data(), data.size(), ret.data(), cypher_block_sz);
+        return ret;
+    }
 
-    Blob ret;
-    auto eb = data.cbegin();
-    auto ee = data.cend();
-    for (unsigned i=0; i<block_num; i++) {
-        auto blk_sz = std::min<unsigned>(ee - eb, max_block_sz);
-        const gnutls_datum_t dat {(uint8_t*)&(*eb), blk_sz};
-        gnutls_datum_t encrypted;
-        err = gnutls_pubkey_encrypt_data(pk, 0, &dat, &encrypted);
-        if (err != GNUTLS_E_SUCCESS)
-            throw CryptoException(std::string("Can't encrypt data: ") + gnutls_strerror(err));
-        if (encrypted.size != cypher_block_sz)
-            throw CryptoException("Unexpected cypherblock size");
-        ret.insert(ret.end(), encrypted.data, encrypted.data+encrypted.size);
-        eb += blk_sz;
-        gnutls_free(encrypted.data);
+    unsigned aes_key_sz = aesKeySize(max_block_sz);
+    if (aes_key_sz == 0)
+        throw CryptoException("Key is not long enough for AES128");
+    Blob key(aes_key_sz);
+    {
+        crypto::random_device rdev;
+        std::generate_n(key.begin(), key.size(), std::bind(rand_byte, std::ref(rdev)));
     }
+    auto data_encrypted = aesEncrypt(data, key);
+
+    Blob ret;
+    ret.reserve(cypher_block_sz + data_encrypted.size());
 
+    ret.resize(cypher_block_sz);
+    encryptBloc(key.data(), key.size(), ret.data(), cypher_block_sz);
+    ret.insert(ret.end(), data_encrypted.begin(), data_encrypted.end());
     return ret;
 }
 
-- 
GitLab