diff --git a/src/crypto.cpp b/src/crypto.cpp index d601247860d9a71eb3848180af3d058072d55063..6ab5645eb178468ea13bd7583160939f0e6e591d 100644 --- a/src/crypto.cpp +++ b/src/crypto.cpp @@ -105,7 +105,6 @@ Blob aesEncrypt(const uint8_t* data, size_t data_length, const Blob& key) struct gcm_aes_ctx aes; gcm_aes_set_key(&aes, key.size(), key.data()); gcm_aes_set_iv(&aes, GCM_IV_SIZE, ret.data()); - gcm_aes_update(&aes, data_length, data); gcm_aes_encrypt(&aes, data_length, ret.data() + GCM_IV_SIZE, data); gcm_aes_digest(&aes, GCM_DIGEST_SIZE, ret.data() + GCM_IV_SIZE + data_length); @@ -137,18 +136,8 @@ Blob aesDecrypt(const Blob& data, const Blob& key) size_t data_sz = data.size() - GCM_IV_SIZE - GCM_DIGEST_SIZE; Blob ret(data_sz); - //gcm_aes_update(&aes, data_sz, data.data() + GCM_IV_SIZE); gcm_aes_decrypt(&aes, data_sz, ret.data(), data.data() + GCM_IV_SIZE); - //gcm_aes_digest(aes, GCM_DIGEST_SIZE, digest.data()); - - // TODO compute the proper digest directly from the decryption pass - Blob ret_tmp(data_sz); - struct gcm_aes_ctx aes_d; - gcm_aes_set_key(&aes_d, key.size(), key.data()); - gcm_aes_set_iv(&aes_d, GCM_IV_SIZE, data.data()); - gcm_aes_update(&aes_d, ret.size(), ret.data()); - gcm_aes_encrypt(&aes_d, ret.size(), ret_tmp.data(), ret.data()); - gcm_aes_digest(&aes_d, GCM_DIGEST_SIZE, digest.data()); + gcm_aes_digest(&aes, GCM_DIGEST_SIZE, digest.data()); if (not std::equal(digest.begin(), digest.end(), data.end() - GCM_DIGEST_SIZE)) throw DecryptError("Can't decrypt data"); diff --git a/tests/cryptotester.cpp b/tests/cryptotester.cpp index a69ad26406a27198afbfd1f8687884e6938ed2fe..b556823cdb7e085065135dc960a41164c4251309 100644 --- a/tests/cryptotester.cpp +++ b/tests/cryptotester.cpp @@ -34,16 +34,28 @@ CryptoTester::testSignatureEncryption() { auto key = dht::crypto::PrivateKey::generate(); auto public_key = key.getPublicKey(); - std::vector<uint8_t> data {5, 10}; - std::vector<uint8_t> signature = key.sign(data); + std::vector<uint8_t> data1 {5, 10}; + std::vector<uint8_t> data2(64 * 1024, 10); + + std::vector<uint8_t> signature1 = key.sign(data1); + std::vector<uint8_t> signature2 = key.sign(data2); // check signature - CPPUNIT_ASSERT(public_key.checkSignature(data, signature)); + CPPUNIT_ASSERT(public_key.checkSignature(data1, signature1)); + CPPUNIT_ASSERT(public_key.checkSignature(data2, signature2)); // encrypt data - std::vector<uint8_t> encrypted = public_key.encrypt(data); - std::vector<uint8_t> decrypted = key.decrypt(encrypted); - CPPUNIT_ASSERT(data == decrypted); + { + std::vector<uint8_t> encrypted = public_key.encrypt(data1); + std::vector<uint8_t> decrypted = key.decrypt(encrypted); + CPPUNIT_ASSERT(data1 == decrypted); + } + + { + std::vector<uint8_t> encrypted = public_key.encrypt(data2); + std::vector<uint8_t> decrypted = key.decrypt(encrypted); + CPPUNIT_ASSERT(data2 == decrypted); + } } void