diff --git a/src/base64.cpp b/src/base64.cpp index 8d88aa085b1cf5aca1e5a2177988598f0a8c08d9..2e77ed688c08810d9a419d1646c4f3f0e96cff90 100644 --- a/src/base64.cpp +++ b/src/base64.cpp @@ -20,97 +20,9 @@ #include <stdint.h> #include <stdlib.h> - -/* Mainly based on the following stackoverflow question: - * http://stackoverflow.com/questions/342409/how-do-i-base64-encode-decode-in-c - */ -static const char encoding_table[] = { - 'A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'J', 'K', - 'L', 'M', 'N', 'O', 'P', 'Q', 'R', 'S', 'T', 'U', 'V', - 'W', 'X', 'Y', 'Z', 'a', 'b', 'c', 'd', 'e', 'f', 'g', - 'h', 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', 'r', - 's', 't', 'u', 'v', 'w', 'x', 'y', 'z', '0', '1', '2', - '3', '4', '5', '6', '7', '8', '9', '+', '/' -}; - -static const size_t mod_table[] = { 0, 2, 1 }; - -char *ring_base64_encode(const uint8_t *input, size_t input_length, - char *output, size_t *output_length) -{ - size_t i, j; - size_t out_sz = *output_length; - *output_length = 4 * ((input_length + 2) / 3); - if (out_sz < *output_length || output == NULL) - return NULL; - - for (i = 0, j = 0; i < input_length; ) { - uint8_t octet_a = i < input_length ? input[i++] : 0; - uint8_t octet_b = i < input_length ? input[i++] : 0; - uint8_t octet_c = i < input_length ? input[i++] : 0; - - uint32_t triple = (octet_a << 0x10) + (octet_b << 0x08) + octet_c; - - output[j++] = encoding_table[(triple >> 3 * 6) & 0x3F]; - output[j++] = encoding_table[(triple >> 2 * 6) & 0x3F]; - output[j++] = encoding_table[(triple >> 1 * 6) & 0x3F]; - output[j++] = encoding_table[(triple >> 0 * 6) & 0x3F]; - } - - for (i = 0; i < mod_table[input_length % 3]; i++) - output[*output_length - 1 - i] = '='; - - return output; -} - -uint8_t *ring_base64_decode(const char *input, size_t input_length, - uint8_t *output, size_t *output_length) -{ - size_t i, j; - uint8_t decoding_table[256]; - - uint8_t c; - for (c = 0; c < 64; c++) - decoding_table[static_cast<int>(encoding_table[c])] = c; - - if (input_length % 4 != 0) - return NULL; - - size_t out_sz = *output_length; - *output_length = input_length / 4 * 3; - if (input[input_length - 1] == '=') - (*output_length)--; - if (input[input_length - 2] == '=') - (*output_length)--; - - if (out_sz < *output_length || output == NULL) - return NULL; - - for (i = 0, j = 0; i < input_length;) { - uint8_t sextet_a = input[i] == '=' ? 0 & i++ - : decoding_table[static_cast<int>(input[i++])]; - uint8_t sextet_b = input[i] == '=' ? 0 & i++ - : decoding_table[static_cast<int>(input[i++])]; - uint8_t sextet_c = input[i] == '=' ? 0 & i++ - : decoding_table[static_cast<int>(input[i++])]; - uint8_t sextet_d = input[i] == '=' ? 0 & i++ - : decoding_table[static_cast<int>(input[i++])]; - - uint32_t triple = (sextet_a << 3 * 6) + - (sextet_b << 2 * 6) + - (sextet_c << 1 * 6) + - (sextet_d << 0 * 6); - - if (j < *output_length) - output[j++] = (triple >> 2 * 8) & 0xFF; - if (j < *output_length) - output[j++] = (triple >> 1 * 8) & 0xFF; - if (j < *output_length) - output[j++] = (triple >> 0 * 8) & 0xFF; - } - - return output; -} +#include <iostream> +#include <pjlib.h> +#include <pjlib-util/base64.h> namespace ring { namespace base64 { @@ -119,11 +31,15 @@ std::string encode(const std::vector<uint8_t>::const_iterator begin, const std::vector<uint8_t>::const_iterator end) { - size_t output_length = 4 * ((std::distance(begin, end) + 2) / 3); + int input_length = std::distance(begin, end); + int output_length = 4 * ((input_length + 2) / 3); std::string out; out.resize(output_length); - ring_base64_encode(&(*begin), std::distance(begin, end), - &(*out.begin()), &output_length); + + if(pj_base64_encode( &(*begin), input_length, &(*out.begin()), &output_length) != PJ_SUCCESS) { + throw base64_exception(); + } + out.resize(output_length); return out; } @@ -137,12 +53,19 @@ encode(const std::vector<uint8_t>& dat) std::vector<uint8_t> decode(const std::string& str) { - size_t output_length = str.length() / 4 * 3 + 2; - std::vector<uint8_t> output; - output.resize(output_length); - ring_base64_decode(str.data(), str.size(), output.data(), &output_length); - output.resize(output_length); - return output; + int output_length = str.length() / 4 * 3 + 2; + pj_str_t input; + pj_strset(&input, (char*) &(*str.begin()), str.length()); + + std::vector<uint8_t> out; + out.resize(output_length); + + if(pj_base64_decode(&input, &(*out.begin()), &output_length) != PJ_SUCCESS) { + throw base64_exception(); + } + + out.resize(output_length); + return out; } }} // namespace ring::base64 diff --git a/src/base64.h b/src/base64.h index 2ac51231375a198205bebb66e66d543cb69ed9cd..1e44fbb924d327d28c1d09c69255fefa81e62e66 100644 --- a/src/base64.h +++ b/src/base64.h @@ -21,38 +21,15 @@ #include <stdint.h> #include <stddef.h> -/** - * Encode a buffer in base64. - * - * @param data the input buffer - * @param input_length the input length - * @param output_length the resulting output length - * @return a base64-encoded buffer - * - * @note callers should free the returned memory - */ -char *ring_base64_encode(const uint8_t *input, size_t input_length, - char *output, size_t *output_length); - -/** - * Decode a base64 buffer. - * - * @param data the input buffer - * @param input_length the input length - * @param output_length the resulting output length - * @return a buffer - * - * @note callers should free the returned memory - */ -uint8_t *ring_base64_decode(const char *input, size_t input_length, - uint8_t *output, size_t *output_length); - #include <string> #include <vector> +#include <exception> namespace ring { namespace base64 { +class base64_exception : public std::exception { }; + std::string encode(const std::vector<uint8_t>::const_iterator begin, const std::vector<uint8_t>::const_iterator end); std::string encode(const std::vector<uint8_t>& dat); std::vector<uint8_t> decode(const std::string& str);