diff --git a/src/crypto.cpp b/src/crypto.cpp index d601247860d9a71eb3848180af3d058072d55063..6546ed36fc55816ccf8d6d85d79748040fec22ed 100644 --- a/src/crypto.cpp +++ b/src/crypto.cpp @@ -52,6 +52,9 @@ static std::uniform_int_distribution<uint8_t> rand_byte; #define GNUTLS_PKCS_PBES2_AES_256 GNUTLS_PKCS_USE_PBES2_AES_256 #endif +#define DHT_AES_LEGACY_ENCRYPT 1 +#define DHT_AES_LEGACY_DECRYPT 1 + namespace dht { namespace crypto { @@ -105,7 +108,9 @@ Blob aesEncrypt(const uint8_t* data, size_t data_length, const Blob& key) struct gcm_aes_ctx aes; gcm_aes_set_key(&aes, key.size(), key.data()); gcm_aes_set_iv(&aes, GCM_IV_SIZE, ret.data()); +#if DHT_AES_LEGACY_ENCRYPT gcm_aes_update(&aes, data_length, data); +#endif gcm_aes_encrypt(&aes, data_length, ret.data() + GCM_IV_SIZE, data); gcm_aes_digest(&aes, GCM_DIGEST_SIZE, ret.data() + GCM_IV_SIZE + data_length); @@ -137,21 +142,26 @@ Blob aesDecrypt(const Blob& data, const Blob& key) size_t data_sz = data.size() - GCM_IV_SIZE - GCM_DIGEST_SIZE; Blob ret(data_sz); - //gcm_aes_update(&aes, data_sz, data.data() + GCM_IV_SIZE); gcm_aes_decrypt(&aes, data_sz, ret.data(), data.data() + GCM_IV_SIZE); - //gcm_aes_digest(aes, GCM_DIGEST_SIZE, digest.data()); - - // TODO compute the proper digest directly from the decryption pass - Blob ret_tmp(data_sz); - struct gcm_aes_ctx aes_d; - gcm_aes_set_key(&aes_d, key.size(), key.data()); - gcm_aes_set_iv(&aes_d, GCM_IV_SIZE, data.data()); - gcm_aes_update(&aes_d, ret.size(), ret.data()); - gcm_aes_encrypt(&aes_d, ret.size(), ret_tmp.data(), ret.data()); - gcm_aes_digest(&aes_d, GCM_DIGEST_SIZE, digest.data()); - - if (not std::equal(digest.begin(), digest.end(), data.end() - GCM_DIGEST_SIZE)) + gcm_aes_digest(&aes, GCM_DIGEST_SIZE, digest.data()); + + if (not std::equal(digest.begin(), digest.end(), data.end() - GCM_DIGEST_SIZE)) { +#if DHT_AES_LEGACY_DECRYPT + //gcm_aes_decrypt(&aes, data_sz, ret.data(), data.data() + GCM_IV_SIZE); + Blob ret_tmp(data_sz); + struct gcm_aes_ctx aes_d; + gcm_aes_set_key(&aes_d, key.size(), key.data()); + gcm_aes_set_iv(&aes_d, GCM_IV_SIZE, data.data()); + gcm_aes_update(&aes_d, ret.size(), ret.data()); + gcm_aes_encrypt(&aes_d, ret.size(), ret_tmp.data(), ret.data()); + gcm_aes_digest(&aes_d, GCM_DIGEST_SIZE, digest.data()); + + if (not std::equal(digest.begin(), digest.end(), data.end() - GCM_DIGEST_SIZE)) + throw DecryptError("Can't decrypt data"); +#else throw DecryptError("Can't decrypt data"); +#endif + } return ret; } diff --git a/tests/cryptotester.cpp b/tests/cryptotester.cpp index a69ad26406a27198afbfd1f8687884e6938ed2fe..b556823cdb7e085065135dc960a41164c4251309 100644 --- a/tests/cryptotester.cpp +++ b/tests/cryptotester.cpp @@ -34,16 +34,28 @@ CryptoTester::testSignatureEncryption() { auto key = dht::crypto::PrivateKey::generate(); auto public_key = key.getPublicKey(); - std::vector<uint8_t> data {5, 10}; - std::vector<uint8_t> signature = key.sign(data); + std::vector<uint8_t> data1 {5, 10}; + std::vector<uint8_t> data2(64 * 1024, 10); + + std::vector<uint8_t> signature1 = key.sign(data1); + std::vector<uint8_t> signature2 = key.sign(data2); // check signature - CPPUNIT_ASSERT(public_key.checkSignature(data, signature)); + CPPUNIT_ASSERT(public_key.checkSignature(data1, signature1)); + CPPUNIT_ASSERT(public_key.checkSignature(data2, signature2)); // encrypt data - std::vector<uint8_t> encrypted = public_key.encrypt(data); - std::vector<uint8_t> decrypted = key.decrypt(encrypted); - CPPUNIT_ASSERT(data == decrypted); + { + std::vector<uint8_t> encrypted = public_key.encrypt(data1); + std::vector<uint8_t> decrypted = key.decrypt(encrypted); + CPPUNIT_ASSERT(data1 == decrypted); + } + + { + std::vector<uint8_t> encrypted = public_key.encrypt(data2); + std::vector<uint8_t> decrypted = key.decrypt(encrypted); + CPPUNIT_ASSERT(data2 == decrypted); + } } void