diff --git a/configure.ac b/configure.ac index 3e156ce5539f738fba15368fa5a6e03a501fcd25..9ac2f54918d50fa99ea15dd6864f5b4513787ea0 100644 --- a/configure.ac +++ b/configure.ac @@ -73,6 +73,7 @@ AX_CXX_COMPILE_STDCXX_11([noext],[mandatory]) PKG_PROG_PKG_CONFIG() PKG_CHECK_MODULES([GNUTLS], [gnutls >= 3.1]) +PKG_CHECK_MODULES([msgpack], [msgpack]) AC_ARG_ENABLE([tools], AS_HELP_STRING([--disable-tools],[Disable tools (CLI DHT node)]),,build_tools=yes) AM_CONDITIONAL(ENABLE_TOOLS, test x$build_tools == xyes) diff --git a/include/opendht/crypto.h b/include/opendht/crypto.h index 3df04c9bbfaf504bdecf7e3b545d0697e9f7d580..59f8fcfdc871631e8ef6b8252101237f65f1f0ff 100644 --- a/include/opendht/crypto.h +++ b/include/opendht/crypto.h @@ -31,7 +31,6 @@ #pragma once #include "infohash.h" -#include "serialize.h" extern "C" { #include <gnutls/gnutls.h> @@ -42,6 +41,8 @@ extern "C" { #include <vector> #include <memory> +typedef std::vector<uint8_t> Blob; + namespace dht { namespace crypto { @@ -75,7 +76,7 @@ Identity generateIdentity(const std::string& name = "dhtnode", Identity ca = {}, /** * A public key. */ -struct PublicKey : public Serializable +struct PublicKey { PublicKey() {} PublicKey(gnutls_pubkey_t k) : pk(k) {} @@ -91,9 +92,19 @@ struct PublicKey : public Serializable bool checkSignature(const Blob& data, const Blob& signature) const; Blob encrypt(const Blob&) const; - void pack(Blob& b) const override; + void pack(Blob& b) const; + void unpack(const uint8_t* dat, size_t dat_size); + + template <typename Packer> + void msgpack_pack(Packer& p) const + { + Blob b; + pack(b); + p.pack_bin(b.size()); + p.pack_bin_body((const char*)b.data(), b.size()); + } - void unpack(Blob::const_iterator& begin, Blob::const_iterator& end) override; + void msgpack_unpack(msgpack::object o); gnutls_pubkey_t pk {}; private: @@ -147,7 +158,7 @@ private: friend dht::crypto::Identity dht::crypto::generateIdentity(const std::string&, dht::crypto::Identity, unsigned key_length); }; -struct Certificate : public Serializable { +struct Certificate { Certificate() {} /** @@ -155,11 +166,16 @@ struct Certificate : public Serializable { */ Certificate(gnutls_x509_crt_t crt) : cert(crt) {} + Certificate(Certificate&& o) noexcept : cert(o.cert), issuer(std::move(o.issuer)) { o.cert = nullptr; }; + /** * Import certificate (PEM or DER) or certificate chain (PEM), * ordered from subject to issuer */ Certificate(const Blob& crt); + Certificate(const uint8_t* dat, size_t dat_size) { + unpack(dat, dat_size); + } /** * Import certificate chain (PEM or DER), @@ -179,12 +195,16 @@ struct Certificate : public Serializable { unpack(certs); } - Certificate(Certificate&& o) noexcept : cert(o.cert), issuer(std::move(o.issuer)) { o.cert = nullptr; }; Certificate& operator=(Certificate&& o) noexcept; ~Certificate(); - void pack(Blob& b) const override; - void unpack(Blob::const_iterator& begin, Blob::const_iterator& end) override; + void pack(Blob& b) const; + void unpack(const uint8_t* dat, size_t dat_size); + Blob getPacked() const { + Blob b; + pack(b); + return b; + } template<typename Iterator> void unpack(const Iterator& begin, const Iterator& end) @@ -227,6 +247,17 @@ struct Certificate : public Serializable { *this = tmp_issuer ? std::move(*tmp_issuer) : Certificate(); } + template <typename Packer> + void msgpack_pack(Packer& p) const + { + Blob b; + pack(b); + p.pack_bin(b.size()); + p.pack_bin_body((const char*)b.data(), b.size()); + } + + void msgpack_unpack(msgpack::object o); + operator bool() const { return cert; } PublicKey getPublicKey() const; diff --git a/include/opendht/default_types.h b/include/opendht/default_types.h index bfffd58f541c43d85588a310fb58c5cb4ef03148..b19714a08bd653bfe836f79a63f52782f94d1cfa 100644 --- a/include/opendht/default_types.h +++ b/include/opendht/default_types.h @@ -34,17 +34,14 @@ namespace dht { -struct DhtMessage : public ValueSerializable +struct DhtMessage : public ValueSerializable<DhtMessage> { DhtMessage(std::string s = {}, Blob msg = {}) : service(s), data(msg) {} - + std::string getService() const { return service; } - virtual void pack(Blob& res) const; - virtual void unpack(Blob::const_iterator& begin, Blob::const_iterator& end); - static const ValueType TYPE; virtual const ValueType& getType() const { return TYPE; @@ -61,14 +58,15 @@ struct DhtMessage : public ValueSerializable public: std::string service; Blob data; + MSGPACK_DEFINE(service, data); }; - -struct SignedValue : public ValueSerializable +template <typename Type> +struct SignedValue : public ValueSerializable<Type> { virtual void unpackValue(const Value& v) { from = v.owner.getId(); - ValueSerializable::unpackValue(v); + ValueSerializable<Type>::unpackValue(v); } static Value::Filter getFilter() { return [](const Value& v){ return v.isSigned(); }; @@ -77,15 +75,16 @@ public: dht::InfoHash from; }; -struct EncryptedValue : public SignedValue +template <typename Type> +struct EncryptedValue : public SignedValue<Type> { virtual void unpackValue(const Value& v) { to = v.recipient; - SignedValue::unpackValue(v); + SignedValue<Type>::unpackValue(v); } static Value::Filter getFilter() { return Value::Filter::chain( - SignedValue::getFilter(), + SignedValue<Type>::getFilter(), [](const Value& v){ return v.recipient != InfoHash(); } ); } @@ -94,7 +93,7 @@ public: dht::InfoHash to; }; -struct ImMessage : public SignedValue +struct ImMessage : public SignedValue<ImMessage> { ImMessage() {} ImMessage(std::string&& msg) @@ -107,14 +106,6 @@ struct ImMessage : public SignedValue static Value::Filter getFilter() { return SignedValue::getFilter(); } - virtual void pack(Blob& data) const { - serialize<std::chrono::system_clock::time_point>(std::chrono::system_clock::now(), data); - data.insert(data.end(), im_message.begin(), im_message.end()); - } - virtual void unpack(Blob::const_iterator& b, Blob::const_iterator& e) { - sent = deserialize<decltype(sent)>(b, e); - im_message = std::string(b, e); - } virtual void unpackValue(const Value& v) { to = v.recipient; SignedValue::unpackValue(v); @@ -123,9 +114,10 @@ struct ImMessage : public SignedValue dht::InfoHash to; std::chrono::system_clock::time_point sent; std::string im_message; + MSGPACK_DEFINE(im_message); }; -struct TrustRequest : public EncryptedValue +struct TrustRequest : public EncryptedValue<TrustRequest> { TrustRequest() {} TrustRequest(std::string s) : service(s) {} @@ -138,22 +130,16 @@ struct TrustRequest : public EncryptedValue static Value::Filter getFilter() { return EncryptedValue::getFilter(); } - virtual void pack(Blob& data) const { - serialize<std::string>(service, data); - serialize<Blob>(payload, data); - } - virtual void unpack(Blob::const_iterator& b, Blob::const_iterator& e) { - service = deserialize<std::string>(b, e); - payload = deserialize<Blob>(b, e); - } + std::string service; Blob payload; + MSGPACK_DEFINE(service, payload); }; -struct IceCandidates : public EncryptedValue +struct IceCandidates : public EncryptedValue<IceCandidates> { IceCandidates() {} - IceCandidates(Blob ice) : ice_data(ice) {} + IceCandidates(Value::Id msg_id, Blob ice) : id(msg_id), ice_data(ice) {} static const ValueType TYPE; virtual const ValueType& getType() const { @@ -162,12 +148,7 @@ struct IceCandidates : public EncryptedValue static Value::Filter getFilter() { return EncryptedValue::getFilter(); } - virtual void pack(Blob& data) const { - serialize<Blob>(ice_data, data); - } - virtual void unpack(Blob::const_iterator& b, Blob::const_iterator& e) { - ice_data = deserialize<Blob>(b, e); - } + virtual void unpackValue(const Value& v) { EncryptedValue::unpackValue(v); id = v.id; @@ -175,12 +156,13 @@ struct IceCandidates : public EncryptedValue Value::Id id; Blob ice_data; + MSGPACK_DEFINE(id, ice_data); }; /* "Peer" announcement */ -struct IpServiceAnnouncement : public ValueSerializable +struct IpServiceAnnouncement : public ValueSerializable<IpServiceAnnouncement> { IpServiceAnnouncement(in_port_t p = 0) { ss.ss_family = 0; @@ -193,12 +175,37 @@ struct IpServiceAnnouncement : public ValueSerializable } IpServiceAnnouncement(const Blob& b) { - unpackBlob(b); + msgpack_unpack(unpack(b).get()); + } + + template <typename Packer> + void msgpack_pack(Packer& pk) const + { + pk.pack_array(2); + pk.pack(getPort()); + if (ss.ss_family == AF_INET) { + pk.pack_bin(sizeof(in_addr)); + pk.pack_bin_body((const char*)&reinterpret_cast<const sockaddr_in*>(&ss)->sin_addr, sizeof(in_addr)); + } else if (ss.ss_family == AF_INET6) { + pk.pack_bin(sizeof(in6_addr)); + pk.pack_bin_body((const char*)&reinterpret_cast<const sockaddr_in6*>(&ss)->sin6_addr, sizeof(in6_addr)); + } + } + + void msgpack_unpack(msgpack::object o) + { + if (o.type != msgpack::type::ARRAY) throw msgpack::type_error(); + if (o.via.array.size < 2) throw msgpack::type_error(); + setPort(o.via.array.ptr[0].as<in_port_t>()); + auto ip_dat = o.via.array.ptr[1].as<Blob>(); + if (ip_dat.size() == sizeof(in_addr)) + std::copy(ip_dat.begin(), ip_dat.end(), (char*)&reinterpret_cast<sockaddr_in*>(&ss)->sin_addr); + else if (ip_dat.size() == sizeof(in6_addr)) + std::copy(ip_dat.begin(), ip_dat.end(), (char*)&reinterpret_cast<sockaddr_in6*>(&ss)->sin6_addr); + else + throw msgpack::type_error(); } - virtual void pack(Blob& res) const; - virtual void unpack(Blob::const_iterator& begin, Blob::const_iterator& end); - in_port_t getPort() const { return ntohs(reinterpret_cast<const sockaddr_in*>(&ss)->sin_port); } diff --git a/include/opendht/dht.h b/include/opendht/dht.h index 2194d8e1f47600e16d0f19cae74ce0a7f77c3330..f3b0bb783581bbcf5af105da2d8c2b3994e4c995 100644 --- a/include/opendht/dht.h +++ b/include/opendht/dht.h @@ -359,7 +359,7 @@ private: static constexpr unsigned MAX_HASHES {16384}; /* The maximum number of searches we keep data about. */ - static constexpr unsigned MAX_SEARCHES {1024}; + static constexpr unsigned MAX_SEARCHES {128}; /* The time after which we can send get requests for a search in case of no answers. */ @@ -383,6 +383,7 @@ private: static constexpr unsigned TOKEN_SIZE {64}; + static const std::string my_v; struct NodeCache { std::shared_ptr<Node> getNode(const InfoHash& id, sa_family_t family); @@ -673,6 +674,7 @@ private: */ struct TransId final : public std::array<uint8_t, 4> { TransId() {} + TransId(const std::array<char, 4>& o) { std::copy(o.begin(), o.end(), begin()); } TransId(const TransPrefix prefix, uint16_t seqno = 0) { std::copy_n(prefix.begin(), prefix.size(), begin()); *reinterpret_cast<uint16_t*>(data()+prefix.size()) = seqno; @@ -688,7 +690,7 @@ private: } bool matches(const TransPrefix prefix, uint16_t *seqno_return = nullptr) const { - if (std::equal(begin(), begin()+1, prefix.begin())) { + if ((*this)[0] == prefix[0] && (*this)[1] == prefix[1]) { if (seqno_return) *seqno_return = *reinterpret_cast<const uint16_t*>(&(*this)[2]); return true; @@ -708,7 +710,7 @@ private: int dht_socket6 {-1}; InfoHash myid {}; - static const uint8_t my_v[9]; + std::array<uint8_t, 8> secret {{}}; std::array<uint8_t, 8> oldsecret {{}}; @@ -785,16 +787,24 @@ private: int sendError(const sockaddr*, socklen_t, TransId tid, uint16_t code, const char *message, bool include_id=false); void processMessage(const uint8_t *buf, size_t buflen, const sockaddr *from, socklen_t fromlen); - MessageType parseMessage(const uint8_t *buf, size_t buflen, - TransId& tid, - InfoHash& id_return, InfoHash& info_hash_return, - InfoHash& target_return, in_port_t& port_return, - Blob& token, Value::Id& value_id, - uint8_t *nodes_return, unsigned *nodes_len, - uint8_t *nodes6_return, unsigned *nodes6_len, - std::vector<std::shared_ptr<Value>>& values_return, - want_t* want_return, uint16_t& error_code, bool& ring, - sockaddr* addr_return, socklen_t& addr_length_return); + + struct ParsedMessage { + MessageType type; + InfoHash id; + InfoHash info_hash; + InfoHash target; + TransId tid; + Blob token; + Value::Id value_id; + Blob nodes4; + Blob nodes6; + std::vector<std::shared_ptr<Value>> values; + want_t want; + uint16_t error_code; + std::string ua; + Address addr; + void msgpack_unpack(msgpack::object o); + }; void rotateSecrets(); diff --git a/include/opendht/dhtrunner.h b/include/opendht/dhtrunner.h index eab88bbf8a8968847de3935f91d32e387cc83da7..0f9003c44ba7bd441f7a88fad22ee7fc25e68306 100644 --- a/include/opendht/dhtrunner.h +++ b/include/opendht/dhtrunner.h @@ -85,14 +85,13 @@ public: { get(hash, [=](const std::vector<std::shared_ptr<Value>>& vals) { for (const auto& v : vals) { - T msg; try { - msg.unpackValue(*v); + auto msg = unpack<T>(v->data); + if (not cb(std::move(msg))) + return false; } catch (const std::exception&) { continue; } - if (not cb(std::move(msg))) - return false; } return true; }, @@ -145,14 +144,13 @@ public: { return listen(hash, [=](const std::vector<std::shared_ptr<Value>>& vals) { for (const auto& v : vals) { - T msg; try { - msg.unpackValue(*v); + auto msg = unpack<T>(v->data); + if (not cb(std::move(msg))) + return false; } catch (const std::exception&) { continue; } - if (not cb(std::move(msg))) - return false; } return true; }, diff --git a/include/opendht/infohash.h b/include/opendht/infohash.h index acc7a33552aaf0d051cb2026c62ff13d17bea755..2dadb14b1b5094214075fe0b47796bbcedd266f8 100644 --- a/include/opendht/infohash.h +++ b/include/opendht/infohash.h @@ -30,6 +30,8 @@ #pragma once +#include <msgpack.hpp> + #include <iostream> #include <iomanip> #include <array> @@ -73,6 +75,10 @@ public: */ explicit InfoHash(const std::string& hex); + InfoHash(const msgpack::object& o) { + msgpack_unpack(o); + } + /** * Find the lowest 1 bit in an id. * Result will allways be lower than 8*HASH_LEN @@ -193,6 +199,20 @@ public: friend std::ostream& operator<< (std::ostream& s, const InfoHash& h); std::string toString() const; + + template <typename Packer> + void msgpack_pack(Packer& pk) const + { + pk.pack_bin(HASH_LEN); + pk.pack_bin_body((char*)data(), HASH_LEN); + } + + void msgpack_unpack(msgpack::object o) { + if (o.type != msgpack::type::BIN or o.via.bin.size != HASH_LEN) + throw msgpack::type_error(); + std::copy_n(o.via.bin.ptr, HASH_LEN, data()); + } + }; } diff --git a/include/opendht/serialize.h b/include/opendht/serialize.h deleted file mode 100644 index 3e81f5b7cc25806e02cc65b6aa6e6d2911d950b1..0000000000000000000000000000000000000000 --- a/include/opendht/serialize.h +++ /dev/null @@ -1,339 +0,0 @@ -/** - * Copyright (c) 2013, Simone Pellegrini All rights reserved. - * Copyright (c) 2014 Savoir-Faire Linux. All rights reserved. - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * - Redistributions of source code must retain the above copyright notice, - * this list of conditions and the following disclaimer. - * - * - Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE - * ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE - * LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR - * CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF - * SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS - * INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN - * CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) - * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE - * POSSIBILITY OF SUCH DAMAGE. - */ - -#pragma once - -#include <vector> -#include <string> -#include <tuple> -#include <numeric> -#include <limits> -#include <chrono> - -typedef std::vector<uint8_t> Blob; - -template <class T> -inline void serialize(const T&, Blob&); - -namespace detail { - - template<std::size_t> struct int_{}; - -} - -// get_size -template <class T> -size_t get_size(const T& obj); - -namespace detail { - - typedef uint16_t serialized_size_t; - - template <class T> - struct get_size_helper; - - template <class T> - struct get_size_helper<std::vector<T>> { - static size_t value(const std::vector<T>& obj) { - return std::accumulate(obj.begin(), obj.end(), sizeof(serialized_size_t), - [](const size_t& acc, const T& cur) { return acc+get_size(cur); }); - } - }; - - template <> - struct get_size_helper<std::string> { - static size_t value(const std::string& obj) { - return sizeof(serialized_size_t) + obj.length()*sizeof(uint8_t); - } - }; - - template <class T> - struct get_size_helper<std::chrono::time_point<T>> { - static constexpr size_t value(const std::chrono::time_point<T>&) { - return sizeof(typename std::chrono::time_point<T>::rep); - } - }; - - template <class tuple_type> - inline size_t get_tuple_size(const tuple_type& obj, int_<0>) { - constexpr size_t idx = std::tuple_size<tuple_type>::value-1; - return get_size(std::get<idx>(obj)); - } - - template <class tuple_type, size_t pos> - inline size_t get_tuple_size(const tuple_type& obj, int_<pos>) { - constexpr size_t idx = std::tuple_size<tuple_type>::value-pos-1; - size_t acc = get_size(std::get<idx>(obj)); - - // recur - return acc+get_tuple_size(obj, int_<pos-1>()); - } - - template <class ...T> - struct get_size_helper<std::tuple<T...>> { - static size_t value(const std::tuple<T...>& obj) { - return get_tuple_size(obj, int_<sizeof...(T)-1>()); - } - }; - - template <class T> - struct get_size_helper { - static size_t value(const T&) { return sizeof(T); } - }; - -} - -template <class T> -inline size_t get_size(const T& obj) { - return detail::get_size_helper<T>::value(obj); -} - -namespace detail { - - template <class T> - class serialize_helper; - - template <class T> - void serializer(const T& obj, Blob::iterator&); - - template <class tuple_type> - inline void serialize_tuple(const tuple_type& obj, Blob::iterator& res, int_<0>) { - constexpr size_t idx = std::tuple_size<tuple_type>::value-1; - serializer(std::get<idx>(obj), res); - } - - template <class tuple_type, size_t pos> - inline void serialize_tuple(const tuple_type& obj, Blob::iterator& res, int_<pos>) { - constexpr size_t idx = std::tuple_size<tuple_type>::value-pos-1; - serializer(std::get<idx>(obj), res); - - // recur - serialize_tuple(obj, res, int_<pos-1>()); - } - - template <class... T> - struct serialize_helper<std::tuple<T...>> { - static void apply(const std::tuple<T...>& obj, Blob::iterator& res) { - detail::serialize_tuple(obj, res, detail::int_<sizeof...(T)-1>()); - } - - }; - - template <> - struct serialize_helper<std::string> { - static void apply(const std::string& obj, Blob::iterator& res) { - // store the number of elements of this vector at the beginning - if (obj.length() > std::numeric_limits<serialized_size_t>::max()) - throw std::length_error("string is too long"); - serializer(static_cast<serialized_size_t>(obj.length()), res); - for(const auto& cur : obj) { serializer(cur, res); } - } - - }; - - template <class T> - struct serialize_helper<std::vector<T>> { - static void apply(const std::vector<T>& obj, Blob::iterator& res) { - // store the number of elements of this vector at the beginning - if (obj.size() > std::numeric_limits<serialized_size_t>::max()) - throw std::length_error("vector is too large"); - serializer(static_cast<serialized_size_t>(obj.size()), res); - for(const auto& cur : obj) { serializer(cur, res); } - } - - }; - - template <class T> - struct serialize_helper<std::chrono::time_point<T>> { - static void apply(const std::chrono::time_point<T>& obj, Blob::iterator& res) { - serializer(obj.time_since_epoch().count(), res); - } - }; - - template <class T> - struct serialize_helper { - static void apply(const T& obj, Blob::iterator& res) { - const uint8_t* ptr = reinterpret_cast<const uint8_t*>(&obj); - std::copy(ptr,ptr+sizeof(T),res); - res+=sizeof(T); - } - - }; - - template <class T> - inline void serializer(const T& obj, Blob::iterator& res) { - serialize_helper<T>::apply(obj,res); - } - -} // end detail namespace - -template <class T> -inline void serialize(const T& obj, Blob& res) { - - size_t offset = res.size(); - size_t size = get_size(obj); - res.resize(res.size() + size); - - Blob::iterator it = res.begin()+offset; - detail::serializer(obj,it); - if (res.begin() + offset + size != it) - throw std::logic_error("error serializing object"); -} - -namespace detail { - - template <class T> - struct deserialize_helper; - - template <class T> - struct deserialize_helper { - static T apply(Blob::const_iterator& begin, - Blob::const_iterator end) { - if (begin+sizeof(T)>end) - throw std::length_error("error deserializing object"); - T val; - std::copy(begin, begin+sizeof(T), reinterpret_cast<uint8_t*>(&val)); - begin+=sizeof(T); - return val; - } - }; - - template <class T> - struct deserialize_helper<std::chrono::time_point<T>> { - static std::chrono::time_point<T> apply(Blob::const_iterator& begin, - Blob::const_iterator end) { - return std::chrono::time_point<T>(typename T::duration(deserialize_helper<typename T::rep>::apply(begin,end))); - } - }; - - template <class T> - struct deserialize_helper<std::vector<T>> { - static std::vector<T> apply(Blob::const_iterator& begin, - Blob::const_iterator end) - { - // retrieve the number of elements - serialized_size_t size = deserialize_helper<serialized_size_t>::apply(begin,end); - - std::vector<T> vect(size); - for(size_t i=0; i<size; ++i) { - vect[i] = std::move(deserialize_helper<T>::apply(begin,end)); - } - return vect; - } - }; - - template <> - struct deserialize_helper<std::string> { - static std::string apply(Blob::const_iterator& begin, - Blob::const_iterator end) - { - // retrieve the number of elements - serialized_size_t size = deserialize_helper<serialized_size_t>::apply(begin,end); - - if (size == 0u) return std::string(); - std::string str(size,'\0'); - for(size_t i=0; i<size; ++i) { - str.at(i) = deserialize_helper<uint8_t>::apply(begin,end); - } - return str; - } - }; - - template <class tuple_type> - inline void deserialize_tuple(tuple_type& obj, - Blob::const_iterator& begin, - Blob::const_iterator end, int_<0>) { - constexpr size_t idx = std::tuple_size<tuple_type>::value-1; - typedef typename std::tuple_element<idx,tuple_type>::type T; - - std::get<idx>(obj) = std::move(deserialize_helper<T>::apply(begin, end)); - } - - template <class tuple_type, size_t pos> - inline void deserialize_tuple(tuple_type& obj, - Blob::const_iterator& begin, - Blob::const_iterator end, int_<pos>) { - constexpr size_t idx = std::tuple_size<tuple_type>::value-pos-1; - typedef typename std::tuple_element<idx,tuple_type>::type T; - std::get<idx>(obj) = std::move(deserialize_helper<T>::apply(begin, end)); - - // recur - deserialize_tuple(obj, begin, end, int_<pos-1>()); - } - - template <class... T> - struct deserialize_helper<std::tuple<T...>> { - static std::tuple<T...> apply(Blob::const_iterator& begin, - Blob::const_iterator end) - { - //return std::make_tuple(deserialize(begin,begin+sizeof(T),T())...); - std::tuple<T...> ret; - deserialize_tuple(ret, begin, end, int_<sizeof...(T)-1>()); - return ret; - } - - }; - -} - -template <class T> -inline T deserialize(Blob::const_iterator& begin, const Blob::const_iterator& end) { - return detail::deserialize_helper<T>::apply(begin, end); -} - -template <class T> -inline T deserialize(const Blob& res) { - Blob::const_iterator it = res.begin(); - return deserialize<T>(it, res.end()); -} - -namespace dht { - - struct Serializable { - /** - * Append serialized object to res. - */ - virtual void pack(Blob& res) const = 0; - Blob getPacked() const { - Blob ret; - pack(ret); - return ret; - } - - /** - * Read serialized object from {begin, end}. - */ - virtual void unpack(Blob::const_iterator& begin, Blob::const_iterator& end) = 0; - void unpackBlob(const Blob& data) { - auto cib = data.cbegin(), cie = data.cend(); - unpack(cib, cie); - } - - virtual ~Serializable() = default; -}; - -} diff --git a/include/opendht/value.h b/include/opendht/value.h index 3a0b0b8e70bf4a4125ea69dccd59d0fc40faa58b..a8e463f47d8d87a1ed18c1aba57c77699fbf04fc 100644 --- a/include/opendht/value.h +++ b/include/opendht/value.h @@ -32,7 +32,7 @@ #include "infohash.h" #include "crypto.h" -#include "serialize.h" +#include <msgpack.hpp> #ifndef _WIN32 #include <netinet/in.h> @@ -153,12 +153,23 @@ struct ValueType { EditPolicy editPolicy {DEFAULT_EDIT_POLICY}; }; -struct ValueSerializable : public Serializable -{ - virtual const ValueType& getType() const = 0; - virtual void unpackValue(const Value& v); - virtual Value packValue() const; -}; +template <typename Type> +Blob +pack(const Type& t) { + msgpack::sbuffer buffer; + msgpack::packer<msgpack::sbuffer> pk(&buffer); + pk.pack(t); + return {buffer.data(), buffer.data()+buffer.size()}; +} + +template <typename Type> +Type +unpack(Blob b) { + msgpack::unpacked msg_res = msgpack::unpack((const char*)b.data(), b.size()); + return msg_res.get().as<Type>(); +} + +msgpack::unpacked unpack(Blob b); /** * A "value" is data potentially stored on the Dht, with some metadata. @@ -169,7 +180,7 @@ struct ValueSerializable : public Serializable * Values are stored at a given InfoHash in the Dht, but also have a * unique ID to distinguish between values stored at the same location. */ -struct Value : public Serializable +struct Value { typedef uint64_t Id; static const Id INVALID_ID {0}; @@ -222,31 +233,11 @@ struct Value : public Serializable }; } - /** - * Hold information about how the data is signed/encrypted. - * Class is final because bitset have no virtual destructor. - */ - class ValueFlags final : public std::bitset<3> { - public: - using std::bitset<3>::bitset; - ValueFlags() {} - ValueFlags(bool sign, bool encrypted, bool have_recipient = false) : bitset<3>((sign ? 1:0) | (encrypted ? 2:0) | (have_recipient ? 4:0)) {} - bool isSigned() const { - return (*this)[0]; - } - bool isEncrypted() const { - return (*this)[1]; - } - bool haveRecipient() const { - return (*this)[2]; - } - }; - bool isEncrypted() const { - return flags.isEncrypted(); + return not cypher.empty(); } bool isSigned() const { - return flags.isSigned(); + return isEncrypted() or not signature.empty(); } Value() {} @@ -260,12 +251,14 @@ struct Value : public Serializable : id(id), type(t), data(std::move(data)) {} Value(ValueType::Id t, const uint8_t* dat_ptr, size_t dat_len, Id id = INVALID_ID) : id(id), type(t), data(dat_ptr, dat_ptr+dat_len) {} - Value(ValueType::Id t, const Serializable& d, Id id = INVALID_ID) - : id(id), type(t), data(d.getPacked()) {} - Value(const ValueType& t, const Serializable& d, Id id = INVALID_ID) - : id(id), type(t.id), data(d.getPacked()) {} - Value(const ValueSerializable& d, Id id = INVALID_ID) - : id(id), type(d.getType().id), data(d.getPacked()) {} + + template <typename Type> + Value(ValueType::Id t, const Type& d, Id id = INVALID_ID) + : id(id), type(t), data(pack(d)) {} + + template <typename Type> + Value(const ValueType& t, const Type& d, Id id = INVALID_ID) + : id(id), type(t.id), data(pack(d)) {} /** Custom user data constructor */ Value(const Blob& userdata) : data(userdata) {} @@ -273,41 +266,50 @@ struct Value : public Serializable Value(const uint8_t* dat_ptr, size_t dat_len) : data(dat_ptr, dat_ptr+dat_len) {} Value(Value&& o) noexcept - : id(o.id), flags(o.flags), owner(std::move(o.owner)), recipient(o.recipient), + : id(o.id), owner(std::move(o.owner)), recipient(o.recipient), type(o.type), data(std::move(o.data)), seq(o.seq), signature(std::move(o.signature)), cypher(std::move(o.cypher)) {} + template <typename ValueType> + Value(const ValueType& vs) + : Value(vs.packValue()) {} + + Value(const msgpack::object& o) { + msgpack_unpack(o); + } + inline bool operator== (const Value& o) { return id == o.id && - (flags.isEncrypted() ? cypher == o.cypher : + (isEncrypted() ? cypher == o.cypher : (owner == o.owner && type == o.type && data == o.data && signature == o.signature)); } void setRecipient(const InfoHash& r) { recipient = r; - flags[2] = true; } void setCypher(Blob&& c) { cypher = std::move(c); - flags = {true, true, true}; } /** - * Pack part of the data to be signed + * Pack part of the data to be signed (must always be done the same way) */ - void packToSign(Blob& res) const; - Blob getToSign() const; + Blob getToSign() const { + msgpack::sbuffer buffer; + msgpack::packer<msgpack::sbuffer> pk(&buffer); + msgpack_pack_to_sign(pk); + return {buffer.data(), buffer.data()+buffer.size()}; + } /** * Pack part of the data to be encrypted */ - void packToEncrypt(Blob& res) const; - Blob getToEncrypt() const; - - void pack(Blob& res) const; - - void unpackBody(Blob::const_iterator& begin, Blob::const_iterator& end); - virtual void unpack(Blob::const_iterator& begin, Blob::const_iterator& end); + Blob getToEncrypt() const { + msgpack::sbuffer buffer; + msgpack::packer<msgpack::sbuffer> pk(&buffer); + msgpack_pack_to_encrypt(pk); + return {buffer.data(), buffer.data()+buffer.size()}; + } /** print value for debugging */ friend std::ostream& operator<< (std::ostream& s, const Value& v); @@ -318,11 +320,53 @@ struct Value : public Serializable return ss.str(); } - Id id {INVALID_ID}; + template <typename Packer> + void msgpack_pack_to_sign(Packer& pk) const + { + pk.pack_map((user_type.empty()?0:1) + (owner?(recipient == InfoHash() ? 4 : 5):2)); + if (owner) { // isSigned + pk.pack(std::string("seq")); pk.pack(seq); + pk.pack(std::string("owner")); owner.msgpack_pack(pk); + if (recipient != InfoHash()) { + pk.pack(std::string("to")); pk.pack(recipient); + } + } + pk.pack(std::string("type")); pk.pack(type); + pk.pack(std::string("data")); pk.pack_bin(data.size()); + pk.pack_bin_body((const char*)data.data(), data.size()); + if (not user_type.empty()) { + pk.pack(std::string("utype")); pk.pack(user_type); + } + } - // data (part that is signed / encrypted) + template <typename Packer> + void msgpack_pack_to_encrypt(Packer& pk) const + { + if (isEncrypted()) { + pk.pack_bin(cypher.size()); + pk.pack_bin_body((const char*)cypher.data(), cypher.size()); + } else { + pk.pack_map(isSigned() ? 2 : 1); + pk.pack(std::string("body")); msgpack_pack_to_sign(pk); + if (isSigned()) { + pk.pack(std::string("sig")); pk.pack_bin(signature.size()); + pk.pack_bin_body((const char*)signature.data(), signature.size()); + } + } + } - ValueFlags flags {}; + template <typename Packer> + void msgpack_pack(Packer& pk) const + { + pk.pack_map(2); + pk.pack(std::string("id")); pk.pack(id); + pk.pack(std::string("dat")); msgpack_pack_to_encrypt(pk); + } + + void msgpack_unpack(msgpack::object o); + void msgpack_unpack_body(const msgpack::object& o); + + Id id {INVALID_ID}; /** * Public key of the signer. @@ -342,6 +386,11 @@ struct Value : public Serializable ValueType::Id type {ValueType::USER_DATA.id}; Blob data {}; + /** + * Custom user-defined type + */ + std::string user_type {}; + /** * Sequence number to avoid replay attacks */ @@ -358,6 +407,41 @@ struct Value : public Serializable Blob cypher {}; }; + + +template <typename Type> +struct ValueSerializable /* : public Serializable*/ +{ + //ValueSerializable() {}; + //ValueSerializable(const Type& t) : t(t) {}; + + virtual const ValueType& getType() const = 0; + //virtual void unpackValue(const Value& v); + virtual void unpackValue(const Value& v) { + auto msg = msgpack::unpack((const char*)v.data.data(), v.data.size()); + msgpack::object obj = msg.get(); + obj.convert(static_cast<Type*>(this)); + } + + virtual Value packValue() const { + return Value {getType(), static_cast<const Type&>(*this)}; + } + virtual ~ValueSerializable() = default; + + //Type t; +/* + Blob pack() { + return pack<Type>(*this); + } + void unpack(Blob&) { + return pack<Type>(*this); + } +*/ + //Blob getPacked() const; +}; + + + template <class T> std::vector<T> unpackVector(const std::vector<std::shared_ptr<Value>>& vals) { diff --git a/src/Makefile.am b/src/Makefile.am index a7842c68375908cb6e3f5e58877cb6baf3e623e6..ca1fff6d595ca0a1b478f78b5a9c88017ab0d4c6 100644 --- a/src/Makefile.am +++ b/src/Makefile.am @@ -24,6 +24,5 @@ nobase_include_HEADERS = \ ../include/opendht/crypto.h \ ../include/opendht/securedht.h \ ../include/opendht/dhtrunner.h \ - ../include/opendht/serialize.h \ ../include/opendht/default_types.h \ ../include/opendht/rng.h diff --git a/src/crypto.cpp b/src/crypto.cpp index 1164633932952bb3b3bed241ad1a073074df9ac0..e01259648119fbb020d71d23369110002d91bc3f 100644 --- a/src/crypto.cpp +++ b/src/crypto.cpp @@ -256,7 +256,7 @@ PrivateKey::getPublicKey() const PublicKey::PublicKey(const Blob& dat) : pk(nullptr) { - unpackBlob(dat); + unpack(dat.data(), dat.size()); } PublicKey::~PublicKey() @@ -286,17 +286,16 @@ PublicKey::pack(Blob& b) const if (err != GNUTLS_E_SUCCESS) throw CryptoException(std::string("Could not export public key: ") + gnutls_strerror(err)); tmp.resize(sz); - serialize<Blob>(tmp, b); + b.insert(b.end(), tmp.begin(), tmp.end()); } void -PublicKey::unpack(Blob::const_iterator& begin, Blob::const_iterator& end) +PublicKey::unpack(const uint8_t* data, size_t data_size) { - Blob tmp = deserialize<Blob>(begin, end); if (pk) gnutls_pubkey_deinit(pk); gnutls_pubkey_init(&pk); - const gnutls_datum_t dat {(uint8_t*)tmp.data(), (unsigned)tmp.size()}; + const gnutls_datum_t dat {(uint8_t*)data, (unsigned)data_size}; int err = gnutls_pubkey_import(pk, &dat, GNUTLS_X509_FMT_PEM); if (err != GNUTLS_E_SUCCESS) err = gnutls_pubkey_import(pk, &dat, GNUTLS_X509_FMT_DER); @@ -304,6 +303,14 @@ PublicKey::unpack(Blob::const_iterator& begin, Blob::const_iterator& end) throw CryptoException(std::string("Could not read public key: ") + gnutls_strerror(err)); } +void +PublicKey::msgpack_unpack(msgpack::object o) +{ + if (o.type != msgpack::type::BIN) + throw msgpack::type_error(); + unpack((const uint8_t*)o.via.bin.ptr, o.via.bin.size); +} + bool PublicKey::checkSignature(const Blob& data, const Blob& signature) const { if (!pk) @@ -363,7 +370,7 @@ PublicKey::getId() const Certificate::Certificate(const Blob& certData) : cert(nullptr) { - unpackBlob(certData); + unpack(certData.data(), certData.size()); } Certificate& @@ -378,7 +385,7 @@ Certificate::operator=(Certificate&& o) noexcept } void -Certificate::unpack(Blob::const_iterator& begin, Blob::const_iterator& end) +Certificate::unpack(const uint8_t* dat, size_t dat_size) { if (cert) { gnutls_x509_crt_deinit(cert); @@ -386,7 +393,7 @@ Certificate::unpack(Blob::const_iterator& begin, Blob::const_iterator& end) } gnutls_x509_crt_t* cert_list; unsigned cert_num; - const gnutls_datum_t crt_dt {(uint8_t*)&(*begin), (unsigned)(end-begin)}; + const gnutls_datum_t crt_dt {(uint8_t*)dat, (unsigned)dat_size}; int err = gnutls_x509_crt_list_import2(&cert_list, &cert_num, &crt_dt, GNUTLS_X509_FMT_PEM, GNUTLS_X509_CRT_LIST_FAIL_IF_UNSORTED); if (err != GNUTLS_E_SUCCESS) err = gnutls_x509_crt_list_import2(&cert_list, &cert_num, &crt_dt, GNUTLS_X509_FMT_DER, GNUTLS_X509_CRT_LIST_FAIL_IF_UNSORTED); @@ -405,6 +412,14 @@ Certificate::unpack(Blob::const_iterator& begin, Blob::const_iterator& end) gnutls_free(cert_list); } +void +Certificate::msgpack_unpack(msgpack::object o) +{ + if (o.type != msgpack::type::BIN) + throw msgpack::type_error(); + unpack((const uint8_t*)o.via.bin.ptr, o.via.bin.size); +} + void Certificate::pack(Blob& b) const { diff --git a/src/default_types.cpp b/src/default_types.cpp index 63b413f9a370e62865ee5e011d2234e13423c5e0..3ac4b653d5db6f883f3b3a3ef3d289d96fb606f7 100644 --- a/src/default_types.cpp +++ b/src/default_types.cpp @@ -38,29 +38,14 @@ std::ostream& operator<< (std::ostream& s, const DhtMessage& v) return s; } -void -DhtMessage::pack(Blob& res) const -{ - serialize<std::string>(service, res); - serialize<Blob>(data, res); -} - -void -DhtMessage::unpack(Blob::const_iterator& begin, Blob::const_iterator& end) -{ - service = deserialize<std::string>(begin, end); - data = deserialize<Blob>(begin, end); -} - bool DhtMessage::storePolicy(InfoHash, std::shared_ptr<Value>& v, InfoHash, const sockaddr*, socklen_t) { - DhtMessage request; try { - request.unpackBlob(v->data); + auto msg = unpack<DhtMessage>(v->data); + if (msg.service.empty()) + return false; } catch (const std::exception& e) {} - if (request.service.empty()) - return false; return true; } @@ -71,9 +56,7 @@ DhtMessage::ServiceFilter(std::string s) Value::TypeFilter(TYPE), [s](const Value& v) { try { - auto b = v.data.cbegin(), e = v.data.cend(); - auto service = deserialize<std::string>(b, e); - return service == s; + return unpack<DhtMessage>(v.data).service == s; } catch (const std::exception& e) { return false; } @@ -95,51 +78,20 @@ std::ostream& operator<< (std::ostream& s, const IpServiceAnnouncement& v) return s; } -void -IpServiceAnnouncement::pack(Blob& res) const -{ - serialize<in_port_t>(getPort(), res); - if (ss.ss_family == AF_INET) { - auto sa4 = reinterpret_cast<const sockaddr_in*>(&ss); - serialize<in_addr>(sa4->sin_addr, res); - } else if (ss.ss_family == AF_INET6) { - auto sa6 = reinterpret_cast<const sockaddr_in6*>(&ss); - serialize<in6_addr>(sa6->sin6_addr, res); - } -} - -void -IpServiceAnnouncement::unpack(Blob::const_iterator& begin, Blob::const_iterator& end) -{ - setPort(deserialize<in_port_t>(begin, end)); - size_t addr_size = end - begin; - if (addr_size < sizeof(in_addr)) { - ss.ss_family = 0; - } else if (addr_size == sizeof(in_addr)) { - auto sa4 = reinterpret_cast<sockaddr_in*>(&ss); - sa4->sin_family = AF_INET; - sa4->sin_addr = deserialize<in_addr>(begin, end); - } else if (addr_size == sizeof(in6_addr)) { - auto sa6 = reinterpret_cast<sockaddr_in6*>(&ss); - sa6->sin6_family = AF_INET6; - sa6->sin6_addr = deserialize<in6_addr>(begin, end); - } else { - throw std::runtime_error("ServiceAnnouncement parse error."); - } -} - bool IpServiceAnnouncement::storePolicy(InfoHash, std::shared_ptr<Value>& v, InfoHash, const sockaddr* from, socklen_t fromlen) { - IpServiceAnnouncement request {}; - request.unpackBlob(v->data); - if (request.getPort() == 0) - return false; - IpServiceAnnouncement sa_addr {from, fromlen}; - sa_addr.setPort(request.getPort()); - // argument v is modified (not the value). - v = std::make_shared<Value>(IpServiceAnnouncement::TYPE, sa_addr, v->id); - return true; + try { + auto msg = unpack<IpServiceAnnouncement>(v->data); + if (msg.getPort() == 0) + return false; + IpServiceAnnouncement sa_addr {from, fromlen}; + sa_addr.setPort(msg.getPort()); + // argument v is modified (not the value). + v = std::make_shared<Value>(IpServiceAnnouncement::TYPE, sa_addr, v->id); + return true; + } catch (const std::exception& e) {} + return false; } const ValueType DhtMessage::TYPE = {1, "DHT message", std::chrono::minutes(5), DhtMessage::storePolicy, ValueType::DEFAULT_EDIT_POLICY}; diff --git a/src/dht.cpp b/src/dht.cpp index 84a6a89e58b5fc0acf719f38c1066b2b1779a27c..348bcddd769f50ffbccc44e01f66d3cfe1db433f 100644 --- a/src/dht.cpp +++ b/src/dht.cpp @@ -27,6 +27,7 @@ THE SOFTWARE. #include "dht.h" #include "rng.h" +#include <msgpack.hpp> extern "C" { #include <gnutls/gnutls.h> } @@ -140,11 +141,11 @@ namespace dht { const Dht::TransPrefix Dht::TransPrefix::PING = {"pn"}; const Dht::TransPrefix Dht::TransPrefix::FIND_NODE = {"fn"}; -const Dht::TransPrefix Dht::TransPrefix::GET_VALUES = {"gp"}; -const Dht::TransPrefix Dht::TransPrefix::ANNOUNCE_VALUES = {"ap"}; -const Dht::TransPrefix Dht::TransPrefix::LISTEN = {"ls"}; +const Dht::TransPrefix Dht::TransPrefix::GET_VALUES = {"gt"}; +const Dht::TransPrefix Dht::TransPrefix::ANNOUNCE_VALUES = {"pt"}; +const Dht::TransPrefix Dht::TransPrefix::LISTEN = {"lt"}; -const uint8_t Dht::my_v[9] = "1:v4:RNG"; +const std::string my_v = "RNG"; static constexpr InfoHash zeroes {}; static constexpr InfoHash ones = {std::array<uint8_t, HASH_LEN>{ @@ -2180,26 +2181,6 @@ Dht::processMessage(const uint8_t *buf, size_t buflen, const sockaddr *from, soc if (buflen == 0) return; - //DHT_DEBUG("processMessage %p %lu %p %lu", buf, buflen, from, fromlen); - - MessageType message; - InfoHash id, info_hash, target; - TransId tid; - Blob token {}; - uint8_t nodes[26*16], nodes6[38*16]; - unsigned nodes_len = 26*16, nodes6_len = 38*16; - in_port_t port; - Value::Id value_id; - uint16_t error_code; - sockaddr_storage addr; - socklen_t addr_length = sizeof(sockaddr_storage); - - std::vector<std::shared_ptr<Value>> values; - - want_t want; - uint16_t ttid; - bool ring; - if (isMartian(from, fromlen)) return; @@ -2208,15 +2189,13 @@ Dht::processMessage(const uint8_t *buf, size_t buflen, const sockaddr *from, soc return; } - if (buf[buflen] != '\0') - throw DhtException("Unterminated message."); + //DHT_DEBUG("processMessage %p %lu %p %lu", buf, buflen, from, fromlen); + ParsedMessage msg; try { - message = parseMessage(buf, buflen, tid, id, info_hash, target, - port, token, value_id, - nodes, &nodes_len, nodes6, &nodes6_len, - values, &want, error_code, ring, (sockaddr*)&addr, addr_length); - if (message != MessageType::Error && id == zeroes) + msgpack::unpacked msg_res = msgpack::unpack((const char*)buf, buflen); + msg.msgpack_unpack(msg_res.get()); + if (msg.type != MessageType::Error && msg.id == zeroes) throw DhtException("no or invalid InfoHash"); } catch (const std::exception& e) { DHT_WARN("Can't process message of size %lu: %s.", buflen, e.what()); @@ -2224,16 +2203,12 @@ Dht::processMessage(const uint8_t *buf, size_t buflen, const sockaddr *from, soc return; } - // drop msg with unknown protocol - if (not ring) - return; - - if (id == myid) { + if (msg.id == myid) { DHT_DEBUG("Received message from self."); return; } - if (message > MessageType::Reply) { + if (msg.type > MessageType::Reply) { /* Rate limit requests. */ if (!rateLimit()) { DHT_WARN("Dropping request due to rate limiting."); @@ -2242,14 +2217,15 @@ Dht::processMessage(const uint8_t *buf, size_t buflen, const sockaddr *from, soc } //std::cout << "Message from " << id << " IPv" << (from->sa_family==AF_INET?'4':'6') << std::endl; + uint16_t ttid = 0; - switch (message) { + switch (msg.type) { case MessageType::Error: - if (tid.length != 4) return; - if (error_code == 401 && id != zeroes && (tid.matches(TransPrefix::ANNOUNCE_VALUES, &ttid) || tid.matches(TransPrefix::LISTEN, &ttid))) { + if (msg.tid.length != 4) return; + if (msg.error_code == 401 && msg.id != zeroes && (msg.tid.matches(TransPrefix::ANNOUNCE_VALUES, &ttid) || msg.tid.matches(TransPrefix::LISTEN, &ttid))) { auto esr = findSearch(ttid, from->sa_family); if (!esr) return; - auto ne = newNode(id, from, fromlen, 2); + auto ne = newNode(msg.id, from, fromlen, 2); unsigned cleared = 0; for (auto& sr : searches) { for (auto& n : sr.nodes) { @@ -2262,45 +2238,45 @@ Dht::processMessage(const uint8_t *buf, size_t buflen, const sockaddr *from, soc break; } } - DHT_WARN("Token flush for node %s (%d searches affected)", id.toString().c_str(), cleared); + DHT_WARN("Token flush for node %s (%d searches affected)",msg.id.toString().c_str(), cleared); } else { - DHT_WARN("Received unknown error message %u from %s:", error_code, id.toString().c_str()); + DHT_WARN("Received unknown error message %u from %s:", msg.error_code,msg.id.toString().c_str()); DHT_WARN.logPrintable(buf, buflen); } break; case MessageType::Reply: - if (tid.length != 4) { - DHT_ERROR("Broken node truncates transaction ids (len: %d): ", tid.length); + if (msg.tid.length != 4) { + DHT_ERROR("Broken node truncates transaction ids (len: %d): ", msg.tid.length); DHT_ERROR.logPrintable(buf, buflen); /* This is really annoying, as it means that we will time-out all our searches that go through this node. Kill it. */ - blacklistNode(&id, from, fromlen); + blacklistNode(&msg.id, from, fromlen); return; } - if (tid.matches(TransPrefix::PING)) { + if (msg.tid.matches(TransPrefix::PING)) { DHT_DEBUG("Pong!"); - newNode(id, from, fromlen, 2, (sockaddr*)&addr, addr_length); - } else if (tid.matches(TransPrefix::FIND_NODE) or tid.matches(TransPrefix::GET_VALUES)) { + newNode(msg.id, from, fromlen, 2, (sockaddr*)&msg.addr.first, msg.addr.second); + } else if (msg.tid.matches(TransPrefix::FIND_NODE) or msg.tid.matches(TransPrefix::GET_VALUES)) { bool gp = false; Search *sr = nullptr; std::shared_ptr<Node> n; - if (tid.matches(TransPrefix::GET_VALUES, &ttid)) { + if (msg.tid.matches(TransPrefix::GET_VALUES, &ttid)) { gp = true; sr = findSearch(ttid, from->sa_family); } - DHT_DEBUG("Nodes found (%u+%u)%s!", nodes_len/26, nodes6_len/38, gp ? " for get_values" : ""); - if (nodes_len % 26 != 0 || nodes6_len % 38 != 0) { + DHT_DEBUG("Nodes found (%u+%u)%s!", msg.nodes4.size()/26, msg.nodes6.size()/38, gp ? " for get_values" : ""); + if (msg.nodes4.size() % 26 != 0 || msg.nodes6.size() % 38 != 0) { DHT_WARN("Unexpected length for node info!"); - blacklistNode(&id, from, fromlen); + blacklistNode(&msg.id, from, fromlen); break; } else if (gp && sr == nullptr) { DHT_WARN("Unknown search with tid %u !", ttid); - n = newNode(id, from, fromlen, 1); + n = newNode(msg.id, from, fromlen, 1); } else { - n = newNode(id, from, fromlen, 2, (sockaddr*)&addr, addr_length); - for (unsigned i = 0; i < nodes_len / 26; i++) { - uint8_t *ni = nodes + i * 26; + n = newNode(msg.id, from, fromlen, 2, (sockaddr*)&msg.addr.first, msg.addr.second); + for (unsigned i = 0; i < msg.nodes4.size() / 26; i++) { + uint8_t *ni = msg.nodes4.data() + i * 26; const InfoHash& ni_id = *reinterpret_cast<InfoHash*>(ni); if (ni_id == myid) continue; @@ -2314,8 +2290,8 @@ Dht::processMessage(const uint8_t *buf, size_t buflen, const sockaddr *from, soc sr->insertNode(sn, now); } } - for (unsigned i = 0; i < nodes6_len / 38; i++) { - uint8_t *ni = nodes6 + i * 38; + for (unsigned i = 0; i < msg.nodes6.size() / 38; i++) { + uint8_t *ni = msg.nodes6.data() + i * 38; InfoHash* ni_id = reinterpret_cast<InfoHash*>(ni); if (*ni_id == myid) continue; @@ -2339,13 +2315,13 @@ Dht::processMessage(const uint8_t *buf, size_t buflen, const sockaddr *from, soc } } if (sr) { - sr->insertNode(n, now, token); - if (!values.empty()) { - DHT_DEBUG("Got %d values !", values.size()); + sr->insertNode(n, now, msg.token); + if (!msg.values.empty()) { + DHT_DEBUG("Got %d values !", msg.values.size()); for (auto& cb : sr->callbacks) { if (!cb.get_cb) continue; std::vector<std::shared_ptr<Value>> tmp; - std::copy_if(values.begin(), values.end(), std::back_inserter(tmp), [&](const std::shared_ptr<Value>& v) { + std::copy_if(msg.values.begin(), msg.values.end(), std::back_inserter(tmp), [&](const std::shared_ptr<Value>& v) { return not static_cast<bool>(cb.filter) or cb.filter(*v); }); if (not tmp.empty()) @@ -2355,7 +2331,7 @@ Dht::processMessage(const uint8_t *buf, size_t buflen, const sockaddr *from, soc for (auto& l : sr->listeners) { if (!l.second.get_cb) continue; std::vector<std::shared_ptr<Value>> tmp; - std::copy_if(values.begin(), values.end(), std::back_inserter(tmp), [&](const std::shared_ptr<Value>& v) { + std::copy_if(msg.values.begin(), msg.values.end(), std::back_inserter(tmp), [&](const std::shared_ptr<Value>& v) { return not static_cast<bool>(l.second.filter) or l.second.filter(*v); }); if (not tmp.empty()) @@ -2368,17 +2344,17 @@ Dht::processMessage(const uint8_t *buf, size_t buflen, const sockaddr *from, soc if (sr->isSynced(now)) search_time = now; } - } else if (tid.matches(TransPrefix::ANNOUNCE_VALUES, &ttid)) { + } else if (msg.tid.matches(TransPrefix::ANNOUNCE_VALUES, &ttid)) { DHT_DEBUG("Got reply to announce_values."); Search *sr = findSearch(ttid, from->sa_family); - if (!sr || value_id == Value::INVALID_ID) { + if (!sr || msg.value_id == Value::INVALID_ID) { DHT_DEBUG("Unknown search or announce!"); - newNode(id, from, fromlen, 1); + newNode(msg.id, from, fromlen, 1); } else { - auto n = newNode(id, from, fromlen, 2, (sockaddr*)&addr, addr_length); + auto n = newNode(msg.id, from, fromlen, 2, (sockaddr*)&msg.addr.first, msg.addr.second); for (auto& sn : sr->nodes) if (sn.node == n) { - auto it = sn.acked.emplace(value_id, SearchNode::RequestStatus{}); + auto it = sn.acked.emplace(msg.value_id, SearchNode::RequestStatus{}); it.first->second.reply_time = now; break; } @@ -2388,22 +2364,22 @@ Dht::processMessage(const uint8_t *buf, size_t buflen, const sockaddr *from, soc // If the value was just successfully announced, call the callback for (auto& a : sr->announce) { - if (!a.callback || !a.value || a.value->id != value_id) + if (!a.callback || !a.value || a.value->id != msg.value_id) continue; - if (sr->isAnnounced(value_id, getType(a.value->type), now)) { + if (sr->isAnnounced(msg.value_id, getType(a.value->type), now)) { a.callback(true, sr->getNodes()); a.callback = nullptr; } } } - } else if (tid.matches(TransPrefix::LISTEN, &ttid)) { + } else if (msg.tid.matches(TransPrefix::LISTEN, &ttid)) { DHT_DEBUG("Got reply to listen."); Search *sr = findSearch(ttid, from->sa_family); if (!sr) { DHT_DEBUG("Unknown search or announce!"); - newNode(id, from, fromlen, 1); + newNode(msg.id, from, fromlen, 1); } else { - auto n = newNode(id, from, fromlen, 2, (sockaddr*)&addr, addr_length); + auto n = newNode(msg.id, from, fromlen, 2, (sockaddr*)&msg.addr.first, msg.addr.second); for (auto& sn : sr->nodes) if (sn.node == n) { sn.listenStatus.reply_time = now; @@ -2419,73 +2395,72 @@ Dht::processMessage(const uint8_t *buf, size_t buflen, const sockaddr *from, soc } break; case MessageType::Ping: - //DHT_DEBUG("Got ping (%d)!", tid.length); - newNode(id, from, fromlen, 1); + newNode(msg.id, from, fromlen, 1); //DHT_DEBUG("Sending pong."); - sendPong(from, fromlen, tid); + sendPong(from, fromlen, msg.tid); break; case MessageType::FindNode: DHT_DEBUG("Got \"find node\" request"); - newNode(id, from, fromlen, 1); - DHT_DEBUG("Sending closest nodes (%d).", want); - sendClosestNodes(from, fromlen, tid, target, want); + newNode(msg.id, from, fromlen, 1); + DHT_DEBUG("Sending closest nodes (%d).", msg.want); + sendClosestNodes(from, fromlen, msg.tid, msg.target, msg.want); break; case MessageType::GetValues: DHT_DEBUG("Got \"get values\" request"); - newNode(id, from, fromlen, 1); - if (info_hash == zeroes) { - DHT_WARN("Eek! Got get_values with no info_hash from %s %s.", id.toString().c_str(), print_addr(from, fromlen).c_str()); - sendError(from, fromlen, tid, 203, "Get_values with no info_hash"); + newNode(msg.id, from, fromlen, 1); + if (msg.info_hash == zeroes) { + DHT_WARN("Eek! Got get_values with no info_hash from %s %s.", msg.id.toString().c_str(), print_addr(from, fromlen).c_str()); + sendError(from, fromlen, msg.tid, 203, "Get_values with no info_hash"); break; } else { - Storage* st = findStorage(info_hash); + Storage* st = findStorage(msg.info_hash); Blob ntoken = makeToken(from, false); if (st && st->values.size() > 0) { DHT_DEBUG("Sending found%s values.", from->sa_family == AF_INET6 ? " IPv6" : ""); - sendClosestNodes(from, fromlen, tid, info_hash, want, ntoken, st->values); + sendClosestNodes(from, fromlen, msg.tid, msg.info_hash, msg.want, ntoken, st->values); } else { DHT_DEBUG("Sending nodes for get_values."); - sendClosestNodes(from, fromlen, tid, info_hash, want, ntoken); + sendClosestNodes(from, fromlen, msg.tid, msg.info_hash, msg.want, ntoken); } } break; case MessageType::AnnounceValue: DHT_DEBUG("Got \"announce value\" request!"); - newNode(id, from, fromlen, 1); - if (info_hash == zeroes) { + newNode(msg.id, from, fromlen, 1); + if (msg.info_hash == zeroes) { DHT_WARN("Announce_value with no info_hash."); - sendError(from, fromlen, tid, 203, "Announce_value with no info_hash"); + sendError(from, fromlen, msg.tid, 203, "Announce_value with no info_hash"); break; } - if (!tokenMatch(token, from)) { - DHT_WARN("Incorrect token %s for announce_values.", to_hex(token.data(), token.size()).c_str()); - sendError(from, fromlen, tid, 401, "Announce_value with wrong token", true); + if (!tokenMatch(msg.token, from)) { + DHT_WARN("Incorrect token %s for announce_values.", to_hex(msg.token.data(), msg.token.size()).c_str()); + sendError(from, fromlen, msg.tid, 401, "Announce_value with wrong token", true); break; } - for (const auto& v : values) { + for (const auto& v : msg.values) { if (v->id == Value::INVALID_ID) { DHT_WARN("Incorrect value id "); - sendError(from, fromlen, tid, 203, "Announce_value with invalid id"); + sendError(from, fromlen, msg.tid, 203, "Announce_value with invalid id"); continue; } - auto lv = getLocalById(info_hash, v->id); + auto lv = getLocalById(msg.info_hash, v->id); std::shared_ptr<Value> vc = v; if (lv) { const auto& type = getType(lv->type); - if (type.editPolicy(info_hash, lv, vc, id, from, fromlen)) { - DHT_DEBUG("Editing value of type %s belonging to %s at %s.", type.name.c_str(), v->owner.getId().toString().c_str(), info_hash.toString().c_str()); - storageStore(info_hash, vc); + if (type.editPolicy(msg.info_hash, lv, vc, msg.id, from, fromlen)) { + DHT_DEBUG("Editing value of type %s belonging to %s at %s.", type.name.c_str(), v->owner.getId().toString().c_str(), msg.info_hash.toString().c_str()); + storageStore(msg.info_hash, vc); } else { - DHT_WARN("Rejecting edition of type %s belonging to %s at %s because of storage policy.", type.name.c_str(), v->owner.getId().toString().c_str(), info_hash.toString().c_str()); + DHT_WARN("Rejecting edition of type %s belonging to %s at %s because of storage policy.", type.name.c_str(), v->owner.getId().toString().c_str(), msg.info_hash.toString().c_str()); } } else { // Allow the value to be edited by the storage policy const auto& type = getType(vc->type); - if (type.storePolicy(info_hash, vc, id, from, fromlen)) { - DHT_DEBUG("Storing value of type %s belonging to %s at %s.", type.name.c_str(), v->owner.getId().toString().c_str(), info_hash.toString().c_str()); - storageStore(info_hash, vc); + if (type.storePolicy(msg.info_hash, vc, msg.id, from, fromlen)) { + DHT_DEBUG("Storing value of type %s belonging to %s at %s.", type.name.c_str(), v->owner.getId().toString().c_str(), msg.info_hash.toString().c_str()); + storageStore(msg.info_hash, vc); } else { - DHT_WARN("Rejecting storage of type %s belonging to %s at %s because of storage policy.", type.name.c_str(), v->owner.getId().toString().c_str(), info_hash.toString().c_str()); + DHT_WARN("Rejecting storage of type %s belonging to %s at %s because of storage policy.", type.name.c_str(), v->owner.getId().toString().c_str(), msg.info_hash.toString().c_str()); } } @@ -2493,26 +2468,26 @@ Dht::processMessage(const uint8_t *buf, size_t buflen, const sockaddr *from, soc This is to prevent them from backtracking, and hence polluting the DHT. */ DHT_DEBUG("Sending announceValue confirmation."); - sendValueAnnounced(from, fromlen, tid, v->id); + sendValueAnnounced(from, fromlen, msg.tid, v->id); } break; case MessageType::Listen: - if (info_hash == zeroes) { + if (msg.info_hash == zeroes) { DHT_WARN("Listen with no info_hash."); - sendError(from, fromlen, tid, 203, "Listen with no info_hash"); + sendError(from, fromlen, msg.tid, 203, "Listen with no info_hash"); break; } - if (!tokenMatch(token, from)) { - DHT_WARN("Incorrect token %s for announce_values.", to_hex(token.data(), token.size()).c_str()); - sendError(from, fromlen, tid, 401, "Listen with wrong token", true); + if (!tokenMatch(msg.token, from)) { + DHT_WARN("Incorrect token %s for announce_values.", to_hex(msg.token.data(), msg.token.size()).c_str()); + sendError(from, fromlen, msg.tid, 401, "Listen with wrong token", true); break; } - if (!tid.matches(TransPrefix::LISTEN, &ttid)) { + if (!msg.tid.matches(TransPrefix::LISTEN, &ttid)) { break; } - newNode(id, from, fromlen, 1); - storageAddListener(info_hash, id, from, fromlen, ttid); - sendListenConfirmation(from, fromlen, tid); + newNode(msg.id, from, fromlen, 1); + storageAddListener(msg.info_hash, msg.id, from, fromlen, ttid); + sendListenConfirmation(from, fromlen, msg.tid); break; } } @@ -2586,11 +2561,16 @@ Dht::exportValues() const for (const auto& h : store) { ValuesExport ve; ve.first = h.id; - serialize<uint16_t>(h.values.size(), ve.second); + + msgpack::sbuffer buffer; + msgpack::packer<msgpack::sbuffer> pk(&buffer); + pk.pack_array(h.values.size()); for (const auto& v : h.values) { - serialize<time_point>(v.time, ve.second); - v.data->pack(ve.second); + pk.pack_array(2); + pk.pack(v.time.time_since_epoch().count()); + v.data->msgpack_pack(pk); } + ve.second = {buffer.data(), buffer.data()+buffer.size()}; e.push_back(std::move(ve)); } return e; @@ -2602,16 +2582,22 @@ Dht::importValues(const std::vector<ValuesExport>& import) for (const auto& h : import) { if (h.second.empty()) continue; - auto b = h.second.begin(), - e = h.second.end(); + try { - const size_t n_vals = deserialize<uint16_t>(b, e); - for (unsigned i = 0; i < n_vals; i++) { + msgpack::unpacked msg; + msgpack::unpack(&msg, (const char*)h.second.data(), h.second.size()); + auto valarr = msg.get(); + if (valarr.type != msgpack::type::ARRAY) + throw msgpack::type_error(); + for (unsigned i = 0; i < valarr.via.array.size; i++) { + auto& valel = valarr.via.array.ptr[i]; + if (valel.via.array.size < 2) + throw msgpack::type_error(); time_point val_time; Value tmp_val; try { - val_time = deserialize<time_point>(b, e); - tmp_val.unpack(b, e); + val_time = time_point{time_point::duration{valel.via.array.ptr[0].as<time_point::duration::rep>()}}; + tmp_val.msgpack_unpack(valel.via.array.ptr[1]); } catch (const std::exception&) { DHT_ERROR("Error reading value at %s", h.first.toString().c_str()); continue; @@ -2679,24 +2665,22 @@ Dht::pingNode(const sockaddr *sa, socklen_t salen) return sendPing(sa, salen, TransId {TransPrefix::PING}); } -/* We could use a proper bencoding printer and parser, but the format of - DHT messages is fairly stylised, so this seemed simpler. */ - -#define CHECK(offset, delta, size) \ - if (offset + delta > size) throw std::length_error("Provided buffer is not large enough."); - -#define INC(offset, delta, size) \ - if (delta < 0) throw std::length_error("Provided buffer is not large enough."); \ - CHECK(offset, (size_t)delta, size); \ - offset += delta - -#define COPY(buf, offset, src, delta, size) \ - CHECK(offset, delta, size); \ - memcpy(buf + offset, src, delta); \ - offset += delta; +void +insertAddr(msgpack::packer<msgpack::sbuffer>& pk, const sockaddr *sa, socklen_t) +{ + size_t addr_len = (sa->sa_family == AF_INET) ? sizeof(in_addr) : sizeof(in6_addr); + void* addr_ptr = (sa->sa_family == AF_INET) ? (void*)&((sockaddr_in*)sa)->sin_addr + : (void*)&((sockaddr_in6*)sa)->sin6_addr; + pk.pack("sa"); + pk.pack_bin(addr_len); + pk.pack_bin_body((char*)addr_ptr, addr_len); +} -#define ADD_V(buf, offset, size) \ - COPY(buf, offset, my_v, sizeof(my_v), size); +void +insertV(msgpack::packer<msgpack::sbuffer>& pk) +{ + pk.pack("v"); pk.pack(my_v); +} int Dht::send(const char *buf, size_t len, int flags, const sockaddr *sa, socklen_t salen) @@ -2725,67 +2709,66 @@ Dht::send(const char *buf, size_t len, int flags, const sockaddr *sa, socklen_t int Dht::sendPing(const sockaddr *sa, socklen_t salen, TransId tid) { - char buf[512]; - int i = 0, rc; - rc = snprintf(buf + i, 512 - i, "d1:ad2:id20:"); INC(i, rc, 512); - COPY(buf, i, myid.data(), myid.size(), 512); - rc = snprintf(buf + i, 512 - i, "e1:q4:ping1:t%d:", tid.length); - INC(i, rc, 512); - COPY(buf, i, tid.data(), tid.length, 512); - ADD_V(buf, i, 512); - rc = snprintf(buf + i, 512 - i, "1:y1:qe"); INC(i, rc, 512); - return send(buf, i, 0, sa, salen); -} + msgpack::sbuffer buffer; + msgpack::packer<msgpack::sbuffer> pk(&buffer); + pk.pack_map(5); -void -insertAddr(char* buf, size_t buflen, size_t& p, const sockaddr *sa, socklen_t) -{ - size_t addr_len = (sa->sa_family == AF_INET) ? sizeof(in_addr) : sizeof(in6_addr); - void* addr_ptr = (sa->sa_family == AF_INET) ? (void*)&((sockaddr_in*)sa)->sin_addr - : (void*)&((sockaddr_in6*)sa)->sin6_addr; - int rc = snprintf(buf + p, buflen - p, "2:sa%lu:", addr_len); - INC(p, rc, buflen); - COPY(buf, p, addr_ptr, addr_len, buflen); + pk.pack(std::string("a")); pk.pack_map(1); + pk.pack(std::string("id")); pk.pack(myid); + + pk.pack(std::string("q")); pk.pack(std::string("ping")); + pk.pack(std::string("t")); pk.pack_bin(tid.size()); + pk.pack_bin_body((const char*)tid.data(), tid.size()); + pk.pack(std::string("y")); pk.pack(std::string("q")); + pk.pack(std::string("v")); pk.pack(std::string("RNG1")); + + return send(buffer.data(), buffer.size(), 0, sa, salen); } int Dht::sendPong(const sockaddr *sa, socklen_t salen, TransId tid) { - char buf[512]; - size_t i = 0; - auto rc = snprintf(buf + i, 512 - i, "d1:rd2:id20:"); INC(i, rc, 512); - COPY(buf, i, myid.data(), myid.size(), 512); - insertAddr(buf, 512, i, sa, salen); - rc = snprintf(buf + i, 512 - i, "e1:t%d:", tid.length); INC(i, rc, 512); - COPY(buf, i, tid.data(), tid.length, 512); - ADD_V(buf, i, 512); - rc = snprintf(buf + i, 512 - i, "1:y1:re"); INC(i, rc, 512); - return send(buf, i, 0, sa, salen); + msgpack::sbuffer buffer; + msgpack::packer<msgpack::sbuffer> pk(&buffer); + pk.pack_map(4); + + pk.pack(std::string("r")); pk.pack_map(2); + pk.pack(std::string("id")); pk.pack(myid); + insertAddr(pk, sa, salen); + + pk.pack(std::string("t")); pk.pack_bin(tid.size()); + pk.pack_bin_body((const char*)tid.data(), tid.size()); + pk.pack(std::string("y")); pk.pack(std::string("r")); + pk.pack(std::string("v")); pk.pack(std::string("RNG1")); + + return send(buffer.data(), buffer.size(), 0, sa, salen); } int Dht::sendFindNode(const sockaddr *sa, socklen_t salen, TransId tid, const InfoHash& target, want_t want, int confirm) { - constexpr const size_t BUF_SZ = 512; - char buf[BUF_SZ]; - int i = 0, rc; - rc = snprintf(buf + i, BUF_SZ - i, "d1:ad2:id20:"); INC(i, rc, BUF_SZ); - COPY(buf, i, myid.data(), myid.size(), BUF_SZ); - rc = snprintf(buf + i, BUF_SZ - i, "6:target20:"); INC(i, rc, BUF_SZ); - COPY(buf, i, target.data(), target.size(), BUF_SZ); + msgpack::sbuffer buffer; + msgpack::packer<msgpack::sbuffer> pk(&buffer); + pk.pack_map(5); + + pk.pack(std::string("a")); pk.pack_map(2 + (want>0?1:0)); + pk.pack(std::string("id")); pk.pack(myid); + pk.pack(std::string("target")); pk.pack(target); if (want > 0) { - rc = snprintf(buf + i, BUF_SZ - i, "4:wantl%s%se", - (want & WANT4) ? "2:n4" : "", - (want & WANT6) ? "2:n6" : ""); - INC(i, rc, BUF_SZ); + pk.pack(std::string("w")); + pk.pack_array(((want & WANT4)?1:0) + ((want & WANT6)?1:0)); + if (want & WANT4) pk.pack(AF_INET); + if (want & WANT6) pk.pack(AF_INET6); } - rc = snprintf(buf + i, BUF_SZ - i, "e1:q9:find_node1:t%d:", tid.length); - INC(i, rc, BUF_SZ); - COPY(buf, i, tid.data(), tid.length, BUF_SZ); - ADD_V(buf, i, BUF_SZ); - rc = snprintf(buf + i, BUF_SZ - i, "1:y1:qe"); INC(i, rc, BUF_SZ); - return send(buf, i, confirm ? MSG_CONFIRM : 0, sa, salen); + + pk.pack(std::string("q")); pk.pack(std::string("find")); + pk.pack(std::string("t")); pk.pack_bin(tid.size()); + pk.pack_bin_body((const char*)tid.data(), tid.size()); + pk.pack(std::string("y")); pk.pack(std::string("q")); + pk.pack(std::string("v")); pk.pack(std::string("RNG1")); + + return send(buffer.data(), buffer.size(), confirm ? 0 : MSG_CONFIRM, sa, salen); } int @@ -2794,56 +2777,51 @@ Dht::sendNodesValues(const sockaddr *sa, socklen_t salen, TransId tid, const uint8_t *nodes6, unsigned nodes6_len, const std::vector<ValueStorage>& st, const Blob& token) { - constexpr const size_t BUF_SZ = 2048 * 64; - char buf[BUF_SZ]; - size_t i = 0; + msgpack::sbuffer buffer; + msgpack::packer<msgpack::sbuffer> pk(&buffer); + pk.pack_map(4); - auto rc = snprintf(buf + i, BUF_SZ - i, "d1:rd2:id20:"); INC(i, rc, BUF_SZ); - COPY(buf, i, myid.data(), myid.size(), BUF_SZ); - insertAddr(buf, BUF_SZ, i, sa, salen); + pk.pack(std::string("r")); + pk.pack_map(2 + (not st.empty()?1:0) + (nodes_len>0?1:0) + (nodes6_len>0?1:0) + (not token.empty()?1:0)); + pk.pack(std::string("id")); pk.pack(myid); + insertAddr(pk, sa, salen); if (nodes_len > 0) { - rc = snprintf(buf + i, BUF_SZ - i, "5:nodes%u:", nodes_len); - INC(i, rc, BUF_SZ); - COPY(buf, i, nodes, nodes_len, BUF_SZ); + pk.pack(std::string("n4")); + pk.pack_bin(nodes_len); + pk.pack_bin_body((const char*)nodes, nodes_len); } if (nodes6_len > 0) { - rc = snprintf(buf + i, BUF_SZ - i, "6:nodes6%u:", nodes6_len); - INC(i, rc, BUF_SZ); - COPY(buf, i, nodes6, nodes6_len, BUF_SZ); + pk.pack(std::string("n6")); + pk.pack_bin(nodes6_len); + pk.pack_bin_body((const char*)nodes6, nodes6_len); } if (not token.empty()) { - rc = snprintf(buf + i, BUF_SZ - i, "5:token%lu:", token.size()); - INC(i, rc, BUF_SZ); - COPY(buf, i, token.data(), token.size(), BUF_SZ); + pk.pack(std::string("token")); pk.pack(token); } - - if (st.size() > 0) { - /* We treat the storage as a circular list, and serve a randomly - chosen slice. In order to make sure we fit, - we limit ourselves to 50 values. */ + if (not st.empty()) { + // We treat the storage as a circular list, and serve a randomly + // chosen slice. In order to make sure we fit, + // we limit ourselves to 50 values. std::uniform_int_distribution<> pos_dis(0, st.size()-1); unsigned j0 = pos_dis(rd); unsigned j = j0; unsigned k = 0; - rc = snprintf(buf + i, BUF_SZ - i, "6:valuesl"); INC(i, rc, BUF_SZ); + pk.pack(std::string("values")); + pk.pack_array(std::min(st.size(), 50ul)); do { - Blob packed_value; - st[j].data->pack(packed_value); - rc = snprintf(buf + i, BUF_SZ - i, "%lu:", packed_value.size()); INC(i, rc, BUF_SZ); - COPY(buf, i, packed_value.data(), packed_value.size(), BUF_SZ); + pk.pack(st[j].data); k++; j = (j + 1) % st.size(); } while (j != j0 && k < 50); - rc = snprintf(buf + i, BUF_SZ - i, "e"); INC(i, rc, BUF_SZ); } - rc = snprintf(buf + i, BUF_SZ - i, "e1:t%d:", tid.length); INC(i, rc, BUF_SZ); - COPY(buf, i, tid.data(), tid.length, BUF_SZ); - ADD_V(buf, i, BUF_SZ); - rc = snprintf(buf + i, BUF_SZ - i, "1:y1:re"); INC(i, rc, BUF_SZ); + pk.pack(std::string("t")); pk.pack_bin(tid.size()); + pk.pack_bin_body((const char*)tid.data(), tid.size()); + pk.pack(std::string("y")); pk.pack(std::string("r")); + pk.pack(std::string("v")); pk.pack(std::string("RNG1")); - return send(buf, i, 0, sa, salen); + return send(buffer.data(), buffer.size(), 0, sa, salen); } unsigned @@ -2951,69 +2929,68 @@ Dht::sendGetValues(const sockaddr *sa, socklen_t salen, TransId tid, const InfoHash& infohash, want_t want, int confirm) { - static constexpr const size_t BUF_SZ = 2048 * 4; - char buf[BUF_SZ]; - size_t i = 0; - int rc; + msgpack::sbuffer buffer; + msgpack::packer<msgpack::sbuffer> pk(&buffer); + pk.pack_map(5); - rc = snprintf(buf + i, BUF_SZ - i, "d1:ad2:id20:"); INC(i, rc, BUF_SZ); - COPY(buf, i, myid.data(), myid.size(), BUF_SZ); - rc = snprintf(buf + i, BUF_SZ - i, "9:info_hash20:"); INC(i, rc, BUF_SZ); - COPY(buf, i, infohash.data(), infohash.size(), BUF_SZ); + pk.pack(std::string("a")); pk.pack_map(2 + (want>0?1:0)); + pk.pack(std::string("id")); pk.pack(myid); + pk.pack(std::string("h")); pk.pack(infohash); if (want > 0) { - rc = snprintf(buf + i, BUF_SZ - i, "4:wantl%s%se", - (want & WANT4) ? "2:n4" : "", - (want & WANT6) ? "2:n6" : ""); - INC(i, rc, BUF_SZ); + pk.pack(std::string("w")); + pk.pack_array(((want & WANT4)?1:0) + ((want & WANT6)?1:0)); + if (want & WANT4) pk.pack(AF_INET); + if (want & WANT6) pk.pack(AF_INET6); } - rc = snprintf(buf + i, BUF_SZ - i, "e1:q9:get_peers1:t%d:", tid.length); - INC(i, rc, BUF_SZ); - COPY(buf, i, tid.data(), tid.length, BUF_SZ); - ADD_V(buf, i, BUF_SZ); - rc = snprintf(buf + i, BUF_SZ - i, "1:y1:qe"); INC(i, rc, BUF_SZ); - return send(buf, i, confirm ? MSG_CONFIRM : 0, sa, salen); + + pk.pack(std::string("q")); pk.pack(std::string("get")); + pk.pack(std::string("t")); pk.pack_bin(tid.size()); + pk.pack_bin_body((const char*)tid.data(), tid.size()); + pk.pack(std::string("y")); pk.pack(std::string("q")); + pk.pack(std::string("v")); pk.pack(std::string("RNG1")); + + return send(buffer.data(), buffer.size(), confirm ? 0 : MSG_CONFIRM, sa, salen); } int Dht::sendListen(const sockaddr* sa, socklen_t salen, TransId tid, const InfoHash& infohash, const Blob& token, int confirm) { - static constexpr const size_t BUF_SZ = 2048; - char buf[BUF_SZ]; - size_t i = 0; - int rc; + msgpack::sbuffer buffer; + msgpack::packer<msgpack::sbuffer> pk(&buffer); + pk.pack_map(5); - rc = snprintf(buf + i, BUF_SZ - i, "d1:ad2:id%lu:", myid.size()); INC(i, rc, BUF_SZ); - COPY(buf, i, myid.data(), myid.size(), BUF_SZ); - rc = snprintf(buf + i, BUF_SZ - i, "9:info_hash%lu:", infohash.size()); INC(i, rc, BUF_SZ); - COPY(buf, i, infohash.data(), infohash.size(), BUF_SZ); + pk.pack(std::string("a")); pk.pack_map(3); + pk.pack(std::string("id")); pk.pack(myid); + pk.pack(std::string("h")); pk.pack(infohash); + pk.pack(std::string("token")); pk.pack(token); - rc = snprintf(buf + i, BUF_SZ - i, "e5:token%lu:", token.size()); INC(i, rc, BUF_SZ); - COPY(buf, i, token.data(), token.size(), BUF_SZ); - rc = snprintf(buf + i, BUF_SZ - i, "e1:q6:listen1:t%u:", tid.length); INC(i, rc, BUF_SZ); - COPY(buf, i, tid.data(), tid.length, BUF_SZ); - ADD_V(buf, i, BUF_SZ); - rc = snprintf(buf + i, BUF_SZ - i, "1:y1:qe"); INC(i, rc, BUF_SZ); + pk.pack(std::string("q")); pk.pack(std::string("listen")); + pk.pack(std::string("t")); pk.pack_bin(tid.size()); + pk.pack_bin_body((const char*)tid.data(), tid.size()); + pk.pack(std::string("y")); pk.pack(std::string("q")); + pk.pack(std::string("v")); pk.pack(std::string("RNG1")); - return send(buf, i, confirm ? 0 : MSG_CONFIRM, sa, salen); + return send(buffer.data(), buffer.size(), confirm ? 0 : MSG_CONFIRM, sa, salen); } int Dht::sendListenConfirmation(const sockaddr* sa, socklen_t salen, TransId tid) { - static constexpr const size_t BUF_SZ = 512; - char buf[BUF_SZ]; - size_t i = 0; + msgpack::sbuffer buffer; + msgpack::packer<msgpack::sbuffer> pk(&buffer); + pk.pack_map(4); + + pk.pack(std::string("r")); pk.pack_map(2); + pk.pack(std::string("id")); pk.pack(myid); + insertAddr(pk, sa, salen); - auto rc = snprintf(buf + i, BUF_SZ - i, "d1:rd2:id20:"); INC(i, rc, BUF_SZ); - COPY(buf, i, myid.data(), myid.size(), BUF_SZ); - insertAddr(buf, BUF_SZ, i, sa, salen); + pk.pack(std::string("t")); pk.pack_bin(tid.size()); + pk.pack_bin_body((const char*)tid.data(), tid.size()); + pk.pack(std::string("y")); pk.pack(std::string("r")); + pk.pack(std::string("v")); pk.pack(std::string("RNG1")); - rc = snprintf(buf + i, BUF_SZ - i, "e1:t%u:", tid.length); INC(i, rc, BUF_SZ); - COPY(buf, i, tid.data(), tid.length, BUF_SZ); - ADD_V(buf, i, BUF_SZ); - rc = snprintf(buf + i, BUF_SZ - i, "1:y1:re"); INC(i, rc, BUF_SZ); - return send(buf, i, 0, sa, salen); + return send(buffer.data(), buffer.size(), 0, sa, salen); } int @@ -3021,292 +2998,210 @@ Dht::sendAnnounceValue(const sockaddr *sa, socklen_t salen, TransId tid, const InfoHash& infohash, const Value& value, const Blob& token, int confirm) { - constexpr const size_t BUF_SZ = 2048 * 4; - char buf[BUF_SZ]; - size_t i = 0; + msgpack::sbuffer buffer; + msgpack::packer<msgpack::sbuffer> pk(&buffer); + pk.pack_map(5); - int rc = snprintf(buf + i, BUF_SZ - i, "d1:ad2:id%lu:", myid.size()); INC(i, rc, BUF_SZ); - COPY(buf, i, myid.data(), myid.size(), BUF_SZ); - rc = snprintf(buf + i, BUF_SZ - i, "9:info_hash%lu:", infohash.size()); INC(i, rc, BUF_SZ); - COPY(buf, i, infohash.data(), infohash.size(), BUF_SZ); + pk.pack(std::string("a")); pk.pack_map(4); + pk.pack(std::string("id")); pk.pack(myid); + pk.pack(std::string("h")); pk.pack(infohash); + pk.pack(std::string("values")); pk.pack_array(1); pk.pack(value); + pk.pack(std::string("token")); pk.pack(token); - Blob packed_value; - value.pack(packed_value); - rc = snprintf(buf + i, BUF_SZ - i, "6:valuesl%lu:", packed_value.size()); INC(i, rc, BUF_SZ); - COPY(buf, i, packed_value.data(), packed_value.size(), BUF_SZ); - rc = snprintf(buf + i, BUF_SZ - i, "e5:token%lu:", token.size()); INC(i, rc, BUF_SZ); - COPY(buf, i, token.data(), token.size(), BUF_SZ); - rc = snprintf(buf + i, BUF_SZ - i, "e1:q13:announce_peer1:t%u:", tid.length); INC(i, rc, BUF_SZ); - COPY(buf, i, tid.data(), tid.length, BUF_SZ); - ADD_V(buf, i, BUF_SZ); - rc = snprintf(buf + i, BUF_SZ - i, "1:y1:qe"); INC(i, rc, BUF_SZ); + pk.pack(std::string("q")); pk.pack(std::string("put")); + pk.pack(std::string("t")); pk.pack_bin(tid.size()); + pk.pack_bin_body((const char*)tid.data(), tid.size()); + pk.pack(std::string("y")); pk.pack(std::string("q")); + pk.pack(std::string("v")); pk.pack(std::string("RNG1")); - return send(buf, i, confirm ? 0 : MSG_CONFIRM, sa, salen); + return send(buffer.data(), buffer.size(), confirm ? 0 : MSG_CONFIRM, sa, salen); } int Dht::sendValueAnnounced(const sockaddr *sa, socklen_t salen, TransId tid, Value::Id vid) { - char buf[512]; - size_t i = 0; + msgpack::sbuffer buffer; + msgpack::packer<msgpack::sbuffer> pk(&buffer); + pk.pack_map(4); + + pk.pack(std::string("r")); pk.pack_map(3); + pk.pack(std::string("id")); pk.pack(myid); + pk.pack(std::string("vid")); pk.pack(vid); + insertAddr(pk, sa, salen); - auto rc = snprintf(buf + i, 512 - i, "d1:rd2:id20:"); INC(i, rc, 512); - COPY(buf, i, myid.data(), myid.size(), 512); - insertAddr(buf, 512, i, sa, salen); + pk.pack(std::string("t")); pk.pack_bin(tid.size()); + pk.pack_bin_body((const char*)tid.data(), tid.size()); + pk.pack(std::string("y")); pk.pack(std::string("r")); + pk.pack(std::string("v")); pk.pack(std::string("RNG1")); - rc = snprintf(buf + i, 512 - i, "3:vid%lu:", sizeof(Value::Id)); INC(i, rc, 512); - COPY(buf, i, &vid, sizeof(Value::Id), 512); - rc = snprintf(buf + i, 512 - i, "e1:t%u:", tid.length); INC(i, rc, 512); - COPY(buf, i, tid.data(), tid.length, 512); - ADD_V(buf, i, 512); - rc = snprintf(buf + i, 512 - i, "1:y1:re"); INC(i, rc, 512); - return send(buf, i, 0, sa, salen); + return send(buffer.data(), buffer.size(), 0, sa, salen); } int Dht::sendError(const sockaddr *sa, socklen_t salen, TransId tid, uint16_t code, const char *message, bool include_id) { - constexpr const size_t BUF_SZ = 512; - char buf[BUF_SZ]; - int i = 0, rc; - - size_t msg_len = strlen(message); - rc = snprintf(buf + i, BUF_SZ - i, "d1:eli%ue%lu:", code, msg_len); - INC(i, rc, BUF_SZ); - COPY(buf, i, message, msg_len, BUF_SZ); - rc = snprintf(buf + i, BUF_SZ - i, "e1:t%d:", tid.length); INC(i, rc, BUF_SZ); - COPY(buf, i, tid.data(), tid.length, BUF_SZ); - ADD_V(buf, i, BUF_SZ); + msgpack::sbuffer buffer; + msgpack::packer<msgpack::sbuffer> pk(&buffer); + pk.pack_map(4 + (include_id?1:0)); + + pk.pack(std::string("e")); pk.pack_array(2); + pk.pack(code); + pk.pack_str(strlen(message)); + pk.pack_str_body(message, strlen(message)); + if (include_id) { - rc = snprintf(buf + i, BUF_SZ - i, "1:rd2:id20:"); INC(i, rc, BUF_SZ); - COPY(buf, i, myid.data(), myid.size(), BUF_SZ); - COPY(buf, i, "e", 1u, BUF_SZ); - } - rc = snprintf(buf + i, BUF_SZ - i, "1:y1:ee"); INC(i, rc, BUF_SZ); - return send(buf, i, 0, sa, salen); -} - -#undef CHECK -#undef INC -#undef COPY -#undef ADD_V - -Dht::MessageType -Dht::parseMessage(const uint8_t *buf, size_t buflen, - TransId& tid_return, - InfoHash& id_return, InfoHash& info_hash_return, - InfoHash& target_return, in_port_t& port_return, - Blob& token, Value::Id& value_id, - uint8_t *nodes_return, unsigned *nodes_len, - uint8_t *nodes6_return, unsigned *nodes6_len, - std::vector<std::shared_ptr<Value>>& values_return, - want_t* want_return, uint16_t& error_code, bool& ring, - sockaddr* addr_return, socklen_t& addr_length_return) -{ - const uint8_t *p; - - /* This code will happily crash if the buffer is not NUL-terminated. */ - if (buf[buflen] != '\0') - throw DhtException("Eek! parse_message with unterminated buffer."); - -#define CHECK(ptr, len) if (((uint8_t*)ptr) + (len) > (buf) + (buflen)) throw std::out_of_range("Truncated message."); - - p = (uint8_t*)dht_memmem(buf, buflen, "1:t", 3); - if (p) { - char *q; - size_t l = strtoul((char*)p + 3, &q, 10); - if (q && *q == ':') { - CHECK(q + 1, l); - tid_return = {q+1, l}; - } else - tid_return.length = 0; - } - - p = (uint8_t*)dht_memmem(buf, buflen, "2:id20:", 7); - if (p) { - CHECK(p + 7, HASH_LEN); - memcpy(id_return.data(), p + 7, HASH_LEN); - } else { - id_return = {}; - } - - if (addr_return and addr_length_return) { - p = (uint8_t*)dht_memmem(buf, buflen, "2:sa", 4); - if (p) { - char *q; - size_t l = strtoul((char*)p + 4, &q, 10); - if (q && *q == ':' && (l == sizeof(in_addr) or l == sizeof(in6_addr))) { - CHECK(q + 1, l); - if (l == sizeof(in_addr)) { - auto addr = (sockaddr_in*)addr_return; - std::fill_n((uint8_t*)addr, sizeof(sockaddr_in), 0); - addr->sin_family = AF_INET; - addr->sin_port = 0; - memcpy(&addr->sin_addr, q+1, l); - addr_length_return = sizeof(sockaddr_in); - } else if (l == sizeof(in6_addr)) { - auto addr = (sockaddr_in6*)addr_return; - std::fill_n((uint8_t*)addr, sizeof(sockaddr_in6), 0); - addr_return->sa_family = AF_INET6; - addr->sin6_port = 0; - memcpy(&addr->sin6_addr, q+1, l); - addr_length_return = sizeof(sockaddr_in6); - } - } else - addr_length_return = 0; - } else - addr_length_return = 0; + pk.pack(std::string("r")); pk.pack_map(1); + pk.pack(std::string("id")); pk.pack(myid); } - p = (uint8_t*)dht_memmem(buf, buflen, "9:info_hash20:", 14); - if (p) { - CHECK(p + 14, HASH_LEN); - memcpy(info_hash_return.data(), p + 14, HASH_LEN); - } else { - info_hash_return = {}; + pk.pack(std::string("t")); pk.pack_bin(tid.size()); + pk.pack_bin_body((const char*)tid.data(), tid.size()); + pk.pack(std::string("y")); pk.pack(std::string("e")); + pk.pack(std::string("v")); pk.pack(std::string("RNG1")); + + return send(buffer.data(), buffer.size(), 0, sa, salen); +} + +msgpack::object* +findMapValue(msgpack::object& map, const std::string& key) { + if (map.type != msgpack::type::MAP) throw msgpack::type_error(); + for (unsigned i = 0; i < map.via.map.size; i++) { + auto& o = map.via.map.ptr[i]; + if(o.key.type != msgpack::type::STR) + continue; + if (o.key.as<std::string>() == key) { + return &o.val; + } } + return nullptr; +} - p = (uint8_t*)dht_memmem(buf, buflen, "porti", 5); - if (p) { - char *q; - unsigned long l = strtoul((char*)p + 5, &q, 10); - if (q && *q == 'e' && l < 0x10000) - port_return = l; - else - port_return = 0; - } else - port_return = 0; +void +Dht::ParsedMessage::msgpack_unpack(msgpack::object msg) +{ + auto y = findMapValue(msg, "y"); + auto a = findMapValue(msg, "a"); + auto r = findMapValue(msg, "r"); + auto e = findMapValue(msg, "e"); - p = (uint8_t*)dht_memmem(buf, buflen, "6:target20:", 11); - if (p) { - CHECK(p + 11, HASH_LEN); - memcpy(target_return.data(), p + 11, HASH_LEN); - } else { - target_return = {}; + std::string query; + if (auto q = findMapValue(msg, "q")) { + if (q->type != msgpack::type::STR) + throw msgpack::type_error(); + query = q->as<std::string>(); } - p = (uint8_t*)dht_memmem(buf, buflen, "5:token", 7); - if (p) { - char *q; - size_t l = strtoul((char*)p + 7, &q, 10); - if (q && *q == ':' && l > 0 && l <= 128) { - CHECK(q + 1, l); - token.clear(); - token.insert(token.begin(), q + 1, q + 1 + l); - } + auto& req = a ? *a : (r ? *r : *e); + if (not &req) + throw msgpack::type_error(); + + if (e) { + if (e->type != msgpack::type::ARRAY) + throw msgpack::type_error(); + error_code = e->via.array.ptr[0].as<uint16_t>(); } - if (nodes_len) { - p = (uint8_t*)dht_memmem(buf, buflen, "5:nodes", 7); - if (p) { - char *q; - size_t l = strtoul((char*)p + 7, &q, 10); - if (q && *q == ':' && l > 0 && l <= *nodes_len) { - CHECK(q + 1, l); - memcpy(nodes_return, q + 1, l); - *nodes_len = l; - } else - *nodes_len = 0; - } else - *nodes_len = 0; - } - - if (nodes6_len) { - p = (uint8_t*)dht_memmem(buf, buflen, "6:nodes6", 8); - if (p) { - char *q; - size_t l = strtoul((char*)p + 8, &q, 10); - if (q && *q == ':' && l > 0 && l <= *nodes6_len) { - CHECK(q + 1, l); - memcpy(nodes6_return, q + 1, l); - *nodes6_len = l; - } else - *nodes6_len = 0; - } else - *nodes6_len = 0; - } - - p = (uint8_t*)dht_memmem(buf, buflen, "6:valuesl", 9); - if (p) { - unsigned i = p - buf + 9; - while (true) { - char *q; - size_t l = strtoul((char*)buf + i, &q, 10); - if (q && *q == ':' && l > 0) { - CHECK(q + 1, l); - i = q + 1 + l - (char*)buf; - Value v; - v.unpackBlob(Blob {q + 1, q + 1 + l}); - values_return.push_back(std::make_shared<Value>(std::move(v))); - } else - break; - } - if (i >= buflen || buf[i] != 'e') - DHT_DEBUG("eek... unexpected end for values."); + if (auto rid = findMapValue(req, "id")) + id = {*rid}; + + if (auto rh = findMapValue(req, "h")) + info_hash = {*rh}; + + if (auto rtarget = findMapValue(req, "target")) + target = {*rtarget}; + + if (auto otoken = findMapValue(req, "token")) + token = otoken->as<Blob>(); + + if (auto vid = findMapValue(req, "vid")) + value_id = vid->as<Value::Id>(); + + if (auto rnodes4 = findMapValue(req, "n4")) { + auto n4b = rnodes4->as<std::vector<char>>(); + nodes4 = {n4b.begin(), n4b.end()}; } - p = (uint8_t*)dht_memmem(buf, buflen, "3:vid8:", 7); - if (p) { - CHECK(p + 7, sizeof(value_id)); - memcpy(&value_id, p + 7, sizeof(value_id)); - } else { - value_id = Value::INVALID_ID; - } - - if (want_return) { - p = (uint8_t*)dht_memmem(buf, buflen, "4:wantl", 7); - if (p) { - unsigned i = p - buf + 7; - *want_return = 0; - while (buf[i] > '0' && buf[i] <= '9' && buf[i + 1] == ':' && - i + 2 + buf[i] - '0' < buflen) { - CHECK(buf + i + 2, buf[i] - '0'); - if (buf[i] == '2' && memcmp(buf + i + 2, "n4", 2) == 0) - *want_return |= WANT4; - else if (buf[i] == '2' && memcmp(buf + i + 2, "n6", 2) == 0) - *want_return |= WANT6; - else - DHT_DEBUG("eek... unexpected want flag (%c)", buf[i]); - i += 2 + buf[i] - '0'; - } - if (i >= buflen || buf[i] != 'e') - DHT_DEBUG("eek... unexpected end for want."); - } else { - *want_return = -1; + if (auto rnodes6 = findMapValue(req, "n6")) { + auto n6b = rnodes6->as<std::vector<char>>(); + nodes6 = {n6b.begin(), n6b.end()}; + } + + if (auto sa = findMapValue(req, "sa")) { + if (sa->type != msgpack::type::BIN) + throw msgpack::type_error(); + auto l = sa->via.bin.size; + if (l == sizeof(in_addr)) { + auto a = (sockaddr_in*)&addr.first; + std::fill_n((uint8_t*)a, sizeof(sockaddr_in), 0); + a->sin_family = AF_INET; + a->sin_port = 0; + std::copy_n(sa->via.bin.ptr, l, (char*)&a->sin_addr); + addr.second = sizeof(sockaddr_in); + } else if (l == sizeof(in6_addr)) { + auto a = (sockaddr_in6*)&addr.first; + std::fill_n((uint8_t*)a, sizeof(sockaddr_in6), 0); + a->sin6_family = AF_INET6; + a->sin6_port = 0; + std::copy_n(sa->via.bin.ptr, l, (char*)&a->sin6_addr); + addr.second = sizeof(sockaddr_in6); } + } else + addr.second = 0; + + if (auto rvalues = findMapValue(req, "values")) { + if (rvalues->type != msgpack::type::ARRAY) + throw msgpack::type_error(); + for (size_t i = 0; i < rvalues->via.array.size; i++) + try { + values.emplace_back(std::make_shared<Value>(rvalues->via.array.ptr[i])); + } catch (const std::exception& e) { + //DHT_WARN("Error reading value: %s", e.what()); + std::cout << "Error reading value: " << e.what() << std::endl; + } } - p = (uint8_t*)dht_memmem(buf, buflen, "1:eli", 5); - if (p) { - char *q; - unsigned long l = strtoul((char*)p + 5, &q, 10); - if (q && *q == 'e' && l < 0x10000) - error_code = l; + if (auto w = findMapValue(req, "w")) { + if (w->type != msgpack::type::ARRAY) + throw msgpack::type_error(); + want = 0; + for (unsigned i=0; i<w->via.array.size; i++) { + auto& val = w->via.array.ptr[i]; + try { + auto w = val.as<sa_family_t>(); + if (w == AF_INET) + want |= WANT4; + else if(w == AF_INET6) + want |= WANT6; + } catch (const std::exception& e) {}; + } } else { - error_code = 0; - } - -#undef CHECK - - ring = dht_memmem(buf, buflen, my_v, sizeof(my_v)); - - if (dht_memmem(buf, buflen, "1:y1:r", 6)) - return MessageType::Reply; - if (dht_memmem(buf, buflen, "1:y1:e", 6)) - return MessageType::Error; - if (!dht_memmem(buf, buflen, "1:y1:q", 6)) - throw DhtException("Parse error"); - if (dht_memmem(buf, buflen, "1:q4:ping", 9)) - return MessageType::Ping; - if (dht_memmem(buf, buflen, "1:q9:find_node", 14)) - return MessageType::FindNode; - if (dht_memmem(buf, buflen, "1:q9:get_peers", 14)) - return MessageType::GetValues; - if (dht_memmem(buf, buflen, "1:q13:announce_peer", 19)) - return MessageType::AnnounceValue; - if (dht_memmem(buf, buflen, "1:q6:listen", 11)) - return MessageType::Listen; - throw DhtException("Can't read message type."); + want = -1; + } + + if (auto t = findMapValue(msg, "t")) + tid = {t->as<std::array<char, 4>>()}; + + if (auto rv = findMapValue(msg, "v")) + ua = rv->as<std::string>(); + + if (r) + type = MessageType::Reply; + else if (e) + type = MessageType::Error; + else if (y and y->as<std::string>() != "q") + throw msgpack::type_error(); + else if (query == "ping") + type = MessageType::Ping; + else if (query == "find") + type = MessageType::FindNode; + else if (query == "get") + type = MessageType::GetValues; + else if (query == "listen") + type = MessageType::Listen; + else if (query == "put") + type = MessageType::AnnounceValue; + else + throw msgpack::type_error(); } #ifdef HAVE_MEMMEM diff --git a/src/dhtrunner.cpp b/src/dhtrunner.cpp index 39865a72576320dd91277ee0e9437be5e072d81c..779f6485db1da99bc1530dbbab81a0cea49a92fa 100644 --- a/src/dhtrunner.cpp +++ b/src/dhtrunner.cpp @@ -241,13 +241,12 @@ DhtRunner::doRun(const sockaddr_in* sin4, const sockaddr_in6* sin6, SecureDht::C if(rc > 0) { fromlen = sizeof(from); if(s4 >= 0 && FD_ISSET(s4, &readfds)) - rc = recvfrom(s4, (char*)buf, sizeof(buf) - 1, 0, (struct sockaddr*)&from, &fromlen); + rc = recvfrom(s4, (char*)buf, sizeof(buf), 0, (struct sockaddr*)&from, &fromlen); else if(s6 >= 0 && FD_ISSET(s6, &readfds)) - rc = recvfrom(s6, (char*)buf, sizeof(buf) - 1, 0, (struct sockaddr*)&from, &fromlen); + rc = recvfrom(s6, (char*)buf, sizeof(buf), 0, (struct sockaddr*)&from, &fromlen); else break; if (rc > 0) { - buf[rc] = 0; { std::lock_guard<std::mutex> lck(sock_mtx); rcv.emplace_back(Blob {buf, buf+rc+1}, std::make_pair(from, fromlen)); diff --git a/src/securedht.cpp b/src/securedht.cpp index 54a84de808b223194858e73303851a0c11dcb675..d2c2ed4cd05cb008710c78b764be8ef1f073d611 100644 --- a/src/securedht.cpp +++ b/src/securedht.cpp @@ -238,12 +238,8 @@ SecureDht::getCallbackFilter(GetCallback cb, Value::Filter&& filter) try { Value decrypted_val (decrypt(*v)); if (decrypted_val.recipient == getId()) { - if (decrypted_val.owner.checkSignature(decrypted_val.getToSign(), decrypted_val.signature)) { - if (not filter or filter(decrypted_val)) - tmpvals.push_back(std::make_shared<Value>(std::move(decrypted_val))); - } - else - DHT_WARN("Signature verification failed for %s", v->toString().c_str()); + if (not filter or filter(decrypted_val)) + tmpvals.push_back(std::make_shared<Value>(std::move(decrypted_val))); } // Ignore values belonging to other people } catch (const std::exception& e) { @@ -343,17 +339,16 @@ SecureDht::putEncrypted(const InfoHash& hash, const InfoHash& to, std::shared_pt void SecureDht::sign(Value& v) const { - if (v.flags.isEncrypted()) + if (v.isEncrypted()) throw DhtException("Can't sign encrypted data."); v.owner = key_->getPublicKey(); - v.flags = Value::ValueFlags(true, false, v.flags[2]); v.signature = key_->sign(v.getToSign()); } Value SecureDht::encrypt(Value& v, const crypto::PublicKey& to) const { - if (v.flags.isEncrypted()) + if (v.isEncrypted()) throw DhtException("Data is already encrypted."); v.setRecipient(to.getId()); sign(v); @@ -365,12 +360,18 @@ SecureDht::encrypt(Value& v, const crypto::PublicKey& to) const Value SecureDht::decrypt(const Value& v) { - if (not v.flags.isEncrypted()) + if (not v.isEncrypted()) throw DhtException("Data is not encrypted."); + auto decrypted = key_->decrypt(v.cypher); + Value ret {v.id}; - auto pb = decrypted.cbegin(), pe = decrypted.cend(); - ret.unpackBody(pb, pe); + auto msg = msgpack::unpack((const char*)decrypted.data(), decrypted.size()); + ret.msgpack_unpack_body(msg.get()); + + if (not ret.owner.checkSignature(ret.getToSign(), ret.signature)) + throw crypto::DecryptError("Signature mismatch"); + return ret; } diff --git a/src/value.cpp b/src/value.cpp index dc6c50f07a5a66d0328d846d58b650c5ed7f0b0d..ac79dbac78526e0f614fa3366d13acaa03fe9744 100644 --- a/src/value.cpp +++ b/src/value.cpp @@ -38,11 +38,17 @@ namespace dht { std::ostream& operator<< (std::ostream& s, const Value& v) { s << "Value[id:" << std::hex << v.id << std::dec << " "; - if (v.flags.isSigned()) + if (v.isSigned()) s << "signed (v" << v.seq << ") "; - if (v.flags.isEncrypted()) + if (v.isEncrypted()) s << "encrypted "; - else { + else if (v.isSigned()) { + if (v.recipient == InfoHash()) + s << "signed (v" << v.seq << ") "; + else + s << "decrypted "; + } + if (not v.isEncrypted()) { if (v.type == IpServiceAnnouncement::TYPE.id) { s << IpServiceAnnouncement(v.data); } else if (v.type == CERTIFICATE_TYPE.id) { @@ -57,7 +63,7 @@ std::ostream& operator<< (std::ostream& s, const Value& v) s << "Data (type: " << v.type << " ): "; s << std::hex; for (size_t i=0; i<v.data.size(); i++) - s << std::setfill('0') << std::setw(2) << (unsigned)v.data[i]; + s << std::setfill('0') << std::setw(2) << (unsigned)v.data[i] << " "; s << std::dec; } } @@ -68,62 +74,45 @@ std::ostream& operator<< (std::ostream& s, const Value& v) const ValueType ValueType::USER_DATA = {0, "User Data"}; -void -Value::packToSign(Blob& res) const -{ - res.push_back(flags.to_ulong()); - if (flags.isEncrypted()) { - res.insert(res.end(), cypher.begin(), cypher.end()); - } else { - if (flags.isSigned()) { - serialize<decltype(seq)>(seq, res); - owner.pack(res); - if (flags.haveRecipient()) - res.insert(res.end(), recipient.begin(), recipient.end()); - } - serialize<ValueType::Id>(type, res); - serialize<Blob>(data, res); - } -} - -Blob -Value::getToSign() const -{ - Blob ret; - packToSign(ret); - return ret; -} - -/** - * Pack part of the data to be encrypted - */ -void -Value::packToEncrypt(Blob& res) const -{ - packToSign(res); - if (!flags.isEncrypted() && flags.isSigned()) - serialize<Blob>(signature, res); +msgpack::unpacked +unpack(Blob b) { + return msgpack::unpack((const char*)b.data(), b.size()); } -Blob -Value::getToEncrypt() const -{ - Blob ret; - packToEncrypt(ret); - return ret; +msgpack::object* +findMapValue(const msgpack::object& map, const std::string& key) { + if (map.type != msgpack::type::MAP) throw msgpack::type_error(); + for (unsigned i = 0; i < map.via.map.size; i++) { + auto& o = map.via.map.ptr[i]; + if(o.key.type != msgpack::type::STR) + continue; + if (o.key.as<std::string>() == key) { + return &o.val; + } + } + return nullptr; } void -Value::pack(Blob& res) const +Value::msgpack_unpack(msgpack::object o) { - serialize<Id>(id, res); - packToEncrypt(res); + if (o.type != msgpack::type::MAP) throw msgpack::type_error(); + if (o.via.map.size < 2) throw msgpack::type_error(); + + if (auto rid = findMapValue(o, "id")) { + id = rid->as<Id>(); + } else + throw msgpack::type_error(); + + if (auto rdat = findMapValue(o, "dat")) { + msgpack_unpack_body(*rdat); + } else + throw msgpack::type_error(); } void -Value::unpackBody(Blob::const_iterator& begin, Blob::const_iterator& end) +Value::msgpack_unpack_body(const msgpack::object& o) { - // clear optional fields owner = {}; recipient = {}; cypher.clear(); @@ -131,39 +120,48 @@ Value::unpackBody(Blob::const_iterator& begin, Blob::const_iterator& end) data.clear(); type = 0; - flags = {deserialize<uint8_t>(begin, end)}; - if (flags.isEncrypted()) { - cypher = {begin, end}; - begin = end; + if (o.type == msgpack::type::BIN) { + auto dat = o.as<std::vector<char>>(); + cypher = {dat.begin(), dat.end()}; } else { - if(flags.isSigned()) { - seq = deserialize<decltype(seq)>(begin, end); - owner.unpack(begin, end); - if (flags.haveRecipient()) - recipient = deserialize<InfoHash>(begin, end); + if (o.type != msgpack::type::MAP) + throw msgpack::type_error(); + auto rbody = findMapValue(o, "body"); + if (not rbody) + throw msgpack::type_error(); + + if (auto rdata = findMapValue(*rbody, "data")) { + auto dat = rdata->as<std::vector<char>>(); + data = {dat.begin(), dat.end()}; + } else + throw msgpack::type_error(); + + if (auto rtype = findMapValue(*rbody, "type")) { + type = rtype->as<ValueType::Id>(); + } else + throw msgpack::type_error(); + + if (auto rutype = findMapValue(*rbody, "utype")) { + user_type = rutype->as<std::string>(); } - type = deserialize<ValueType::Id>(begin, end); - data = deserialize<Blob>(begin, end); - if (flags.isSigned()) - signature = deserialize<Blob>(begin, end); - } -} - -void -Value::unpack(Blob::const_iterator& begin, Blob::const_iterator& end) -{ - id = deserialize<Id>(begin, end); - unpackBody(begin, end); -} -void -ValueSerializable::unpackValue(const Value& v) { - unpackBlob(v.data); -} + if (auto rowner = findMapValue(*rbody, "owner")) { + if (auto rseq = findMapValue(*rbody, "seq")) + seq = rseq->as<decltype(seq)>(); + else + throw msgpack::type_error(); + owner.msgpack_unpack(*rowner); + if (auto rrecipient = findMapValue(*rbody, "to")) { + recipient = rrecipient->as<InfoHash>(); + } -Value -ValueSerializable::packValue() const { - return Value {getType(), *this}; + if (auto rsig = findMapValue(o, "sig")) { + auto dat = rsig->as<std::vector<char>>(); + signature = {dat.begin(), dat.end()}; + } else + throw msgpack::type_error(); + } + } } }