diff --git a/CMakeLists.txt b/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..463ce66f1b9994221757ecf621cff8d05b224ee2 --- /dev/null +++ b/CMakeLists.txt @@ -0,0 +1,83 @@ +cmake_minimum_required (VERSION 2.8.6) +project (opendht) + +list(APPEND CMAKE_MODULE_PATH "${PROJECT_SOURCE_DIR}/cmake") + +set (prefix ${CMAKE_INSTALL_PREFIX}) +set (exec_prefix "\${prefix}") +set (libdir "\${exec_prefix}/lib") +set (includedir "\${prefix}/include") + +option (OPENDHT_PYTHON "Build Python bindings" OFF) +option (OPENDHT_TOOLS "Build DHT tools" ON) +option (OPENDHT_DEBUG "Build with debug flags" OFF) + +set (CMAKE_CXX_FLAGS "-std=c++11 -Wno-return-type -Wall -Wextra -Wnon-virtual-dtor ${CMAKE_CXX_FLAGS}") + +find_package (GnuTLS 3.1 REQUIRED) +find_package (Msgpack 1.1 REQUIRED) + +list (APPEND opendht_SOURCES + src/infohash.cpp + src/crypto.cpp + src/default_types.cpp + src/value.cpp + src/dht.cpp + src/securedht.cpp + src/dhtrunner.cpp +) + +list (APPEND opendht_HEADERS + include/opendht/rng.h + include/opendht/crypto.h + include/opendht/infohash.h + include/opendht/default_types.h + include/opendht/value.h + include/opendht/dht.h + include/opendht/securedht.h + include/opendht.h +) + +configure_file ( + opendht.pc.in + opendht.pc + @ONLY +) + +include_directories ( + ./ + include/ + include/opendht/ + ${CMAKE_CURRENT_BINARY_DIR}/include/ +) + +if (OPENDHT_DEBUG) + set(CMAKE_BUILD_TYPE Debug) +else () + set(CMAKE_BUILD_TYPE Release) +endif () + +add_library (opendht SHARED + ${opendht_SOURCES} + ${opendht_HEADERS} +) +set_target_properties (opendht PROPERTIES IMPORT_SUFFIX "_import.lib") +#set_target_properties (opendht PROPERTIES SOVERSION 1 VERSION 1.0.0) + +add_library (opendht-static STATIC + ${opendht_SOURCES} + ${opendht_HEADERS} +) +set_target_properties (opendht-static PROPERTIES OUTPUT_NAME "opendht") + +if (NOT DEFINED CMAKE_INSTALL_LIBDIR) + set(CMAKE_INSTALL_LIBDIR lib) +endif () + +if (OPENDHT_TOOLS) + add_subdirectory(tools) +endif () + +install (TARGETS opendht opendht-static DESTINATION ${CMAKE_INSTALL_LIBDIR}) +install (DIRECTORY include DESTINATION ${CMAKE_INSTALL_PREFIX}) +install (FILES ${CMAKE_CURRENT_BINARY_DIR}/opendht.pc DESTINATION ${CMAKE_INSTALL_LIBDIR}/pkgconfig) diff --git a/cmake/FindMsgpack.cmake b/cmake/FindMsgpack.cmake new file mode 100644 index 0000000000000000000000000000000000000000..7d8813791568c0a79df2ddb339ad3e9db9975c9f --- /dev/null +++ b/cmake/FindMsgpack.cmake @@ -0,0 +1,48 @@ +# - Try to find msgpack +# Once done this will define +# MSGPACK_FOUND - System has msgpack +# MSGPACK_INCLUDE_DIRS - The msgpack include directories +# MSGPACK_LIBRARIES - The libraries needed to use msgpack + +if(NOT MSGPACK_USE_BUNDLED) + find_package(PkgConfig) + if (PKG_CONFIG_FOUND) + pkg_check_modules(PC_MSGPACK QUIET msgpack) + endif() +else() + set(PC_MSGPACK_INCLUDEDIR) + set(PC_MSGPACK_INCLUDE_DIRS) + set(PC_MSGPACK_LIBDIR) + set(PC_MSGPACK_LIBRARY_DIRS) + set(LIMIT_SEARCH NO_DEFAULT_PATH) +endif() + +set(MSGPACK_DEFINITIONS ${PC_MSGPACK_CFLAGS_OTHER}) + +find_path(MSGPACK_INCLUDE_DIR msgpack.h + HINTS ${PC_MSGPACK_INCLUDEDIR} ${PC_MSGPACK_INCLUDE_DIRS} + ${LIMIT_SEARCH}) + +# If we're asked to use static linkage, add libmsgpack.a as a preferred library name. +if(MSGPACK_USE_STATIC) + list(APPEND MSGPACK_NAMES + "${CMAKE_STATIC_LIBRARY_PREFIX}msgpack${CMAKE_STATIC_LIBRARY_SUFFIX}") +endif() + +list(APPEND MSGPACK_NAMES msgpack) + +find_library(MSGPACK_LIBRARY NAMES ${MSGPACK_NAMES} + HINTS ${PC_MSGPACK_LIBDIR} ${PC_MSGPACK_LIBRARY_DIRS} + ${LIMIT_SEARCH}) + +mark_as_advanced(MSGPACK_INCLUDE_DIR MSGPACK_LIBRARY) + +set(MSGPACK_LIBRARIES ${MSGPACK_LIBRARY}) +set(MSGPACK_INCLUDE_DIRS ${MSGPACK_INCLUDE_DIR}) + +include(FindPackageHandleStandardArgs) +# handle the QUIETLY and REQUIRED arguments and set MSGPACK_FOUND to TRUE +# if all listed variables are TRUE +find_package_handle_standard_args(Msgpack DEFAULT_MSG + MSGPACK_LIBRARY MSGPACK_INCLUDE_DIR) + diff --git a/configure.ac b/configure.ac index 3e156ce5539f738fba15368fa5a6e03a501fcd25..c4a3f058128950b5ff6179a3d998ac0a4568380c 100644 --- a/configure.ac +++ b/configure.ac @@ -72,7 +72,9 @@ LT_LANG(C++) AX_CXX_COMPILE_STDCXX_11([noext],[mandatory]) PKG_PROG_PKG_CONFIG() +PKG_CHECK_MODULES([nettle], [nettle >= 2.4]) PKG_CHECK_MODULES([GNUTLS], [gnutls >= 3.1]) +PKG_CHECK_MODULES([msgpack], [msgpack >= 1.1]) AC_ARG_ENABLE([tools], AS_HELP_STRING([--disable-tools],[Disable tools (CLI DHT node)]),,build_tools=yes) AM_CONDITIONAL(ENABLE_TOOLS, test x$build_tools == xyes) diff --git a/include/opendht/crypto.h b/include/opendht/crypto.h index 3df04c9bbfaf504bdecf7e3b545d0697e9f7d580..e1daf2560d978a6b091b38cb875e7af487d82df0 100644 --- a/include/opendht/crypto.h +++ b/include/opendht/crypto.h @@ -31,7 +31,6 @@ #pragma once #include "infohash.h" -#include "serialize.h" extern "C" { #include <gnutls/gnutls.h> @@ -42,6 +41,8 @@ extern "C" { #include <vector> #include <memory> +typedef std::vector<uint8_t> Blob; + namespace dht { namespace crypto { @@ -75,7 +76,7 @@ Identity generateIdentity(const std::string& name = "dhtnode", Identity ca = {}, /** * A public key. */ -struct PublicKey : public Serializable +struct PublicKey { PublicKey() {} PublicKey(gnutls_pubkey_t k) : pk(k) {} @@ -91,14 +92,25 @@ struct PublicKey : public Serializable bool checkSignature(const Blob& data, const Blob& signature) const; Blob encrypt(const Blob&) const; - void pack(Blob& b) const override; + void pack(Blob& b) const; + void unpack(const uint8_t* dat, size_t dat_size); - void unpack(Blob::const_iterator& begin, Blob::const_iterator& end) override; + template <typename Packer> + void msgpack_pack(Packer& p) const + { + Blob b; + pack(b); + p.pack_bin(b.size()); + p.pack_bin_body((const char*)b.data(), b.size()); + } + + void msgpack_unpack(msgpack::object o); gnutls_pubkey_t pk {}; private: PublicKey(const PublicKey&) = delete; PublicKey& operator=(const PublicKey&) = delete; + void encryptBloc(const uint8_t* src, size_t src_size, uint8_t* dst, size_t dst_size) const; }; /** @@ -134,6 +146,7 @@ struct PrivateKey /** * Generate a new RSA key pair * @param key_length : size of the modulus in bits + * Minimim value: 2048 * Recommended values: 4096, 8192 */ static PrivateKey generate(unsigned key_length = 4096); @@ -143,11 +156,12 @@ struct PrivateKey private: PrivateKey(const PrivateKey&) = delete; PrivateKey& operator=(const PrivateKey&) = delete; + Blob decryptBloc(const uint8_t* src, size_t src_size) const; friend dht::crypto::Identity dht::crypto::generateIdentity(const std::string&, dht::crypto::Identity, unsigned key_length); }; -struct Certificate : public Serializable { +struct Certificate { Certificate() {} /** @@ -155,11 +169,16 @@ struct Certificate : public Serializable { */ Certificate(gnutls_x509_crt_t crt) : cert(crt) {} + Certificate(Certificate&& o) noexcept : cert(o.cert), issuer(std::move(o.issuer)) { o.cert = nullptr; }; + /** * Import certificate (PEM or DER) or certificate chain (PEM), * ordered from subject to issuer */ Certificate(const Blob& crt); + Certificate(const uint8_t* dat, size_t dat_size) { + unpack(dat, dat_size); + } /** * Import certificate chain (PEM or DER), @@ -179,12 +198,16 @@ struct Certificate : public Serializable { unpack(certs); } - Certificate(Certificate&& o) noexcept : cert(o.cert), issuer(std::move(o.issuer)) { o.cert = nullptr; }; Certificate& operator=(Certificate&& o) noexcept; ~Certificate(); - void pack(Blob& b) const override; - void unpack(Blob::const_iterator& begin, Blob::const_iterator& end) override; + void pack(Blob& b) const; + void unpack(const uint8_t* dat, size_t dat_size); + Blob getPacked() const { + Blob b; + pack(b); + return b; + } template<typename Iterator> void unpack(const Iterator& begin, const Iterator& end) @@ -227,6 +250,17 @@ struct Certificate : public Serializable { *this = tmp_issuer ? std::move(*tmp_issuer) : Certificate(); } + template <typename Packer> + void msgpack_pack(Packer& p) const + { + Blob b; + pack(b); + p.pack_bin(b.size()); + p.pack_bin_body((const char*)b.data(), b.size()); + } + + void msgpack_unpack(msgpack::object o); + operator bool() const { return cert; } PublicKey getPublicKey() const; @@ -270,6 +304,15 @@ private: friend dht::crypto::Identity dht::crypto::generateIdentity(const std::string&, dht::crypto::Identity, unsigned key_length); }; +/** + * AES-GCM encryption. Key must be 128, 192 or 126 bits long (16, 24 or 32 bytes). + */ +Blob aesEncrypt(const Blob& data, const Blob& key); + +/** + * AES-GCM decryption. + */ +Blob aesDecrypt(const Blob& data, const Blob& key); } } diff --git a/include/opendht/default_types.h b/include/opendht/default_types.h index bfffd58f541c43d85588a310fb58c5cb4ef03148..99fd70bebcc60e602d8360c0224e1b0bf0791805 100644 --- a/include/opendht/default_types.h +++ b/include/opendht/default_types.h @@ -34,17 +34,14 @@ namespace dht { -struct DhtMessage : public ValueSerializable +struct DhtMessage : public Value::Serializable<DhtMessage> { DhtMessage(std::string s = {}, Blob msg = {}) : service(s), data(msg) {} - + std::string getService() const { return service; } - virtual void pack(Blob& res) const; - virtual void unpack(Blob::const_iterator& begin, Blob::const_iterator& end); - static const ValueType TYPE; virtual const ValueType& getType() const { return TYPE; @@ -61,14 +58,15 @@ struct DhtMessage : public ValueSerializable public: std::string service; Blob data; + MSGPACK_DEFINE(service, data); }; - -struct SignedValue : public ValueSerializable +template <typename Type> +struct SignedValue : public Value::Serializable<Type> { virtual void unpackValue(const Value& v) { from = v.owner.getId(); - ValueSerializable::unpackValue(v); + Value::Serializable<Type>::unpackValue(v); } static Value::Filter getFilter() { return [](const Value& v){ return v.isSigned(); }; @@ -77,15 +75,16 @@ public: dht::InfoHash from; }; -struct EncryptedValue : public SignedValue +template <typename Type> +struct EncryptedValue : public SignedValue<Type> { virtual void unpackValue(const Value& v) { to = v.recipient; - SignedValue::unpackValue(v); + SignedValue<Type>::unpackValue(v); } static Value::Filter getFilter() { return Value::Filter::chain( - SignedValue::getFilter(), + SignedValue<Type>::getFilter(), [](const Value& v){ return v.recipient != InfoHash(); } ); } @@ -94,7 +93,7 @@ public: dht::InfoHash to; }; -struct ImMessage : public SignedValue +struct ImMessage : public SignedValue<ImMessage> { ImMessage() {} ImMessage(std::string&& msg) @@ -107,14 +106,6 @@ struct ImMessage : public SignedValue static Value::Filter getFilter() { return SignedValue::getFilter(); } - virtual void pack(Blob& data) const { - serialize<std::chrono::system_clock::time_point>(std::chrono::system_clock::now(), data); - data.insert(data.end(), im_message.begin(), im_message.end()); - } - virtual void unpack(Blob::const_iterator& b, Blob::const_iterator& e) { - sent = deserialize<decltype(sent)>(b, e); - im_message = std::string(b, e); - } virtual void unpackValue(const Value& v) { to = v.recipient; SignedValue::unpackValue(v); @@ -123,9 +114,10 @@ struct ImMessage : public SignedValue dht::InfoHash to; std::chrono::system_clock::time_point sent; std::string im_message; + MSGPACK_DEFINE(im_message); }; -struct TrustRequest : public EncryptedValue +struct TrustRequest : public EncryptedValue<TrustRequest> { TrustRequest() {} TrustRequest(std::string s) : service(s) {} @@ -138,22 +130,16 @@ struct TrustRequest : public EncryptedValue static Value::Filter getFilter() { return EncryptedValue::getFilter(); } - virtual void pack(Blob& data) const { - serialize<std::string>(service, data); - serialize<Blob>(payload, data); - } - virtual void unpack(Blob::const_iterator& b, Blob::const_iterator& e) { - service = deserialize<std::string>(b, e); - payload = deserialize<Blob>(b, e); - } + std::string service; Blob payload; + MSGPACK_DEFINE(service, payload); }; -struct IceCandidates : public EncryptedValue +struct IceCandidates : public EncryptedValue<IceCandidates> { IceCandidates() {} - IceCandidates(Blob ice) : ice_data(ice) {} + IceCandidates(Value::Id msg_id, Blob ice) : id(msg_id), ice_data(ice) {} static const ValueType TYPE; virtual const ValueType& getType() const { @@ -162,25 +148,16 @@ struct IceCandidates : public EncryptedValue static Value::Filter getFilter() { return EncryptedValue::getFilter(); } - virtual void pack(Blob& data) const { - serialize<Blob>(ice_data, data); - } - virtual void unpack(Blob::const_iterator& b, Blob::const_iterator& e) { - ice_data = deserialize<Blob>(b, e); - } - virtual void unpackValue(const Value& v) { - EncryptedValue::unpackValue(v); - id = v.id; - } Value::Id id; Blob ice_data; + MSGPACK_DEFINE(id, ice_data); }; /* "Peer" announcement */ -struct IpServiceAnnouncement : public ValueSerializable +struct IpServiceAnnouncement : public Value::Serializable<IpServiceAnnouncement> { IpServiceAnnouncement(in_port_t p = 0) { ss.ss_family = 0; @@ -193,12 +170,37 @@ struct IpServiceAnnouncement : public ValueSerializable } IpServiceAnnouncement(const Blob& b) { - unpackBlob(b); + msgpack_unpack(unpackMsg(b).get()); + } + + template <typename Packer> + void msgpack_pack(Packer& pk) const + { + pk.pack_array(2); + pk.pack(getPort()); + if (ss.ss_family == AF_INET) { + pk.pack_bin(sizeof(in_addr)); + pk.pack_bin_body((const char*)&reinterpret_cast<const sockaddr_in*>(&ss)->sin_addr, sizeof(in_addr)); + } else if (ss.ss_family == AF_INET6) { + pk.pack_bin(sizeof(in6_addr)); + pk.pack_bin_body((const char*)&reinterpret_cast<const sockaddr_in6*>(&ss)->sin6_addr, sizeof(in6_addr)); + } + } + + void msgpack_unpack(msgpack::object o) + { + if (o.type != msgpack::type::ARRAY) throw msgpack::type_error(); + if (o.via.array.size < 2) throw msgpack::type_error(); + setPort(o.via.array.ptr[0].as<in_port_t>()); + auto ip_dat = o.via.array.ptr[1].as<Blob>(); + if (ip_dat.size() == sizeof(in_addr)) + std::copy(ip_dat.begin(), ip_dat.end(), (char*)&reinterpret_cast<sockaddr_in*>(&ss)->sin_addr); + else if (ip_dat.size() == sizeof(in6_addr)) + std::copy(ip_dat.begin(), ip_dat.end(), (char*)&reinterpret_cast<sockaddr_in6*>(&ss)->sin6_addr); + else + throw msgpack::type_error(); } - virtual void pack(Blob& res) const; - virtual void unpack(Blob::const_iterator& begin, Blob::const_iterator& end); - in_port_t getPort() const { return ntohs(reinterpret_cast<const sockaddr_in*>(&ss)->sin_port); } diff --git a/include/opendht/dht.h b/include/opendht/dht.h index 2194d8e1f47600e16d0f19cae74ce0a7f77c3330..e062f8efe13f0c536694f6593dc03ffa3d9cc70b 100644 --- a/include/opendht/dht.h +++ b/include/opendht/dht.h @@ -138,12 +138,14 @@ public: static GetCallbackSimple bindGetCb(GetCallbackRaw raw_cb, void* user_data) { + if (not raw_cb) return {}; return [=](const std::shared_ptr<Value>& value) { return raw_cb(value, user_data); }; } static GetCallback bindGetCb(GetCallbackSimple cb) { + if (not cb) return {}; return [=](const std::vector<std::shared_ptr<Value>>& values) { for (const auto& v : values) if (not cb(v)) @@ -159,11 +161,13 @@ public: static DoneCallback bindDoneCb(DoneCallbackSimple donecb) { + if (not donecb) return {}; using namespace std::placeholders; return std::bind(donecb, _1); } static DoneCallback bindDoneCb(DoneCallbackRaw raw_cb, void* user_data) { + if (not raw_cb) return {}; return [=](bool success, const std::vector<std::shared_ptr<Node>>& nodes) { raw_cb(success, (std::vector<std::shared_ptr<Node>>*)&nodes, user_data); }; @@ -266,7 +270,7 @@ public: * reannounced on a regular basis. * User can call #cancelPut(InfoHash, Value::Id) to cancel a put operation. */ - void put(const InfoHash& key, const std::shared_ptr<Value>&, DoneCallback cb=nullptr); + void put(const InfoHash& key, std::shared_ptr<Value>, DoneCallback cb=nullptr); void put(const InfoHash& key, const std::shared_ptr<Value>& v, DoneCallbackSimple cb) { put(key, v, bindDoneCb(cb)); } @@ -359,7 +363,7 @@ private: static constexpr unsigned MAX_HASHES {16384}; /* The maximum number of searches we keep data about. */ - static constexpr unsigned MAX_SEARCHES {1024}; + static constexpr unsigned MAX_SEARCHES {128}; /* The time after which we can send get requests for a search in case of no answers. */ @@ -383,6 +387,7 @@ private: static constexpr unsigned TOKEN_SIZE {64}; + static const std::string my_v; struct NodeCache { std::shared_ptr<Node> getNode(const InfoHash& id, sa_family_t family); @@ -673,6 +678,7 @@ private: */ struct TransId final : public std::array<uint8_t, 4> { TransId() {} + TransId(const std::array<char, 4>& o) { std::copy(o.begin(), o.end(), begin()); } TransId(const TransPrefix prefix, uint16_t seqno = 0) { std::copy_n(prefix.begin(), prefix.size(), begin()); *reinterpret_cast<uint16_t*>(data()+prefix.size()) = seqno; @@ -688,7 +694,7 @@ private: } bool matches(const TransPrefix prefix, uint16_t *seqno_return = nullptr) const { - if (std::equal(begin(), begin()+1, prefix.begin())) { + if (std::equal(begin(), begin()+2, prefix.begin())) { if (seqno_return) *seqno_return = *reinterpret_cast<const uint16_t*>(&(*this)[2]); return true; @@ -708,7 +714,7 @@ private: int dht_socket6 {-1}; InfoHash myid {}; - static const uint8_t my_v[9]; + std::array<uint8_t, 8> secret {{}}; std::array<uint8_t, 8> oldsecret {{}}; @@ -785,16 +791,24 @@ private: int sendError(const sockaddr*, socklen_t, TransId tid, uint16_t code, const char *message, bool include_id=false); void processMessage(const uint8_t *buf, size_t buflen, const sockaddr *from, socklen_t fromlen); - MessageType parseMessage(const uint8_t *buf, size_t buflen, - TransId& tid, - InfoHash& id_return, InfoHash& info_hash_return, - InfoHash& target_return, in_port_t& port_return, - Blob& token, Value::Id& value_id, - uint8_t *nodes_return, unsigned *nodes_len, - uint8_t *nodes6_return, unsigned *nodes6_len, - std::vector<std::shared_ptr<Value>>& values_return, - want_t* want_return, uint16_t& error_code, bool& ring, - sockaddr* addr_return, socklen_t& addr_length_return); + + struct ParsedMessage { + MessageType type; + InfoHash id; + InfoHash info_hash; + InfoHash target; + TransId tid; + Blob token; + Value::Id value_id; + Blob nodes4; + Blob nodes6; + std::vector<std::shared_ptr<Value>> values; + want_t want; + uint16_t error_code; + std::string ua; + Address addr; + void msgpack_unpack(msgpack::object o); + }; void rotateSecrets(); @@ -859,7 +873,7 @@ private: * The values can be filtered by an arbitrary provided filter. */ Search* search(const InfoHash& id, sa_family_t af, GetCallback = nullptr, DoneCallback = nullptr, Value::Filter = Value::AllFilter()); - void announce(const InfoHash& id, sa_family_t af, const std::shared_ptr<Value>& value, DoneCallback callback); + void announce(const InfoHash& id, sa_family_t af, std::shared_ptr<Value> value, DoneCallback callback); size_t listenTo(const InfoHash& id, sa_family_t af, GetCallback cb, Value::Filter f = Value::AllFilter()); std::list<Search>::iterator newSearch(); diff --git a/include/opendht/dhtrunner.h b/include/opendht/dhtrunner.h index eab88bbf8a8968847de3935f91d32e387cc83da7..ce274cf18930f13a3c1870d31ac8b0a166ab3f9c 100644 --- a/include/opendht/dhtrunner.h +++ b/include/opendht/dhtrunner.h @@ -73,50 +73,49 @@ public: void get(const std::string& key, Dht::GetCallback vcb, Dht::DoneCallbackSimple dcb={}, Value::Filter f = Value::AllFilter()); template <class T> - void get(InfoHash hash, std::function<bool(std::vector<T>&&)> cb) + void get(InfoHash hash, std::function<bool(std::vector<T>&&)> cb, Dht::DoneCallbackSimple dcb={}) { get(hash, [=](const std::vector<std::shared_ptr<Value>>& vals) { return cb(unpackVector<T>(vals)); }, - T::getFilter()); + dcb, + getFilterSet<T>()); } template <class T> - void get(InfoHash hash, std::function<bool(T&&)> cb) + void get(InfoHash hash, std::function<bool(T&&)> cb, Dht::DoneCallbackSimple dcb={}) { get(hash, [=](const std::vector<std::shared_ptr<Value>>& vals) { for (const auto& v : vals) { - T msg; try { - msg.unpackValue(*v); + if (not cb(Value::unpack<T>(*v))) + return false; } catch (const std::exception&) { continue; } - if (not cb(std::move(msg))) - return false; } return true; }, - T::getFilter()); + dcb, + getFilterSet<T>()); } std::future<std::vector<std::shared_ptr<dht::Value>>> get(InfoHash key, Value::Filter f = Value::AllFilter()) { auto p = std::make_shared<std::promise<std::vector<std::shared_ptr< dht::Value >>>>(); auto values = std::make_shared<std::vector<std::shared_ptr< dht::Value >>>(); - get(key, [=](const std::vector<std::shared_ptr<dht::Value>>& vlist) { values->insert(values->end(), vlist.begin(), vlist.end()); return true; }, [=](bool) { p->set_value(std::move(*values)); - }, f); + }, + f); return p->get_future(); } template <class T> std::future<std::vector<T>> get(InfoHash key) { - auto p = std::make_shared<std::promise<std::vector<std::shared_ptr<dht::Value>>>>(); + auto p = std::make_shared<std::promise<std::vector<T>>>(); auto values = std::make_shared<std::vector<T>>(); - get<T>(key, [=](T&& v) { values->emplace_back(std::move(v)); return true; @@ -138,29 +137,23 @@ public: return listen(hash, [=](const std::vector<std::shared_ptr<Value>>& vals) { return cb(unpackVector<T>(vals)); }, - T::getFilter()); + getFilterSet<T>()); } - template <class T> + template <typename T> std::future<size_t> listen(InfoHash hash, std::function<bool(T&&)> cb, Value::Filter f = Value::AllFilter()) { return listen(hash, [=](const std::vector<std::shared_ptr<Value>>& vals) { for (const auto& v : vals) { - T msg; try { - msg.unpackValue(*v); + if (not cb(Value::unpack<T>(*v))) + return false; } catch (const std::exception&) { continue; } - if (not cb(std::move(msg))) - return false; } return true; }, - Value::Filter::chain({ - Value::TypeFilter(T::TYPE), - T::getFilter(), - f - })); + getFilterSet<T>(f)); } void cancelListen(InfoHash h, size_t token); @@ -171,12 +164,12 @@ public: void put(InfoHash hash, Value&& value, Dht::DoneCallbackSimple cb) { put(hash, std::forward<Value>(value), Dht::bindDoneCb(cb)); } - void put(const std::string& key, Value&& value, Dht::DoneCallback cb=nullptr); + void put(const std::string& key, Value&& value, Dht::DoneCallbackSimple cb=nullptr); void cancelPut(const InfoHash& h, const Value::Id& id); void putSigned(InfoHash hash, Value&& value, Dht::DoneCallback cb=nullptr); - void putSigned(const std::string& key, Value&& value, Dht::DoneCallback cb=nullptr); + void putSigned(const std::string& key, Value&& value, Dht::DoneCallbackSimple cb=nullptr); void putSigned(InfoHash hash, Value&& value, Dht::DoneCallbackSimple cb) { putSigned(hash, std::forward<Value>(value), Dht::bindDoneCb(cb)); } @@ -372,6 +365,10 @@ private: static std::vector<std::pair<sockaddr_storage, socklen_t>> getAddrInfo(const char* host, const char* service); + Dht::Status getStatus() const { + return std::max(status4, status6); + } + std::unique_ptr<SecureDht> dht_ {}; mutable std::mutex dht_mtx {}; std::thread dht_thread {}; @@ -381,6 +378,7 @@ private: std::mutex sock_mtx {}; std::vector<std::pair<Blob, std::pair<sockaddr_storage, socklen_t>>> rcv {}; + std::queue<std::function<void(SecureDht&)>> pending_ops_prio {}; std::queue<std::function<void(SecureDht&)>> pending_ops {}; std::mutex storage_mtx {}; diff --git a/include/opendht/infohash.h b/include/opendht/infohash.h index acc7a33552aaf0d051cb2026c62ff13d17bea755..2dadb14b1b5094214075fe0b47796bbcedd266f8 100644 --- a/include/opendht/infohash.h +++ b/include/opendht/infohash.h @@ -30,6 +30,8 @@ #pragma once +#include <msgpack.hpp> + #include <iostream> #include <iomanip> #include <array> @@ -73,6 +75,10 @@ public: */ explicit InfoHash(const std::string& hex); + InfoHash(const msgpack::object& o) { + msgpack_unpack(o); + } + /** * Find the lowest 1 bit in an id. * Result will allways be lower than 8*HASH_LEN @@ -193,6 +199,20 @@ public: friend std::ostream& operator<< (std::ostream& s, const InfoHash& h); std::string toString() const; + + template <typename Packer> + void msgpack_pack(Packer& pk) const + { + pk.pack_bin(HASH_LEN); + pk.pack_bin_body((char*)data(), HASH_LEN); + } + + void msgpack_unpack(msgpack::object o) { + if (o.type != msgpack::type::BIN or o.via.bin.size != HASH_LEN) + throw msgpack::type_error(); + std::copy_n(o.via.bin.ptr, HASH_LEN, data()); + } + }; } diff --git a/include/opendht/securedht.h b/include/opendht/securedht.h index 898fab2056cd5d9df2db5c5f146fbe1b4fe47b11..20e48666725c25ec66e9aabf28127ffe2f897547 100644 --- a/include/opendht/securedht.h +++ b/include/opendht/securedht.h @@ -99,7 +99,7 @@ public: /** * Will take ownership of the value, sign it using our private key and put it in the DHT. */ - void putSigned(const InfoHash& hash, const std::shared_ptr<Value>& val, DoneCallback callback); + void putSigned(const InfoHash& hash, std::shared_ptr<Value> val, DoneCallback callback); void putSigned(const InfoHash& hash, Value&& v, DoneCallback callback) { putSigned(hash, std::make_shared<Value>(std::move(v)), callback); } diff --git a/include/opendht/serialize.h b/include/opendht/serialize.h deleted file mode 100644 index 3e81f5b7cc25806e02cc65b6aa6e6d2911d950b1..0000000000000000000000000000000000000000 --- a/include/opendht/serialize.h +++ /dev/null @@ -1,339 +0,0 @@ -/** - * Copyright (c) 2013, Simone Pellegrini All rights reserved. - * Copyright (c) 2014 Savoir-Faire Linux. All rights reserved. - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * - Redistributions of source code must retain the above copyright notice, - * this list of conditions and the following disclaimer. - * - * - Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE - * ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE - * LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR - * CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF - * SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS - * INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN - * CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) - * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE - * POSSIBILITY OF SUCH DAMAGE. - */ - -#pragma once - -#include <vector> -#include <string> -#include <tuple> -#include <numeric> -#include <limits> -#include <chrono> - -typedef std::vector<uint8_t> Blob; - -template <class T> -inline void serialize(const T&, Blob&); - -namespace detail { - - template<std::size_t> struct int_{}; - -} - -// get_size -template <class T> -size_t get_size(const T& obj); - -namespace detail { - - typedef uint16_t serialized_size_t; - - template <class T> - struct get_size_helper; - - template <class T> - struct get_size_helper<std::vector<T>> { - static size_t value(const std::vector<T>& obj) { - return std::accumulate(obj.begin(), obj.end(), sizeof(serialized_size_t), - [](const size_t& acc, const T& cur) { return acc+get_size(cur); }); - } - }; - - template <> - struct get_size_helper<std::string> { - static size_t value(const std::string& obj) { - return sizeof(serialized_size_t) + obj.length()*sizeof(uint8_t); - } - }; - - template <class T> - struct get_size_helper<std::chrono::time_point<T>> { - static constexpr size_t value(const std::chrono::time_point<T>&) { - return sizeof(typename std::chrono::time_point<T>::rep); - } - }; - - template <class tuple_type> - inline size_t get_tuple_size(const tuple_type& obj, int_<0>) { - constexpr size_t idx = std::tuple_size<tuple_type>::value-1; - return get_size(std::get<idx>(obj)); - } - - template <class tuple_type, size_t pos> - inline size_t get_tuple_size(const tuple_type& obj, int_<pos>) { - constexpr size_t idx = std::tuple_size<tuple_type>::value-pos-1; - size_t acc = get_size(std::get<idx>(obj)); - - // recur - return acc+get_tuple_size(obj, int_<pos-1>()); - } - - template <class ...T> - struct get_size_helper<std::tuple<T...>> { - static size_t value(const std::tuple<T...>& obj) { - return get_tuple_size(obj, int_<sizeof...(T)-1>()); - } - }; - - template <class T> - struct get_size_helper { - static size_t value(const T&) { return sizeof(T); } - }; - -} - -template <class T> -inline size_t get_size(const T& obj) { - return detail::get_size_helper<T>::value(obj); -} - -namespace detail { - - template <class T> - class serialize_helper; - - template <class T> - void serializer(const T& obj, Blob::iterator&); - - template <class tuple_type> - inline void serialize_tuple(const tuple_type& obj, Blob::iterator& res, int_<0>) { - constexpr size_t idx = std::tuple_size<tuple_type>::value-1; - serializer(std::get<idx>(obj), res); - } - - template <class tuple_type, size_t pos> - inline void serialize_tuple(const tuple_type& obj, Blob::iterator& res, int_<pos>) { - constexpr size_t idx = std::tuple_size<tuple_type>::value-pos-1; - serializer(std::get<idx>(obj), res); - - // recur - serialize_tuple(obj, res, int_<pos-1>()); - } - - template <class... T> - struct serialize_helper<std::tuple<T...>> { - static void apply(const std::tuple<T...>& obj, Blob::iterator& res) { - detail::serialize_tuple(obj, res, detail::int_<sizeof...(T)-1>()); - } - - }; - - template <> - struct serialize_helper<std::string> { - static void apply(const std::string& obj, Blob::iterator& res) { - // store the number of elements of this vector at the beginning - if (obj.length() > std::numeric_limits<serialized_size_t>::max()) - throw std::length_error("string is too long"); - serializer(static_cast<serialized_size_t>(obj.length()), res); - for(const auto& cur : obj) { serializer(cur, res); } - } - - }; - - template <class T> - struct serialize_helper<std::vector<T>> { - static void apply(const std::vector<T>& obj, Blob::iterator& res) { - // store the number of elements of this vector at the beginning - if (obj.size() > std::numeric_limits<serialized_size_t>::max()) - throw std::length_error("vector is too large"); - serializer(static_cast<serialized_size_t>(obj.size()), res); - for(const auto& cur : obj) { serializer(cur, res); } - } - - }; - - template <class T> - struct serialize_helper<std::chrono::time_point<T>> { - static void apply(const std::chrono::time_point<T>& obj, Blob::iterator& res) { - serializer(obj.time_since_epoch().count(), res); - } - }; - - template <class T> - struct serialize_helper { - static void apply(const T& obj, Blob::iterator& res) { - const uint8_t* ptr = reinterpret_cast<const uint8_t*>(&obj); - std::copy(ptr,ptr+sizeof(T),res); - res+=sizeof(T); - } - - }; - - template <class T> - inline void serializer(const T& obj, Blob::iterator& res) { - serialize_helper<T>::apply(obj,res); - } - -} // end detail namespace - -template <class T> -inline void serialize(const T& obj, Blob& res) { - - size_t offset = res.size(); - size_t size = get_size(obj); - res.resize(res.size() + size); - - Blob::iterator it = res.begin()+offset; - detail::serializer(obj,it); - if (res.begin() + offset + size != it) - throw std::logic_error("error serializing object"); -} - -namespace detail { - - template <class T> - struct deserialize_helper; - - template <class T> - struct deserialize_helper { - static T apply(Blob::const_iterator& begin, - Blob::const_iterator end) { - if (begin+sizeof(T)>end) - throw std::length_error("error deserializing object"); - T val; - std::copy(begin, begin+sizeof(T), reinterpret_cast<uint8_t*>(&val)); - begin+=sizeof(T); - return val; - } - }; - - template <class T> - struct deserialize_helper<std::chrono::time_point<T>> { - static std::chrono::time_point<T> apply(Blob::const_iterator& begin, - Blob::const_iterator end) { - return std::chrono::time_point<T>(typename T::duration(deserialize_helper<typename T::rep>::apply(begin,end))); - } - }; - - template <class T> - struct deserialize_helper<std::vector<T>> { - static std::vector<T> apply(Blob::const_iterator& begin, - Blob::const_iterator end) - { - // retrieve the number of elements - serialized_size_t size = deserialize_helper<serialized_size_t>::apply(begin,end); - - std::vector<T> vect(size); - for(size_t i=0; i<size; ++i) { - vect[i] = std::move(deserialize_helper<T>::apply(begin,end)); - } - return vect; - } - }; - - template <> - struct deserialize_helper<std::string> { - static std::string apply(Blob::const_iterator& begin, - Blob::const_iterator end) - { - // retrieve the number of elements - serialized_size_t size = deserialize_helper<serialized_size_t>::apply(begin,end); - - if (size == 0u) return std::string(); - std::string str(size,'\0'); - for(size_t i=0; i<size; ++i) { - str.at(i) = deserialize_helper<uint8_t>::apply(begin,end); - } - return str; - } - }; - - template <class tuple_type> - inline void deserialize_tuple(tuple_type& obj, - Blob::const_iterator& begin, - Blob::const_iterator end, int_<0>) { - constexpr size_t idx = std::tuple_size<tuple_type>::value-1; - typedef typename std::tuple_element<idx,tuple_type>::type T; - - std::get<idx>(obj) = std::move(deserialize_helper<T>::apply(begin, end)); - } - - template <class tuple_type, size_t pos> - inline void deserialize_tuple(tuple_type& obj, - Blob::const_iterator& begin, - Blob::const_iterator end, int_<pos>) { - constexpr size_t idx = std::tuple_size<tuple_type>::value-pos-1; - typedef typename std::tuple_element<idx,tuple_type>::type T; - std::get<idx>(obj) = std::move(deserialize_helper<T>::apply(begin, end)); - - // recur - deserialize_tuple(obj, begin, end, int_<pos-1>()); - } - - template <class... T> - struct deserialize_helper<std::tuple<T...>> { - static std::tuple<T...> apply(Blob::const_iterator& begin, - Blob::const_iterator end) - { - //return std::make_tuple(deserialize(begin,begin+sizeof(T),T())...); - std::tuple<T...> ret; - deserialize_tuple(ret, begin, end, int_<sizeof...(T)-1>()); - return ret; - } - - }; - -} - -template <class T> -inline T deserialize(Blob::const_iterator& begin, const Blob::const_iterator& end) { - return detail::deserialize_helper<T>::apply(begin, end); -} - -template <class T> -inline T deserialize(const Blob& res) { - Blob::const_iterator it = res.begin(); - return deserialize<T>(it, res.end()); -} - -namespace dht { - - struct Serializable { - /** - * Append serialized object to res. - */ - virtual void pack(Blob& res) const = 0; - Blob getPacked() const { - Blob ret; - pack(ret); - return ret; - } - - /** - * Read serialized object from {begin, end}. - */ - virtual void unpack(Blob::const_iterator& begin, Blob::const_iterator& end) = 0; - void unpackBlob(const Blob& data) { - auto cib = data.cbegin(), cie = data.cend(); - unpack(cib, cie); - } - - virtual ~Serializable() = default; -}; - -} diff --git a/include/opendht/value.h b/include/opendht/value.h index 3a0b0b8e70bf4a4125ea69dccd59d0fc40faa58b..30933dab3dab2de7c2fd6d9b14413ceec0d3fa0b 100644 --- a/include/opendht/value.h +++ b/include/opendht/value.h @@ -32,7 +32,7 @@ #include "infohash.h" #include "crypto.h" -#include "serialize.h" +#include <msgpack.hpp> #ifndef _WIN32 #include <netinet/in.h> @@ -153,12 +153,23 @@ struct ValueType { EditPolicy editPolicy {DEFAULT_EDIT_POLICY}; }; -struct ValueSerializable : public Serializable -{ - virtual const ValueType& getType() const = 0; - virtual void unpackValue(const Value& v); - virtual Value packValue() const; -}; +template <typename Type> +Blob +packMsg(const Type& t) { + msgpack::sbuffer buffer; + msgpack::packer<msgpack::sbuffer> pk(&buffer); + pk.pack(t); + return {buffer.data(), buffer.data()+buffer.size()}; +} + +template <typename Type> +Type +unpackMsg(Blob b) { + msgpack::unpacked msg_res = msgpack::unpack((const char*)b.data(), b.size()); + return msg_res.get().as<Type>(); +} + +msgpack::unpacked unpackMsg(Blob b); /** * A "value" is data potentially stored on the Dht, with some metadata. @@ -169,7 +180,7 @@ struct ValueSerializable : public Serializable * Values are stored at a given InfoHash in the Dht, but also have a * unique ID to distinguish between values stored at the same location. */ -struct Value : public Serializable +struct Value { typedef uint64_t Id; static const Id INVALID_ID {0}; @@ -222,31 +233,57 @@ struct Value : public Serializable }; } - /** - * Hold information about how the data is signed/encrypted. - * Class is final because bitset have no virtual destructor. - */ - class ValueFlags final : public std::bitset<3> { - public: - using std::bitset<3>::bitset; - ValueFlags() {} - ValueFlags(bool sign, bool encrypted, bool have_recipient = false) : bitset<3>((sign ? 1:0) | (encrypted ? 2:0) | (have_recipient ? 4:0)) {} - bool isSigned() const { - return (*this)[0]; - } - bool isEncrypted() const { - return (*this)[1]; + template <typename T> + struct Serializable + { + virtual const ValueType& getType() const = 0; + virtual void unpackValue(const Value& v) { + auto msg = msgpack::unpack((const char*)v.data.data(), v.data.size()); + msgpack::object obj = msg.get(); + obj.convert(static_cast<T*>(this)); } - bool haveRecipient() const { - return (*this)[2]; + + virtual Value packValue() const { + return Value {getType(), static_cast<const T&>(*this)}; } + virtual ~Serializable() = default; }; + template <typename T, + typename std::enable_if<std::is_base_of<Serializable<T>, T>::value, T>::type* = nullptr> + static Value pack(const T& obj) + { + return obj.packValue(); + } + + template <typename T, + typename std::enable_if<!std::is_base_of<Serializable<T>, T>::value, T>::type* = nullptr> + static Value pack(const T& obj) + { + return {ValueType::USER_DATA.id, packMsg<T>(obj)}; + } + + template <typename T, + typename std::enable_if<std::is_base_of<Serializable<T>, T>::value, T>::type* = nullptr> + static T unpack(const Value& v) + { + T msg; + msg.unpackValue(v); + return msg; + } + + template <typename T, + typename std::enable_if<!std::is_base_of<Serializable<T>, T>::value, T>::type* = nullptr> + static T unpack(const Value& v) + { + return unpackMsg<T>(v.data); + } + bool isEncrypted() const { - return flags.isEncrypted(); + return not cypher.empty(); } bool isSigned() const { - return flags.isSigned(); + return isEncrypted() or not signature.empty(); } Value() {} @@ -260,12 +297,14 @@ struct Value : public Serializable : id(id), type(t), data(std::move(data)) {} Value(ValueType::Id t, const uint8_t* dat_ptr, size_t dat_len, Id id = INVALID_ID) : id(id), type(t), data(dat_ptr, dat_ptr+dat_len) {} - Value(ValueType::Id t, const Serializable& d, Id id = INVALID_ID) - : id(id), type(t), data(d.getPacked()) {} - Value(const ValueType& t, const Serializable& d, Id id = INVALID_ID) - : id(id), type(t.id), data(d.getPacked()) {} - Value(const ValueSerializable& d, Id id = INVALID_ID) - : id(id), type(d.getType().id), data(d.getPacked()) {} + + template <typename Type> + Value(ValueType::Id t, const Type& d, Id id = INVALID_ID) + : id(id), type(t), data(packMsg(d)) {} + + template <typename Type> + Value(const ValueType& t, const Type& d, Id id = INVALID_ID) + : id(id), type(t.id), data(packMsg(d)) {} /** Custom user data constructor */ Value(const Blob& userdata) : data(userdata) {} @@ -273,41 +312,50 @@ struct Value : public Serializable Value(const uint8_t* dat_ptr, size_t dat_len) : data(dat_ptr, dat_ptr+dat_len) {} Value(Value&& o) noexcept - : id(o.id), flags(o.flags), owner(std::move(o.owner)), recipient(o.recipient), + : id(o.id), owner(std::move(o.owner)), recipient(o.recipient), type(o.type), data(std::move(o.data)), seq(o.seq), signature(std::move(o.signature)), cypher(std::move(o.cypher)) {} + template <typename Type> + Value(const Type& vs) + : Value(pack<Type>(vs)) {} + + Value(const msgpack::object& o) { + msgpack_unpack(o); + } + inline bool operator== (const Value& o) { return id == o.id && - (flags.isEncrypted() ? cypher == o.cypher : + (isEncrypted() ? cypher == o.cypher : (owner == o.owner && type == o.type && data == o.data && signature == o.signature)); } void setRecipient(const InfoHash& r) { recipient = r; - flags[2] = true; } void setCypher(Blob&& c) { cypher = std::move(c); - flags = {true, true, true}; } /** - * Pack part of the data to be signed + * Pack part of the data to be signed (must always be done the same way) */ - void packToSign(Blob& res) const; - Blob getToSign() const; + Blob getToSign() const { + msgpack::sbuffer buffer; + msgpack::packer<msgpack::sbuffer> pk(&buffer); + msgpack_pack_to_sign(pk); + return {buffer.data(), buffer.data()+buffer.size()}; + } /** * Pack part of the data to be encrypted */ - void packToEncrypt(Blob& res) const; - Blob getToEncrypt() const; - - void pack(Blob& res) const; - - void unpackBody(Blob::const_iterator& begin, Blob::const_iterator& end); - virtual void unpack(Blob::const_iterator& begin, Blob::const_iterator& end); + Blob getToEncrypt() const { + msgpack::sbuffer buffer; + msgpack::packer<msgpack::sbuffer> pk(&buffer); + msgpack_pack_to_encrypt(pk); + return {buffer.data(), buffer.data()+buffer.size()}; + } /** print value for debugging */ friend std::ostream& operator<< (std::ostream& s, const Value& v); @@ -318,11 +366,53 @@ struct Value : public Serializable return ss.str(); } - Id id {INVALID_ID}; + template <typename Packer> + void msgpack_pack_to_sign(Packer& pk) const + { + pk.pack_map((user_type.empty()?0:1) + (owner?(recipient == InfoHash() ? 4 : 5):2)); + if (owner) { // isSigned + pk.pack(std::string("seq")); pk.pack(seq); + pk.pack(std::string("owner")); owner.msgpack_pack(pk); + if (recipient != InfoHash()) { + pk.pack(std::string("to")); pk.pack(recipient); + } + } + pk.pack(std::string("type")); pk.pack(type); + pk.pack(std::string("data")); pk.pack_bin(data.size()); + pk.pack_bin_body((const char*)data.data(), data.size()); + if (not user_type.empty()) { + pk.pack(std::string("utype")); pk.pack(user_type); + } + } - // data (part that is signed / encrypted) + template <typename Packer> + void msgpack_pack_to_encrypt(Packer& pk) const + { + if (isEncrypted()) { + pk.pack_bin(cypher.size()); + pk.pack_bin_body((const char*)cypher.data(), cypher.size()); + } else { + pk.pack_map(isSigned() ? 2 : 1); + pk.pack(std::string("body")); msgpack_pack_to_sign(pk); + if (isSigned()) { + pk.pack(std::string("sig")); pk.pack_bin(signature.size()); + pk.pack_bin_body((const char*)signature.data(), signature.size()); + } + } + } - ValueFlags flags {}; + template <typename Packer> + void msgpack_pack(Packer& pk) const + { + pk.pack_map(2); + pk.pack(std::string("id")); pk.pack(id); + pk.pack(std::string("dat")); msgpack_pack_to_encrypt(pk); + } + + void msgpack_unpack(msgpack::object o); + void msgpack_unpack_body(const msgpack::object& o); + + Id id {INVALID_ID}; /** * Public key of the signer. @@ -342,6 +432,11 @@ struct Value : public Serializable ValueType::Id type {ValueType::USER_DATA.id}; Blob data {}; + /** + * Custom user-defined type + */ + std::string user_type {}; + /** * Sequence number to avoid replay attacks */ @@ -358,6 +453,45 @@ struct Value : public Serializable Blob cypher {}; }; +template <typename T, + typename std::enable_if<std::is_base_of<Value::Serializable<T>, T>::value, T>::type* = nullptr> +Value::Filter +getFilterSet(Value::Filter f) +{ + return Value::Filter::chain({ + Value::TypeFilter(T::TYPE), + T::getFilter(), + f + }); +} + +template <typename T, + typename std::enable_if<!std::is_base_of<Value::Serializable<T>, T>::value, T>::type* = nullptr> +Value::Filter +getFilterSet(Value::Filter f) +{ + return f; +} + +template <typename T, + typename std::enable_if<std::is_base_of<Value::Serializable<T>, T>::value, T>::type* = nullptr> +Value::Filter +getFilterSet() +{ + return Value::Filter::chain({ + Value::TypeFilter(T::TYPE), + T::getFilter() + }); +} + +template <typename T, + typename std::enable_if<!std::is_base_of<Value::Serializable<T>, T>::value, T>::type* = nullptr> +Value::Filter +getFilterSet() +{ + return Value::AllFilter(); +} + template <class T> std::vector<T> unpackVector(const std::vector<std::shared_ptr<Value>>& vals) { @@ -365,9 +499,7 @@ unpackVector(const std::vector<std::shared_ptr<Value>>& vals) { ret.reserve(vals.size()); for (const auto& v : vals) { try { - T msg; - msg.unpackValue(*v); - ret.emplace_back(std::move(msg)); + ret.emplace_back(Value::unpack<T>(*v)); } catch (const std::exception&) {} } return ret; diff --git a/src/Makefile.am b/src/Makefile.am index a7842c68375908cb6e3f5e58877cb6baf3e623e6..727c4b50ba33405e98e92c2954d16e86f6fd5dab 100644 --- a/src/Makefile.am +++ b/src/Makefile.am @@ -2,6 +2,7 @@ lib_LTLIBRARIES = libopendht.la AM_CPPFLAGS = -I../include/opendht libopendht_la_CXXFLAGS = @CXXFLAGS@ +libopendht_la_LDFLAGS = @LDFLAGS@ @GNUTLS_LIBS@ @nettle_LIBS@ libopendht_la_SOURCES = \ dht.cpp \ @@ -24,6 +25,5 @@ nobase_include_HEADERS = \ ../include/opendht/crypto.h \ ../include/opendht/securedht.h \ ../include/opendht/dhtrunner.h \ - ../include/opendht/serialize.h \ ../include/opendht/default_types.h \ ../include/opendht/rng.h diff --git a/src/crypto.cpp b/src/crypto.cpp index 1164633932952bb3b3bed241ad1a073074df9ac0..c6bf16af0dcb6e7736c16ccfa89346533672705d 100644 --- a/src/crypto.cpp +++ b/src/crypto.cpp @@ -35,14 +35,17 @@ extern "C" { #include <gnutls/gnutls.h> #include <gnutls/abstract.h> #include <gnutls/x509.h> +#include <nettle/gcm.h> +#include <nettle/aes.h> } #include <random> #include <sstream> -#include <random> #include <stdexcept> #include <cassert> +static std::uniform_int_distribution<uint8_t> rand_byte; + static gnutls_digest_algorithm_t get_dig_for_pub(gnutls_pubkey_t pubkey) { gnutls_digest_algorithm_t dig; @@ -82,6 +85,87 @@ static gnutls_digest_algorithm_t get_dig(gnutls_x509_crt_t crt) namespace dht { namespace crypto { +static constexpr std::array<size_t, 3> AES_LENGTHS {128/8, 192/8, 256/8}; + +size_t aesKeySize(size_t max) +{ + unsigned aes_key_len = 0; + for (size_t s = 0; s < AES_LENGTHS.size(); s++) { + if (AES_LENGTHS[s] <= max) + aes_key_len = AES_LENGTHS[s]; + else break; + } + return aes_key_len; +} + +bool aesKeySizeGood(size_t key_size) +{ + for (auto& i : AES_LENGTHS) + if (key_size == i) + return true; + return false; +} + +#ifndef GCM_DIGEST_SIZE +#define GCM_DIGEST_SIZE GCM_BLOCK_SIZE +#endif + +Blob +aesEncrypt(const Blob& data, const Blob& key) +{ + std::array<uint8_t, GCM_IV_SIZE> iv; + { + crypto::random_device rdev; + std::generate_n(iv.begin(), iv.size(), std::bind(rand_byte, std::ref(rdev))); + } + struct gcm_aes_ctx aes; + gcm_aes_set_key(&aes, key.size(), key.data()); + gcm_aes_set_iv(&aes, iv.size(), iv.data()); + gcm_aes_update(&aes, data.size(), data.data()); + + Blob ret(data.size() + GCM_IV_SIZE + GCM_DIGEST_SIZE); + std::copy(iv.begin(), iv.end(), ret.begin()); + gcm_aes_encrypt(&aes, data.size(), ret.data() + GCM_IV_SIZE, data.data()); + gcm_aes_digest(&aes, GCM_DIGEST_SIZE, ret.data() + GCM_IV_SIZE + data.size()); + return ret; +} + +Blob +aesDecrypt(const Blob& data, const Blob& key) +{ + if (not aesKeySizeGood(key.size())) + throw DecryptError("Wrong key size"); + + if (data.size() <= GCM_IV_SIZE + GCM_DIGEST_SIZE) + throw DecryptError("Wrong data size"); + + std::array<uint8_t, GCM_DIGEST_SIZE> digest; + + struct gcm_aes_ctx aes; + gcm_aes_set_key(&aes, key.size(), key.data()); + gcm_aes_set_iv(&aes, GCM_IV_SIZE, data.data()); + + size_t data_sz = data.size() - GCM_IV_SIZE - GCM_DIGEST_SIZE; + Blob ret(data_sz); + //gcm_aes_update(&aes, data_sz, data.data() + GCM_IV_SIZE); + gcm_aes_decrypt(&aes, data_sz, ret.data(), data.data() + GCM_IV_SIZE); + //gcm_aes_digest(aes, GCM_DIGEST_SIZE, digest.data()); + + // TODO compute the proper digest directly from the decryption pass + Blob ret_tmp(data_sz); + struct gcm_aes_ctx aes_d; + gcm_aes_set_key(&aes_d, key.size(), key.data()); + gcm_aes_set_iv(&aes_d, GCM_IV_SIZE, data.data()); + gcm_aes_update(&aes_d, ret.size() , ret.data()); + gcm_aes_encrypt(&aes_d, ret.size(), ret_tmp.data(), ret.data()); + gcm_aes_digest(&aes_d, GCM_DIGEST_SIZE, digest.data()); + + if (not std::equal(digest.begin(), digest.end(), data.end() - GCM_DIGEST_SIZE)) + throw DecryptError("Can't decrypt data"); + + return ret; +} + PrivateKey::PrivateKey() { #if GNUTLS_VERSION_NUMBER < 0x030300 @@ -194,6 +278,19 @@ PrivateKey::sign(const Blob& data) const return ret; } +Blob +PrivateKey::decryptBloc(const uint8_t* src, size_t src_size) const +{ + const gnutls_datum_t dat {(uint8_t*)src, (unsigned)src_size}; + gnutls_datum_t out; + int err = gnutls_privkey_decrypt_data(key, 0, &dat, &out); + if (err != GNUTLS_E_SUCCESS) + throw DecryptError(std::string("Can't decrypt data: ") + gnutls_strerror(err)); + Blob ret {out.data, out.data+out.size}; + gnutls_free(out.data); + return ret; +} + Blob PrivateKey::decrypt(const Blob& cipher) const { @@ -208,20 +305,12 @@ PrivateKey::decrypt(const Blob& cipher) const throw CryptoException("Must be an RSA key"); unsigned cypher_block_sz = key_len / 8; - if (cipher.size() % cypher_block_sz) - throw CryptoException("Unexpected cipher length"); + if (cipher.size() < cypher_block_sz) + throw DecryptError("Unexpected cipher length"); + else if (cipher.size() == cypher_block_sz) + return decryptBloc(cipher.data(), cypher_block_sz); - Blob ret; - for (auto cb = cipher.cbegin(), ce = cipher.cend(); cb < ce; cb += cypher_block_sz) { - const gnutls_datum_t dat {(uint8_t*)(&(*cb)), cypher_block_sz}; - gnutls_datum_t out; - int err = gnutls_privkey_decrypt_data(key, 0, &dat, &out); - if (err != GNUTLS_E_SUCCESS) - throw DecryptError(std::string("Can't decrypt data: ") + gnutls_strerror(err)); - ret.insert(ret.end(), out.data, out.data+out.size); - gnutls_free(out.data); - } - return ret; + return aesDecrypt(Blob {cipher.begin() + cypher_block_sz, cipher.end()}, decryptBloc(cipher.data(), cypher_block_sz)); } Blob @@ -256,7 +345,7 @@ PrivateKey::getPublicKey() const PublicKey::PublicKey(const Blob& dat) : pk(nullptr) { - unpackBlob(dat); + unpack(dat.data(), dat.size()); } PublicKey::~PublicKey() @@ -286,17 +375,16 @@ PublicKey::pack(Blob& b) const if (err != GNUTLS_E_SUCCESS) throw CryptoException(std::string("Could not export public key: ") + gnutls_strerror(err)); tmp.resize(sz); - serialize<Blob>(tmp, b); + b.insert(b.end(), tmp.begin(), tmp.end()); } void -PublicKey::unpack(Blob::const_iterator& begin, Blob::const_iterator& end) +PublicKey::unpack(const uint8_t* data, size_t data_size) { - Blob tmp = deserialize<Blob>(begin, end); if (pk) gnutls_pubkey_deinit(pk); gnutls_pubkey_init(&pk); - const gnutls_datum_t dat {(uint8_t*)tmp.data(), (unsigned)tmp.size()}; + const gnutls_datum_t dat {(uint8_t*)data, (unsigned)data_size}; int err = gnutls_pubkey_import(pk, &dat, GNUTLS_X509_FMT_PEM); if (err != GNUTLS_E_SUCCESS) err = gnutls_pubkey_import(pk, &dat, GNUTLS_X509_FMT_DER); @@ -304,6 +392,14 @@ PublicKey::unpack(Blob::const_iterator& begin, Blob::const_iterator& end) throw CryptoException(std::string("Could not read public key: ") + gnutls_strerror(err)); } +void +PublicKey::msgpack_unpack(msgpack::object o) +{ + if (o.type != msgpack::type::BIN) + throw msgpack::type_error(); + unpack((const uint8_t*)o.via.bin.ptr, o.via.bin.size); +} + bool PublicKey::checkSignature(const Blob& data, const Blob& signature) const { if (!pk) @@ -314,6 +410,20 @@ PublicKey::checkSignature(const Blob& data, const Blob& signature) const { return rc >= 0; } +void +PublicKey::encryptBloc(const uint8_t* src, size_t src_size, uint8_t* dst, size_t dst_size) const +{ + const gnutls_datum_t key_dat {(uint8_t*)src, (unsigned)src_size}; + gnutls_datum_t encrypted; + auto err = gnutls_pubkey_encrypt_data(pk, 0, &key_dat, &encrypted); + if (err != GNUTLS_E_SUCCESS) + throw CryptoException(std::string("Can't encrypt data: ") + gnutls_strerror(err)); + if (encrypted.size != dst_size) + throw CryptoException("Unexpected cypherblock size"); + std::copy_n(encrypted.data, encrypted.size, dst); + gnutls_free(encrypted.data); +} + Blob PublicKey::encrypt(const Blob& data) const { @@ -327,27 +437,30 @@ PublicKey::encrypt(const Blob& data) const if (err != GNUTLS_PK_RSA) throw CryptoException("Must be an RSA key"); - unsigned max_block_sz = key_len / 8 - 11; - unsigned cypher_block_sz = key_len / 8; - unsigned block_num = data.empty() ? 1 : 1 + (data.size() - 1) / max_block_sz; + const unsigned max_block_sz = key_len / 8 - 11; + const unsigned cypher_block_sz = key_len / 8; + if (data.size() <= max_block_sz) { + Blob ret(cypher_block_sz); + encryptBloc(data.data(), data.size(), ret.data(), cypher_block_sz); + return ret; + } - Blob ret; - auto eb = data.cbegin(); - auto ee = data.cend(); - for (unsigned i=0; i<block_num; i++) { - auto blk_sz = std::min<unsigned>(ee - eb, max_block_sz); - const gnutls_datum_t dat {(uint8_t*)&(*eb), blk_sz}; - gnutls_datum_t encrypted; - err = gnutls_pubkey_encrypt_data(pk, 0, &dat, &encrypted); - if (err != GNUTLS_E_SUCCESS) - throw CryptoException(std::string("Can't encrypt data: ") + gnutls_strerror(err)); - if (encrypted.size != cypher_block_sz) - throw CryptoException("Unexpected cypherblock size"); - ret.insert(ret.end(), encrypted.data, encrypted.data+encrypted.size); - eb += blk_sz; - gnutls_free(encrypted.data); + unsigned aes_key_sz = aesKeySize(max_block_sz); + if (aes_key_sz == 0) + throw CryptoException("Key is not long enough for AES128"); + Blob key(aes_key_sz); + { + crypto::random_device rdev; + std::generate_n(key.begin(), key.size(), std::bind(rand_byte, std::ref(rdev))); } + auto data_encrypted = aesEncrypt(data, key); + + Blob ret; + ret.reserve(cypher_block_sz + data_encrypted.size()); + ret.resize(cypher_block_sz); + encryptBloc(key.data(), key.size(), ret.data(), cypher_block_sz); + ret.insert(ret.end(), data_encrypted.begin(), data_encrypted.end()); return ret; } @@ -363,7 +476,7 @@ PublicKey::getId() const Certificate::Certificate(const Blob& certData) : cert(nullptr) { - unpackBlob(certData); + unpack(certData.data(), certData.size()); } Certificate& @@ -378,7 +491,7 @@ Certificate::operator=(Certificate&& o) noexcept } void -Certificate::unpack(Blob::const_iterator& begin, Blob::const_iterator& end) +Certificate::unpack(const uint8_t* dat, size_t dat_size) { if (cert) { gnutls_x509_crt_deinit(cert); @@ -386,7 +499,7 @@ Certificate::unpack(Blob::const_iterator& begin, Blob::const_iterator& end) } gnutls_x509_crt_t* cert_list; unsigned cert_num; - const gnutls_datum_t crt_dt {(uint8_t*)&(*begin), (unsigned)(end-begin)}; + const gnutls_datum_t crt_dt {(uint8_t*)dat, (unsigned)dat_size}; int err = gnutls_x509_crt_list_import2(&cert_list, &cert_num, &crt_dt, GNUTLS_X509_FMT_PEM, GNUTLS_X509_CRT_LIST_FAIL_IF_UNSORTED); if (err != GNUTLS_E_SUCCESS) err = gnutls_x509_crt_list_import2(&cert_list, &cert_num, &crt_dt, GNUTLS_X509_FMT_DER, GNUTLS_X509_CRT_LIST_FAIL_IF_UNSORTED); @@ -405,6 +518,14 @@ Certificate::unpack(Blob::const_iterator& begin, Blob::const_iterator& end) gnutls_free(cert_list); } +void +Certificate::msgpack_unpack(msgpack::object o) +{ + if (o.type != msgpack::type::BIN) + throw msgpack::type_error(); + unpack((const uint8_t*)o.via.bin.ptr, o.via.bin.size); +} + void Certificate::pack(Blob& b) const { diff --git a/src/default_types.cpp b/src/default_types.cpp index 63b413f9a370e62865ee5e011d2234e13423c5e0..05a1d4184a433f8f792e1c65749e5e505e73bc9c 100644 --- a/src/default_types.cpp +++ b/src/default_types.cpp @@ -38,29 +38,14 @@ std::ostream& operator<< (std::ostream& s, const DhtMessage& v) return s; } -void -DhtMessage::pack(Blob& res) const -{ - serialize<std::string>(service, res); - serialize<Blob>(data, res); -} - -void -DhtMessage::unpack(Blob::const_iterator& begin, Blob::const_iterator& end) -{ - service = deserialize<std::string>(begin, end); - data = deserialize<Blob>(begin, end); -} - bool DhtMessage::storePolicy(InfoHash, std::shared_ptr<Value>& v, InfoHash, const sockaddr*, socklen_t) { - DhtMessage request; try { - request.unpackBlob(v->data); + auto msg = unpackMsg<DhtMessage>(v->data); + if (msg.service.empty()) + return false; } catch (const std::exception& e) {} - if (request.service.empty()) - return false; return true; } @@ -71,9 +56,7 @@ DhtMessage::ServiceFilter(std::string s) Value::TypeFilter(TYPE), [s](const Value& v) { try { - auto b = v.data.cbegin(), e = v.data.cend(); - auto service = deserialize<std::string>(b, e); - return service == s; + return unpackMsg<DhtMessage>(v.data).service == s; } catch (const std::exception& e) { return false; } @@ -95,51 +78,20 @@ std::ostream& operator<< (std::ostream& s, const IpServiceAnnouncement& v) return s; } -void -IpServiceAnnouncement::pack(Blob& res) const -{ - serialize<in_port_t>(getPort(), res); - if (ss.ss_family == AF_INET) { - auto sa4 = reinterpret_cast<const sockaddr_in*>(&ss); - serialize<in_addr>(sa4->sin_addr, res); - } else if (ss.ss_family == AF_INET6) { - auto sa6 = reinterpret_cast<const sockaddr_in6*>(&ss); - serialize<in6_addr>(sa6->sin6_addr, res); - } -} - -void -IpServiceAnnouncement::unpack(Blob::const_iterator& begin, Blob::const_iterator& end) -{ - setPort(deserialize<in_port_t>(begin, end)); - size_t addr_size = end - begin; - if (addr_size < sizeof(in_addr)) { - ss.ss_family = 0; - } else if (addr_size == sizeof(in_addr)) { - auto sa4 = reinterpret_cast<sockaddr_in*>(&ss); - sa4->sin_family = AF_INET; - sa4->sin_addr = deserialize<in_addr>(begin, end); - } else if (addr_size == sizeof(in6_addr)) { - auto sa6 = reinterpret_cast<sockaddr_in6*>(&ss); - sa6->sin6_family = AF_INET6; - sa6->sin6_addr = deserialize<in6_addr>(begin, end); - } else { - throw std::runtime_error("ServiceAnnouncement parse error."); - } -} - bool IpServiceAnnouncement::storePolicy(InfoHash, std::shared_ptr<Value>& v, InfoHash, const sockaddr* from, socklen_t fromlen) { - IpServiceAnnouncement request {}; - request.unpackBlob(v->data); - if (request.getPort() == 0) - return false; - IpServiceAnnouncement sa_addr {from, fromlen}; - sa_addr.setPort(request.getPort()); - // argument v is modified (not the value). - v = std::make_shared<Value>(IpServiceAnnouncement::TYPE, sa_addr, v->id); - return true; + try { + auto msg = unpackMsg<IpServiceAnnouncement>(v->data); + if (msg.getPort() == 0) + return false; + IpServiceAnnouncement sa_addr {from, fromlen}; + sa_addr.setPort(msg.getPort()); + // argument v is modified (not the value). + v = std::make_shared<Value>(IpServiceAnnouncement::TYPE, sa_addr, v->id); + return true; + } catch (const std::exception& e) {} + return false; } const ValueType DhtMessage::TYPE = {1, "DHT message", std::chrono::minutes(5), DhtMessage::storePolicy, ValueType::DEFAULT_EDIT_POLICY}; diff --git a/src/dht.cpp b/src/dht.cpp index 84a6a89e58b5fc0acf719f38c1066b2b1779a27c..e6b315e7e4505224bc8a22830a9bccecf4ee5f56 100644 --- a/src/dht.cpp +++ b/src/dht.cpp @@ -27,6 +27,7 @@ THE SOFTWARE. #include "dht.h" #include "rng.h" +#include <msgpack.hpp> extern "C" { #include <gnutls/gnutls.h> } @@ -140,11 +141,11 @@ namespace dht { const Dht::TransPrefix Dht::TransPrefix::PING = {"pn"}; const Dht::TransPrefix Dht::TransPrefix::FIND_NODE = {"fn"}; -const Dht::TransPrefix Dht::TransPrefix::GET_VALUES = {"gp"}; -const Dht::TransPrefix Dht::TransPrefix::ANNOUNCE_VALUES = {"ap"}; -const Dht::TransPrefix Dht::TransPrefix::LISTEN = {"ls"}; +const Dht::TransPrefix Dht::TransPrefix::GET_VALUES = {"gt"}; +const Dht::TransPrefix Dht::TransPrefix::ANNOUNCE_VALUES = {"pt"}; +const Dht::TransPrefix Dht::TransPrefix::LISTEN = {"lt"}; -const uint8_t Dht::my_v[9] = "1:v4:RNG"; +const std::string my_v = "RNG"; static constexpr InfoHash zeroes {}; static constexpr InfoHash ones = {std::array<uint8_t, HASH_LEN>{ @@ -968,6 +969,7 @@ Dht::searchStep(Search& sr) return sn.node->isExpired(now); }) == sr.nodes.size()) { + DHT_WARN("Search expired"); // no nodes or all expired nodes sr.expired = true; if (sr.announce.empty() && sr.listeners.empty()) { @@ -1286,7 +1288,7 @@ Dht::search(const InfoHash& id, sa_family_t af, GetCallback callback, DoneCallba } void -Dht::announce(const InfoHash& id, sa_family_t af, const std::shared_ptr<Value>& value, DoneCallback callback) +Dht::announce(const InfoHash& id, sa_family_t af, std::shared_ptr<Value> value, DoneCallback callback) { if (!value) { if (callback) @@ -1441,7 +1443,7 @@ Dht::cancelListen(const InfoHash& id, size_t token) } void -Dht::put(const InfoHash& id, const std::shared_ptr<Value>& val, DoneCallback callback) +Dht::put(const InfoHash& id, std::shared_ptr<Value> val, DoneCallback callback) { now = clock::now(); @@ -2180,26 +2182,6 @@ Dht::processMessage(const uint8_t *buf, size_t buflen, const sockaddr *from, soc if (buflen == 0) return; - //DHT_DEBUG("processMessage %p %lu %p %lu", buf, buflen, from, fromlen); - - MessageType message; - InfoHash id, info_hash, target; - TransId tid; - Blob token {}; - uint8_t nodes[26*16], nodes6[38*16]; - unsigned nodes_len = 26*16, nodes6_len = 38*16; - in_port_t port; - Value::Id value_id; - uint16_t error_code; - sockaddr_storage addr; - socklen_t addr_length = sizeof(sockaddr_storage); - - std::vector<std::shared_ptr<Value>> values; - - want_t want; - uint16_t ttid; - bool ring; - if (isMartian(from, fromlen)) return; @@ -2208,15 +2190,13 @@ Dht::processMessage(const uint8_t *buf, size_t buflen, const sockaddr *from, soc return; } - if (buf[buflen] != '\0') - throw DhtException("Unterminated message."); + //DHT_DEBUG("processMessage %p %lu %p %lu", buf, buflen, from, fromlen); + ParsedMessage msg; try { - message = parseMessage(buf, buflen, tid, id, info_hash, target, - port, token, value_id, - nodes, &nodes_len, nodes6, &nodes6_len, - values, &want, error_code, ring, (sockaddr*)&addr, addr_length); - if (message != MessageType::Error && id == zeroes) + msgpack::unpacked msg_res = msgpack::unpack((const char*)buf, buflen); + msg.msgpack_unpack(msg_res.get()); + if (msg.type != MessageType::Error && msg.id == zeroes) throw DhtException("no or invalid InfoHash"); } catch (const std::exception& e) { DHT_WARN("Can't process message of size %lu: %s.", buflen, e.what()); @@ -2224,16 +2204,12 @@ Dht::processMessage(const uint8_t *buf, size_t buflen, const sockaddr *from, soc return; } - // drop msg with unknown protocol - if (not ring) - return; - - if (id == myid) { + if (msg.id == myid) { DHT_DEBUG("Received message from self."); return; } - if (message > MessageType::Reply) { + if (msg.type > MessageType::Reply) { /* Rate limit requests. */ if (!rateLimit()) { DHT_WARN("Dropping request due to rate limiting."); @@ -2242,14 +2218,15 @@ Dht::processMessage(const uint8_t *buf, size_t buflen, const sockaddr *from, soc } //std::cout << "Message from " << id << " IPv" << (from->sa_family==AF_INET?'4':'6') << std::endl; + uint16_t ttid = 0; - switch (message) { + switch (msg.type) { case MessageType::Error: - if (tid.length != 4) return; - if (error_code == 401 && id != zeroes && (tid.matches(TransPrefix::ANNOUNCE_VALUES, &ttid) || tid.matches(TransPrefix::LISTEN, &ttid))) { + if (msg.tid.length != 4) return; + if (msg.error_code == 401 && msg.id != zeroes && (msg.tid.matches(TransPrefix::ANNOUNCE_VALUES, &ttid) || msg.tid.matches(TransPrefix::LISTEN, &ttid))) { auto esr = findSearch(ttid, from->sa_family); if (!esr) return; - auto ne = newNode(id, from, fromlen, 2); + auto ne = newNode(msg.id, from, fromlen, 2); unsigned cleared = 0; for (auto& sr : searches) { for (auto& n : sr.nodes) { @@ -2262,45 +2239,45 @@ Dht::processMessage(const uint8_t *buf, size_t buflen, const sockaddr *from, soc break; } } - DHT_WARN("Token flush for node %s (%d searches affected)", id.toString().c_str(), cleared); + DHT_WARN("Token flush for node %s (%d searches affected)",msg.id.toString().c_str(), cleared); } else { - DHT_WARN("Received unknown error message %u from %s:", error_code, id.toString().c_str()); + DHT_WARN("Received unknown error message %u from %s:", msg.error_code,msg.id.toString().c_str()); DHT_WARN.logPrintable(buf, buflen); } break; case MessageType::Reply: - if (tid.length != 4) { - DHT_ERROR("Broken node truncates transaction ids (len: %d): ", tid.length); + if (msg.tid.length != 4) { + DHT_ERROR("Broken node truncates transaction ids (len: %d): ", msg.tid.length); DHT_ERROR.logPrintable(buf, buflen); /* This is really annoying, as it means that we will time-out all our searches that go through this node. Kill it. */ - blacklistNode(&id, from, fromlen); + blacklistNode(&msg.id, from, fromlen); return; } - if (tid.matches(TransPrefix::PING)) { + if (msg.tid.matches(TransPrefix::PING)) { DHT_DEBUG("Pong!"); - newNode(id, from, fromlen, 2, (sockaddr*)&addr, addr_length); - } else if (tid.matches(TransPrefix::FIND_NODE) or tid.matches(TransPrefix::GET_VALUES)) { + newNode(msg.id, from, fromlen, 2, (sockaddr*)&msg.addr.first, msg.addr.second); + } else if (msg.tid.matches(TransPrefix::FIND_NODE) or msg.tid.matches(TransPrefix::GET_VALUES)) { bool gp = false; Search *sr = nullptr; std::shared_ptr<Node> n; - if (tid.matches(TransPrefix::GET_VALUES, &ttid)) { + if (msg.tid.matches(TransPrefix::GET_VALUES, &ttid)) { gp = true; sr = findSearch(ttid, from->sa_family); } - DHT_DEBUG("Nodes found (%u+%u)%s!", nodes_len/26, nodes6_len/38, gp ? " for get_values" : ""); - if (nodes_len % 26 != 0 || nodes6_len % 38 != 0) { + DHT_DEBUG("Nodes found (%u+%u)%s!", msg.nodes4.size()/26, msg.nodes6.size()/38, gp ? " for get_values" : ""); + if (msg.nodes4.size() % 26 != 0 || msg.nodes6.size() % 38 != 0) { DHT_WARN("Unexpected length for node info!"); - blacklistNode(&id, from, fromlen); + blacklistNode(&msg.id, from, fromlen); break; } else if (gp && sr == nullptr) { DHT_WARN("Unknown search with tid %u !", ttid); - n = newNode(id, from, fromlen, 1); + n = newNode(msg.id, from, fromlen, 1); } else { - n = newNode(id, from, fromlen, 2, (sockaddr*)&addr, addr_length); - for (unsigned i = 0; i < nodes_len / 26; i++) { - uint8_t *ni = nodes + i * 26; + n = newNode(msg.id, from, fromlen, 2, (sockaddr*)&msg.addr.first, msg.addr.second); + for (unsigned i = 0; i < msg.nodes4.size() / 26; i++) { + uint8_t *ni = msg.nodes4.data() + i * 26; const InfoHash& ni_id = *reinterpret_cast<InfoHash*>(ni); if (ni_id == myid) continue; @@ -2314,8 +2291,8 @@ Dht::processMessage(const uint8_t *buf, size_t buflen, const sockaddr *from, soc sr->insertNode(sn, now); } } - for (unsigned i = 0; i < nodes6_len / 38; i++) { - uint8_t *ni = nodes6 + i * 38; + for (unsigned i = 0; i < msg.nodes6.size() / 38; i++) { + uint8_t *ni = msg.nodes6.data() + i * 38; InfoHash* ni_id = reinterpret_cast<InfoHash*>(ni); if (*ni_id == myid) continue; @@ -2339,13 +2316,13 @@ Dht::processMessage(const uint8_t *buf, size_t buflen, const sockaddr *from, soc } } if (sr) { - sr->insertNode(n, now, token); - if (!values.empty()) { - DHT_DEBUG("Got %d values !", values.size()); + sr->insertNode(n, now, msg.token); + if (!msg.values.empty()) { + DHT_DEBUG("Got %d values !", msg.values.size()); for (auto& cb : sr->callbacks) { if (!cb.get_cb) continue; std::vector<std::shared_ptr<Value>> tmp; - std::copy_if(values.begin(), values.end(), std::back_inserter(tmp), [&](const std::shared_ptr<Value>& v) { + std::copy_if(msg.values.begin(), msg.values.end(), std::back_inserter(tmp), [&](const std::shared_ptr<Value>& v) { return not static_cast<bool>(cb.filter) or cb.filter(*v); }); if (not tmp.empty()) @@ -2355,7 +2332,7 @@ Dht::processMessage(const uint8_t *buf, size_t buflen, const sockaddr *from, soc for (auto& l : sr->listeners) { if (!l.second.get_cb) continue; std::vector<std::shared_ptr<Value>> tmp; - std::copy_if(values.begin(), values.end(), std::back_inserter(tmp), [&](const std::shared_ptr<Value>& v) { + std::copy_if(msg.values.begin(), msg.values.end(), std::back_inserter(tmp), [&](const std::shared_ptr<Value>& v) { return not static_cast<bool>(l.second.filter) or l.second.filter(*v); }); if (not tmp.empty()) @@ -2368,17 +2345,17 @@ Dht::processMessage(const uint8_t *buf, size_t buflen, const sockaddr *from, soc if (sr->isSynced(now)) search_time = now; } - } else if (tid.matches(TransPrefix::ANNOUNCE_VALUES, &ttid)) { + } else if (msg.tid.matches(TransPrefix::ANNOUNCE_VALUES, &ttid)) { DHT_DEBUG("Got reply to announce_values."); Search *sr = findSearch(ttid, from->sa_family); - if (!sr || value_id == Value::INVALID_ID) { + if (!sr || msg.value_id == Value::INVALID_ID) { DHT_DEBUG("Unknown search or announce!"); - newNode(id, from, fromlen, 1); + newNode(msg.id, from, fromlen, 1); } else { - auto n = newNode(id, from, fromlen, 2, (sockaddr*)&addr, addr_length); + auto n = newNode(msg.id, from, fromlen, 2, (sockaddr*)&msg.addr.first, msg.addr.second); for (auto& sn : sr->nodes) if (sn.node == n) { - auto it = sn.acked.emplace(value_id, SearchNode::RequestStatus{}); + auto it = sn.acked.emplace(msg.value_id, SearchNode::RequestStatus{}); it.first->second.reply_time = now; break; } @@ -2388,22 +2365,22 @@ Dht::processMessage(const uint8_t *buf, size_t buflen, const sockaddr *from, soc // If the value was just successfully announced, call the callback for (auto& a : sr->announce) { - if (!a.callback || !a.value || a.value->id != value_id) + if (!a.callback || !a.value || a.value->id != msg.value_id) continue; - if (sr->isAnnounced(value_id, getType(a.value->type), now)) { + if (sr->isAnnounced(msg.value_id, getType(a.value->type), now)) { a.callback(true, sr->getNodes()); a.callback = nullptr; } } } - } else if (tid.matches(TransPrefix::LISTEN, &ttid)) { + } else if (msg.tid.matches(TransPrefix::LISTEN, &ttid)) { DHT_DEBUG("Got reply to listen."); Search *sr = findSearch(ttid, from->sa_family); if (!sr) { DHT_DEBUG("Unknown search or announce!"); - newNode(id, from, fromlen, 1); + newNode(msg.id, from, fromlen, 1); } else { - auto n = newNode(id, from, fromlen, 2, (sockaddr*)&addr, addr_length); + auto n = newNode(msg.id, from, fromlen, 2, (sockaddr*)&msg.addr.first, msg.addr.second); for (auto& sn : sr->nodes) if (sn.node == n) { sn.listenStatus.reply_time = now; @@ -2419,73 +2396,72 @@ Dht::processMessage(const uint8_t *buf, size_t buflen, const sockaddr *from, soc } break; case MessageType::Ping: - //DHT_DEBUG("Got ping (%d)!", tid.length); - newNode(id, from, fromlen, 1); + newNode(msg.id, from, fromlen, 1); //DHT_DEBUG("Sending pong."); - sendPong(from, fromlen, tid); + sendPong(from, fromlen, msg.tid); break; case MessageType::FindNode: DHT_DEBUG("Got \"find node\" request"); - newNode(id, from, fromlen, 1); - DHT_DEBUG("Sending closest nodes (%d).", want); - sendClosestNodes(from, fromlen, tid, target, want); + newNode(msg.id, from, fromlen, 1); + DHT_DEBUG("Sending closest nodes (%d).", msg.want); + sendClosestNodes(from, fromlen, msg.tid, msg.target, msg.want); break; case MessageType::GetValues: DHT_DEBUG("Got \"get values\" request"); - newNode(id, from, fromlen, 1); - if (info_hash == zeroes) { - DHT_WARN("Eek! Got get_values with no info_hash from %s %s.", id.toString().c_str(), print_addr(from, fromlen).c_str()); - sendError(from, fromlen, tid, 203, "Get_values with no info_hash"); + newNode(msg.id, from, fromlen, 1); + if (msg.info_hash == zeroes) { + DHT_WARN("Eek! Got get_values with no info_hash from %s %s.", msg.id.toString().c_str(), print_addr(from, fromlen).c_str()); + sendError(from, fromlen, msg.tid, 203, "Get_values with no info_hash"); break; } else { - Storage* st = findStorage(info_hash); + Storage* st = findStorage(msg.info_hash); Blob ntoken = makeToken(from, false); if (st && st->values.size() > 0) { DHT_DEBUG("Sending found%s values.", from->sa_family == AF_INET6 ? " IPv6" : ""); - sendClosestNodes(from, fromlen, tid, info_hash, want, ntoken, st->values); + sendClosestNodes(from, fromlen, msg.tid, msg.info_hash, msg.want, ntoken, st->values); } else { DHT_DEBUG("Sending nodes for get_values."); - sendClosestNodes(from, fromlen, tid, info_hash, want, ntoken); + sendClosestNodes(from, fromlen, msg.tid, msg.info_hash, msg.want, ntoken); } } break; case MessageType::AnnounceValue: DHT_DEBUG("Got \"announce value\" request!"); - newNode(id, from, fromlen, 1); - if (info_hash == zeroes) { + newNode(msg.id, from, fromlen, 1); + if (msg.info_hash == zeroes) { DHT_WARN("Announce_value with no info_hash."); - sendError(from, fromlen, tid, 203, "Announce_value with no info_hash"); + sendError(from, fromlen, msg.tid, 203, "Announce_value with no info_hash"); break; } - if (!tokenMatch(token, from)) { - DHT_WARN("Incorrect token %s for announce_values.", to_hex(token.data(), token.size()).c_str()); - sendError(from, fromlen, tid, 401, "Announce_value with wrong token", true); + if (!tokenMatch(msg.token, from)) { + DHT_WARN("Incorrect token %s for announce_values.", to_hex(msg.token.data(), msg.token.size()).c_str()); + sendError(from, fromlen, msg.tid, 401, "Announce_value with wrong token", true); break; } - for (const auto& v : values) { + for (const auto& v : msg.values) { if (v->id == Value::INVALID_ID) { DHT_WARN("Incorrect value id "); - sendError(from, fromlen, tid, 203, "Announce_value with invalid id"); + sendError(from, fromlen, msg.tid, 203, "Announce_value with invalid id"); continue; } - auto lv = getLocalById(info_hash, v->id); + auto lv = getLocalById(msg.info_hash, v->id); std::shared_ptr<Value> vc = v; if (lv) { const auto& type = getType(lv->type); - if (type.editPolicy(info_hash, lv, vc, id, from, fromlen)) { - DHT_DEBUG("Editing value of type %s belonging to %s at %s.", type.name.c_str(), v->owner.getId().toString().c_str(), info_hash.toString().c_str()); - storageStore(info_hash, vc); + if (type.editPolicy(msg.info_hash, lv, vc, msg.id, from, fromlen)) { + DHT_DEBUG("Editing value of type %s belonging to %s at %s.", type.name.c_str(), v->owner.getId().toString().c_str(), msg.info_hash.toString().c_str()); + storageStore(msg.info_hash, vc); } else { - DHT_WARN("Rejecting edition of type %s belonging to %s at %s because of storage policy.", type.name.c_str(), v->owner.getId().toString().c_str(), info_hash.toString().c_str()); + DHT_WARN("Rejecting edition of type %s belonging to %s at %s because of storage policy.", type.name.c_str(), v->owner.getId().toString().c_str(), msg.info_hash.toString().c_str()); } } else { // Allow the value to be edited by the storage policy const auto& type = getType(vc->type); - if (type.storePolicy(info_hash, vc, id, from, fromlen)) { - DHT_DEBUG("Storing value of type %s belonging to %s at %s.", type.name.c_str(), v->owner.getId().toString().c_str(), info_hash.toString().c_str()); - storageStore(info_hash, vc); + if (type.storePolicy(msg.info_hash, vc, msg.id, from, fromlen)) { + DHT_DEBUG("Storing value of type %s belonging to %s at %s.", type.name.c_str(), v->owner.getId().toString().c_str(), msg.info_hash.toString().c_str()); + storageStore(msg.info_hash, vc); } else { - DHT_WARN("Rejecting storage of type %s belonging to %s at %s because of storage policy.", type.name.c_str(), v->owner.getId().toString().c_str(), info_hash.toString().c_str()); + DHT_WARN("Rejecting storage of type %s belonging to %s at %s because of storage policy.", type.name.c_str(), v->owner.getId().toString().c_str(), msg.info_hash.toString().c_str()); } } @@ -2493,26 +2469,26 @@ Dht::processMessage(const uint8_t *buf, size_t buflen, const sockaddr *from, soc This is to prevent them from backtracking, and hence polluting the DHT. */ DHT_DEBUG("Sending announceValue confirmation."); - sendValueAnnounced(from, fromlen, tid, v->id); + sendValueAnnounced(from, fromlen, msg.tid, v->id); } break; case MessageType::Listen: - if (info_hash == zeroes) { + if (msg.info_hash == zeroes) { DHT_WARN("Listen with no info_hash."); - sendError(from, fromlen, tid, 203, "Listen with no info_hash"); + sendError(from, fromlen, msg.tid, 203, "Listen with no info_hash"); break; } - if (!tokenMatch(token, from)) { - DHT_WARN("Incorrect token %s for announce_values.", to_hex(token.data(), token.size()).c_str()); - sendError(from, fromlen, tid, 401, "Listen with wrong token", true); + if (!tokenMatch(msg.token, from)) { + DHT_WARN("Incorrect token %s for announce_values.", to_hex(msg.token.data(), msg.token.size()).c_str()); + sendError(from, fromlen, msg.tid, 401, "Listen with wrong token", true); break; } - if (!tid.matches(TransPrefix::LISTEN, &ttid)) { + if (!msg.tid.matches(TransPrefix::LISTEN, &ttid)) { break; } - newNode(id, from, fromlen, 1); - storageAddListener(info_hash, id, from, fromlen, ttid); - sendListenConfirmation(from, fromlen, tid); + newNode(msg.id, from, fromlen, 1); + storageAddListener(msg.info_hash, msg.id, from, fromlen, ttid); + sendListenConfirmation(from, fromlen, msg.tid); break; } } @@ -2586,11 +2562,16 @@ Dht::exportValues() const for (const auto& h : store) { ValuesExport ve; ve.first = h.id; - serialize<uint16_t>(h.values.size(), ve.second); + + msgpack::sbuffer buffer; + msgpack::packer<msgpack::sbuffer> pk(&buffer); + pk.pack_array(h.values.size()); for (const auto& v : h.values) { - serialize<time_point>(v.time, ve.second); - v.data->pack(ve.second); + pk.pack_array(2); + pk.pack(v.time.time_since_epoch().count()); + v.data->msgpack_pack(pk); } + ve.second = {buffer.data(), buffer.data()+buffer.size()}; e.push_back(std::move(ve)); } return e; @@ -2602,16 +2583,22 @@ Dht::importValues(const std::vector<ValuesExport>& import) for (const auto& h : import) { if (h.second.empty()) continue; - auto b = h.second.begin(), - e = h.second.end(); + try { - const size_t n_vals = deserialize<uint16_t>(b, e); - for (unsigned i = 0; i < n_vals; i++) { + msgpack::unpacked msg; + msgpack::unpack(&msg, (const char*)h.second.data(), h.second.size()); + auto valarr = msg.get(); + if (valarr.type != msgpack::type::ARRAY) + throw msgpack::type_error(); + for (unsigned i = 0; i < valarr.via.array.size; i++) { + auto& valel = valarr.via.array.ptr[i]; + if (valel.via.array.size < 2) + throw msgpack::type_error(); time_point val_time; Value tmp_val; try { - val_time = deserialize<time_point>(b, e); - tmp_val.unpack(b, e); + val_time = time_point{time_point::duration{valel.via.array.ptr[0].as<time_point::duration::rep>()}}; + tmp_val.msgpack_unpack(valel.via.array.ptr[1]); } catch (const std::exception&) { DHT_ERROR("Error reading value at %s", h.first.toString().c_str()); continue; @@ -2679,24 +2666,22 @@ Dht::pingNode(const sockaddr *sa, socklen_t salen) return sendPing(sa, salen, TransId {TransPrefix::PING}); } -/* We could use a proper bencoding printer and parser, but the format of - DHT messages is fairly stylised, so this seemed simpler. */ - -#define CHECK(offset, delta, size) \ - if (offset + delta > size) throw std::length_error("Provided buffer is not large enough."); - -#define INC(offset, delta, size) \ - if (delta < 0) throw std::length_error("Provided buffer is not large enough."); \ - CHECK(offset, (size_t)delta, size); \ - offset += delta - -#define COPY(buf, offset, src, delta, size) \ - CHECK(offset, delta, size); \ - memcpy(buf + offset, src, delta); \ - offset += delta; +void +insertAddr(msgpack::packer<msgpack::sbuffer>& pk, const sockaddr *sa, socklen_t) +{ + size_t addr_len = (sa->sa_family == AF_INET) ? sizeof(in_addr) : sizeof(in6_addr); + void* addr_ptr = (sa->sa_family == AF_INET) ? (void*)&((sockaddr_in*)sa)->sin_addr + : (void*)&((sockaddr_in6*)sa)->sin6_addr; + pk.pack("sa"); + pk.pack_bin(addr_len); + pk.pack_bin_body((char*)addr_ptr, addr_len); +} -#define ADD_V(buf, offset, size) \ - COPY(buf, offset, my_v, sizeof(my_v), size); +void +insertV(msgpack::packer<msgpack::sbuffer>& pk) +{ + pk.pack("v"); pk.pack(my_v); +} int Dht::send(const char *buf, size_t len, int flags, const sockaddr *sa, socklen_t salen) @@ -2725,67 +2710,66 @@ Dht::send(const char *buf, size_t len, int flags, const sockaddr *sa, socklen_t int Dht::sendPing(const sockaddr *sa, socklen_t salen, TransId tid) { - char buf[512]; - int i = 0, rc; - rc = snprintf(buf + i, 512 - i, "d1:ad2:id20:"); INC(i, rc, 512); - COPY(buf, i, myid.data(), myid.size(), 512); - rc = snprintf(buf + i, 512 - i, "e1:q4:ping1:t%d:", tid.length); - INC(i, rc, 512); - COPY(buf, i, tid.data(), tid.length, 512); - ADD_V(buf, i, 512); - rc = snprintf(buf + i, 512 - i, "1:y1:qe"); INC(i, rc, 512); - return send(buf, i, 0, sa, salen); -} + msgpack::sbuffer buffer; + msgpack::packer<msgpack::sbuffer> pk(&buffer); + pk.pack_map(5); -void -insertAddr(char* buf, size_t buflen, size_t& p, const sockaddr *sa, socklen_t) -{ - size_t addr_len = (sa->sa_family == AF_INET) ? sizeof(in_addr) : sizeof(in6_addr); - void* addr_ptr = (sa->sa_family == AF_INET) ? (void*)&((sockaddr_in*)sa)->sin_addr - : (void*)&((sockaddr_in6*)sa)->sin6_addr; - int rc = snprintf(buf + p, buflen - p, "2:sa%lu:", addr_len); - INC(p, rc, buflen); - COPY(buf, p, addr_ptr, addr_len, buflen); + pk.pack(std::string("a")); pk.pack_map(1); + pk.pack(std::string("id")); pk.pack(myid); + + pk.pack(std::string("q")); pk.pack(std::string("ping")); + pk.pack(std::string("t")); pk.pack_bin(tid.size()); + pk.pack_bin_body((const char*)tid.data(), tid.size()); + pk.pack(std::string("y")); pk.pack(std::string("q")); + pk.pack(std::string("v")); pk.pack(std::string("RNG1")); + + return send(buffer.data(), buffer.size(), 0, sa, salen); } int Dht::sendPong(const sockaddr *sa, socklen_t salen, TransId tid) { - char buf[512]; - size_t i = 0; - auto rc = snprintf(buf + i, 512 - i, "d1:rd2:id20:"); INC(i, rc, 512); - COPY(buf, i, myid.data(), myid.size(), 512); - insertAddr(buf, 512, i, sa, salen); - rc = snprintf(buf + i, 512 - i, "e1:t%d:", tid.length); INC(i, rc, 512); - COPY(buf, i, tid.data(), tid.length, 512); - ADD_V(buf, i, 512); - rc = snprintf(buf + i, 512 - i, "1:y1:re"); INC(i, rc, 512); - return send(buf, i, 0, sa, salen); + msgpack::sbuffer buffer; + msgpack::packer<msgpack::sbuffer> pk(&buffer); + pk.pack_map(4); + + pk.pack(std::string("r")); pk.pack_map(2); + pk.pack(std::string("id")); pk.pack(myid); + insertAddr(pk, sa, salen); + + pk.pack(std::string("t")); pk.pack_bin(tid.size()); + pk.pack_bin_body((const char*)tid.data(), tid.size()); + pk.pack(std::string("y")); pk.pack(std::string("r")); + pk.pack(std::string("v")); pk.pack(std::string("RNG1")); + + return send(buffer.data(), buffer.size(), 0, sa, salen); } int Dht::sendFindNode(const sockaddr *sa, socklen_t salen, TransId tid, const InfoHash& target, want_t want, int confirm) { - constexpr const size_t BUF_SZ = 512; - char buf[BUF_SZ]; - int i = 0, rc; - rc = snprintf(buf + i, BUF_SZ - i, "d1:ad2:id20:"); INC(i, rc, BUF_SZ); - COPY(buf, i, myid.data(), myid.size(), BUF_SZ); - rc = snprintf(buf + i, BUF_SZ - i, "6:target20:"); INC(i, rc, BUF_SZ); - COPY(buf, i, target.data(), target.size(), BUF_SZ); + msgpack::sbuffer buffer; + msgpack::packer<msgpack::sbuffer> pk(&buffer); + pk.pack_map(5); + + pk.pack(std::string("a")); pk.pack_map(2 + (want>0?1:0)); + pk.pack(std::string("id")); pk.pack(myid); + pk.pack(std::string("target")); pk.pack(target); if (want > 0) { - rc = snprintf(buf + i, BUF_SZ - i, "4:wantl%s%se", - (want & WANT4) ? "2:n4" : "", - (want & WANT6) ? "2:n6" : ""); - INC(i, rc, BUF_SZ); + pk.pack(std::string("w")); + pk.pack_array(((want & WANT4)?1:0) + ((want & WANT6)?1:0)); + if (want & WANT4) pk.pack(AF_INET); + if (want & WANT6) pk.pack(AF_INET6); } - rc = snprintf(buf + i, BUF_SZ - i, "e1:q9:find_node1:t%d:", tid.length); - INC(i, rc, BUF_SZ); - COPY(buf, i, tid.data(), tid.length, BUF_SZ); - ADD_V(buf, i, BUF_SZ); - rc = snprintf(buf + i, BUF_SZ - i, "1:y1:qe"); INC(i, rc, BUF_SZ); - return send(buf, i, confirm ? MSG_CONFIRM : 0, sa, salen); + + pk.pack(std::string("q")); pk.pack(std::string("find")); + pk.pack(std::string("t")); pk.pack_bin(tid.size()); + pk.pack_bin_body((const char*)tid.data(), tid.size()); + pk.pack(std::string("y")); pk.pack(std::string("q")); + pk.pack(std::string("v")); pk.pack(std::string("RNG1")); + + return send(buffer.data(), buffer.size(), confirm ? 0 : MSG_CONFIRM, sa, salen); } int @@ -2794,56 +2778,51 @@ Dht::sendNodesValues(const sockaddr *sa, socklen_t salen, TransId tid, const uint8_t *nodes6, unsigned nodes6_len, const std::vector<ValueStorage>& st, const Blob& token) { - constexpr const size_t BUF_SZ = 2048 * 64; - char buf[BUF_SZ]; - size_t i = 0; + msgpack::sbuffer buffer; + msgpack::packer<msgpack::sbuffer> pk(&buffer); + pk.pack_map(4); - auto rc = snprintf(buf + i, BUF_SZ - i, "d1:rd2:id20:"); INC(i, rc, BUF_SZ); - COPY(buf, i, myid.data(), myid.size(), BUF_SZ); - insertAddr(buf, BUF_SZ, i, sa, salen); + pk.pack(std::string("r")); + pk.pack_map(2 + (not st.empty()?1:0) + (nodes_len>0?1:0) + (nodes6_len>0?1:0) + (not token.empty()?1:0)); + pk.pack(std::string("id")); pk.pack(myid); + insertAddr(pk, sa, salen); if (nodes_len > 0) { - rc = snprintf(buf + i, BUF_SZ - i, "5:nodes%u:", nodes_len); - INC(i, rc, BUF_SZ); - COPY(buf, i, nodes, nodes_len, BUF_SZ); + pk.pack(std::string("n4")); + pk.pack_bin(nodes_len); + pk.pack_bin_body((const char*)nodes, nodes_len); } if (nodes6_len > 0) { - rc = snprintf(buf + i, BUF_SZ - i, "6:nodes6%u:", nodes6_len); - INC(i, rc, BUF_SZ); - COPY(buf, i, nodes6, nodes6_len, BUF_SZ); + pk.pack(std::string("n6")); + pk.pack_bin(nodes6_len); + pk.pack_bin_body((const char*)nodes6, nodes6_len); } if (not token.empty()) { - rc = snprintf(buf + i, BUF_SZ - i, "5:token%lu:", token.size()); - INC(i, rc, BUF_SZ); - COPY(buf, i, token.data(), token.size(), BUF_SZ); + pk.pack(std::string("token")); pk.pack(token); } - - if (st.size() > 0) { - /* We treat the storage as a circular list, and serve a randomly - chosen slice. In order to make sure we fit, - we limit ourselves to 50 values. */ + if (not st.empty()) { + // We treat the storage as a circular list, and serve a randomly + // chosen slice. In order to make sure we fit, + // we limit ourselves to 50 values. std::uniform_int_distribution<> pos_dis(0, st.size()-1); unsigned j0 = pos_dis(rd); unsigned j = j0; unsigned k = 0; - rc = snprintf(buf + i, BUF_SZ - i, "6:valuesl"); INC(i, rc, BUF_SZ); + pk.pack(std::string("values")); + pk.pack_array(std::min<size_t>(st.size(), 50)); do { - Blob packed_value; - st[j].data->pack(packed_value); - rc = snprintf(buf + i, BUF_SZ - i, "%lu:", packed_value.size()); INC(i, rc, BUF_SZ); - COPY(buf, i, packed_value.data(), packed_value.size(), BUF_SZ); + pk.pack(*st[j].data); k++; j = (j + 1) % st.size(); } while (j != j0 && k < 50); - rc = snprintf(buf + i, BUF_SZ - i, "e"); INC(i, rc, BUF_SZ); } - rc = snprintf(buf + i, BUF_SZ - i, "e1:t%d:", tid.length); INC(i, rc, BUF_SZ); - COPY(buf, i, tid.data(), tid.length, BUF_SZ); - ADD_V(buf, i, BUF_SZ); - rc = snprintf(buf + i, BUF_SZ - i, "1:y1:re"); INC(i, rc, BUF_SZ); + pk.pack(std::string("t")); pk.pack_bin(tid.size()); + pk.pack_bin_body((const char*)tid.data(), tid.size()); + pk.pack(std::string("y")); pk.pack(std::string("r")); + pk.pack(std::string("v")); pk.pack(std::string("RNG1")); - return send(buf, i, 0, sa, salen); + return send(buffer.data(), buffer.size(), 0, sa, salen); } unsigned @@ -2951,69 +2930,68 @@ Dht::sendGetValues(const sockaddr *sa, socklen_t salen, TransId tid, const InfoHash& infohash, want_t want, int confirm) { - static constexpr const size_t BUF_SZ = 2048 * 4; - char buf[BUF_SZ]; - size_t i = 0; - int rc; + msgpack::sbuffer buffer; + msgpack::packer<msgpack::sbuffer> pk(&buffer); + pk.pack_map(5); - rc = snprintf(buf + i, BUF_SZ - i, "d1:ad2:id20:"); INC(i, rc, BUF_SZ); - COPY(buf, i, myid.data(), myid.size(), BUF_SZ); - rc = snprintf(buf + i, BUF_SZ - i, "9:info_hash20:"); INC(i, rc, BUF_SZ); - COPY(buf, i, infohash.data(), infohash.size(), BUF_SZ); + pk.pack(std::string("a")); pk.pack_map(2 + (want>0?1:0)); + pk.pack(std::string("id")); pk.pack(myid); + pk.pack(std::string("h")); pk.pack(infohash); if (want > 0) { - rc = snprintf(buf + i, BUF_SZ - i, "4:wantl%s%se", - (want & WANT4) ? "2:n4" : "", - (want & WANT6) ? "2:n6" : ""); - INC(i, rc, BUF_SZ); + pk.pack(std::string("w")); + pk.pack_array(((want & WANT4)?1:0) + ((want & WANT6)?1:0)); + if (want & WANT4) pk.pack(AF_INET); + if (want & WANT6) pk.pack(AF_INET6); } - rc = snprintf(buf + i, BUF_SZ - i, "e1:q9:get_peers1:t%d:", tid.length); - INC(i, rc, BUF_SZ); - COPY(buf, i, tid.data(), tid.length, BUF_SZ); - ADD_V(buf, i, BUF_SZ); - rc = snprintf(buf + i, BUF_SZ - i, "1:y1:qe"); INC(i, rc, BUF_SZ); - return send(buf, i, confirm ? MSG_CONFIRM : 0, sa, salen); + + pk.pack(std::string("q")); pk.pack(std::string("get")); + pk.pack(std::string("t")); pk.pack_bin(tid.size()); + pk.pack_bin_body((const char*)tid.data(), tid.size()); + pk.pack(std::string("y")); pk.pack(std::string("q")); + pk.pack(std::string("v")); pk.pack(std::string("RNG1")); + + return send(buffer.data(), buffer.size(), confirm ? 0 : MSG_CONFIRM, sa, salen); } int Dht::sendListen(const sockaddr* sa, socklen_t salen, TransId tid, const InfoHash& infohash, const Blob& token, int confirm) { - static constexpr const size_t BUF_SZ = 2048; - char buf[BUF_SZ]; - size_t i = 0; - int rc; + msgpack::sbuffer buffer; + msgpack::packer<msgpack::sbuffer> pk(&buffer); + pk.pack_map(5); - rc = snprintf(buf + i, BUF_SZ - i, "d1:ad2:id%lu:", myid.size()); INC(i, rc, BUF_SZ); - COPY(buf, i, myid.data(), myid.size(), BUF_SZ); - rc = snprintf(buf + i, BUF_SZ - i, "9:info_hash%lu:", infohash.size()); INC(i, rc, BUF_SZ); - COPY(buf, i, infohash.data(), infohash.size(), BUF_SZ); + pk.pack(std::string("a")); pk.pack_map(3); + pk.pack(std::string("id")); pk.pack(myid); + pk.pack(std::string("h")); pk.pack(infohash); + pk.pack(std::string("token")); pk.pack(token); - rc = snprintf(buf + i, BUF_SZ - i, "e5:token%lu:", token.size()); INC(i, rc, BUF_SZ); - COPY(buf, i, token.data(), token.size(), BUF_SZ); - rc = snprintf(buf + i, BUF_SZ - i, "e1:q6:listen1:t%u:", tid.length); INC(i, rc, BUF_SZ); - COPY(buf, i, tid.data(), tid.length, BUF_SZ); - ADD_V(buf, i, BUF_SZ); - rc = snprintf(buf + i, BUF_SZ - i, "1:y1:qe"); INC(i, rc, BUF_SZ); + pk.pack(std::string("q")); pk.pack(std::string("listen")); + pk.pack(std::string("t")); pk.pack_bin(tid.size()); + pk.pack_bin_body((const char*)tid.data(), tid.size()); + pk.pack(std::string("y")); pk.pack(std::string("q")); + pk.pack(std::string("v")); pk.pack(std::string("RNG1")); - return send(buf, i, confirm ? 0 : MSG_CONFIRM, sa, salen); + return send(buffer.data(), buffer.size(), confirm ? 0 : MSG_CONFIRM, sa, salen); } int Dht::sendListenConfirmation(const sockaddr* sa, socklen_t salen, TransId tid) { - static constexpr const size_t BUF_SZ = 512; - char buf[BUF_SZ]; - size_t i = 0; + msgpack::sbuffer buffer; + msgpack::packer<msgpack::sbuffer> pk(&buffer); + pk.pack_map(4); + + pk.pack(std::string("r")); pk.pack_map(2); + pk.pack(std::string("id")); pk.pack(myid); + insertAddr(pk, sa, salen); - auto rc = snprintf(buf + i, BUF_SZ - i, "d1:rd2:id20:"); INC(i, rc, BUF_SZ); - COPY(buf, i, myid.data(), myid.size(), BUF_SZ); - insertAddr(buf, BUF_SZ, i, sa, salen); + pk.pack(std::string("t")); pk.pack_bin(tid.size()); + pk.pack_bin_body((const char*)tid.data(), tid.size()); + pk.pack(std::string("y")); pk.pack(std::string("r")); + pk.pack(std::string("v")); pk.pack(std::string("RNG1")); - rc = snprintf(buf + i, BUF_SZ - i, "e1:t%u:", tid.length); INC(i, rc, BUF_SZ); - COPY(buf, i, tid.data(), tid.length, BUF_SZ); - ADD_V(buf, i, BUF_SZ); - rc = snprintf(buf + i, BUF_SZ - i, "1:y1:re"); INC(i, rc, BUF_SZ); - return send(buf, i, 0, sa, salen); + return send(buffer.data(), buffer.size(), 0, sa, salen); } int @@ -3021,292 +2999,210 @@ Dht::sendAnnounceValue(const sockaddr *sa, socklen_t salen, TransId tid, const InfoHash& infohash, const Value& value, const Blob& token, int confirm) { - constexpr const size_t BUF_SZ = 2048 * 4; - char buf[BUF_SZ]; - size_t i = 0; + msgpack::sbuffer buffer; + msgpack::packer<msgpack::sbuffer> pk(&buffer); + pk.pack_map(5); - int rc = snprintf(buf + i, BUF_SZ - i, "d1:ad2:id%lu:", myid.size()); INC(i, rc, BUF_SZ); - COPY(buf, i, myid.data(), myid.size(), BUF_SZ); - rc = snprintf(buf + i, BUF_SZ - i, "9:info_hash%lu:", infohash.size()); INC(i, rc, BUF_SZ); - COPY(buf, i, infohash.data(), infohash.size(), BUF_SZ); + pk.pack(std::string("a")); pk.pack_map(4); + pk.pack(std::string("id")); pk.pack(myid); + pk.pack(std::string("h")); pk.pack(infohash); + pk.pack(std::string("values")); pk.pack_array(1); pk.pack(value); + pk.pack(std::string("token")); pk.pack(token); - Blob packed_value; - value.pack(packed_value); - rc = snprintf(buf + i, BUF_SZ - i, "6:valuesl%lu:", packed_value.size()); INC(i, rc, BUF_SZ); - COPY(buf, i, packed_value.data(), packed_value.size(), BUF_SZ); - rc = snprintf(buf + i, BUF_SZ - i, "e5:token%lu:", token.size()); INC(i, rc, BUF_SZ); - COPY(buf, i, token.data(), token.size(), BUF_SZ); - rc = snprintf(buf + i, BUF_SZ - i, "e1:q13:announce_peer1:t%u:", tid.length); INC(i, rc, BUF_SZ); - COPY(buf, i, tid.data(), tid.length, BUF_SZ); - ADD_V(buf, i, BUF_SZ); - rc = snprintf(buf + i, BUF_SZ - i, "1:y1:qe"); INC(i, rc, BUF_SZ); + pk.pack(std::string("q")); pk.pack(std::string("put")); + pk.pack(std::string("t")); pk.pack_bin(tid.size()); + pk.pack_bin_body((const char*)tid.data(), tid.size()); + pk.pack(std::string("y")); pk.pack(std::string("q")); + pk.pack(std::string("v")); pk.pack(std::string("RNG1")); - return send(buf, i, confirm ? 0 : MSG_CONFIRM, sa, salen); + return send(buffer.data(), buffer.size(), confirm ? 0 : MSG_CONFIRM, sa, salen); } int Dht::sendValueAnnounced(const sockaddr *sa, socklen_t salen, TransId tid, Value::Id vid) { - char buf[512]; - size_t i = 0; + msgpack::sbuffer buffer; + msgpack::packer<msgpack::sbuffer> pk(&buffer); + pk.pack_map(4); + + pk.pack(std::string("r")); pk.pack_map(3); + pk.pack(std::string("id")); pk.pack(myid); + pk.pack(std::string("vid")); pk.pack(vid); + insertAddr(pk, sa, salen); - auto rc = snprintf(buf + i, 512 - i, "d1:rd2:id20:"); INC(i, rc, 512); - COPY(buf, i, myid.data(), myid.size(), 512); - insertAddr(buf, 512, i, sa, salen); + pk.pack(std::string("t")); pk.pack_bin(tid.size()); + pk.pack_bin_body((const char*)tid.data(), tid.size()); + pk.pack(std::string("y")); pk.pack(std::string("r")); + pk.pack(std::string("v")); pk.pack(std::string("RNG1")); - rc = snprintf(buf + i, 512 - i, "3:vid%lu:", sizeof(Value::Id)); INC(i, rc, 512); - COPY(buf, i, &vid, sizeof(Value::Id), 512); - rc = snprintf(buf + i, 512 - i, "e1:t%u:", tid.length); INC(i, rc, 512); - COPY(buf, i, tid.data(), tid.length, 512); - ADD_V(buf, i, 512); - rc = snprintf(buf + i, 512 - i, "1:y1:re"); INC(i, rc, 512); - return send(buf, i, 0, sa, salen); + return send(buffer.data(), buffer.size(), 0, sa, salen); } int Dht::sendError(const sockaddr *sa, socklen_t salen, TransId tid, uint16_t code, const char *message, bool include_id) { - constexpr const size_t BUF_SZ = 512; - char buf[BUF_SZ]; - int i = 0, rc; - - size_t msg_len = strlen(message); - rc = snprintf(buf + i, BUF_SZ - i, "d1:eli%ue%lu:", code, msg_len); - INC(i, rc, BUF_SZ); - COPY(buf, i, message, msg_len, BUF_SZ); - rc = snprintf(buf + i, BUF_SZ - i, "e1:t%d:", tid.length); INC(i, rc, BUF_SZ); - COPY(buf, i, tid.data(), tid.length, BUF_SZ); - ADD_V(buf, i, BUF_SZ); + msgpack::sbuffer buffer; + msgpack::packer<msgpack::sbuffer> pk(&buffer); + pk.pack_map(4 + (include_id?1:0)); + + pk.pack(std::string("e")); pk.pack_array(2); + pk.pack(code); + pk.pack_str(strlen(message)); + pk.pack_str_body(message, strlen(message)); + if (include_id) { - rc = snprintf(buf + i, BUF_SZ - i, "1:rd2:id20:"); INC(i, rc, BUF_SZ); - COPY(buf, i, myid.data(), myid.size(), BUF_SZ); - COPY(buf, i, "e", 1u, BUF_SZ); - } - rc = snprintf(buf + i, BUF_SZ - i, "1:y1:ee"); INC(i, rc, BUF_SZ); - return send(buf, i, 0, sa, salen); -} - -#undef CHECK -#undef INC -#undef COPY -#undef ADD_V - -Dht::MessageType -Dht::parseMessage(const uint8_t *buf, size_t buflen, - TransId& tid_return, - InfoHash& id_return, InfoHash& info_hash_return, - InfoHash& target_return, in_port_t& port_return, - Blob& token, Value::Id& value_id, - uint8_t *nodes_return, unsigned *nodes_len, - uint8_t *nodes6_return, unsigned *nodes6_len, - std::vector<std::shared_ptr<Value>>& values_return, - want_t* want_return, uint16_t& error_code, bool& ring, - sockaddr* addr_return, socklen_t& addr_length_return) -{ - const uint8_t *p; - - /* This code will happily crash if the buffer is not NUL-terminated. */ - if (buf[buflen] != '\0') - throw DhtException("Eek! parse_message with unterminated buffer."); - -#define CHECK(ptr, len) if (((uint8_t*)ptr) + (len) > (buf) + (buflen)) throw std::out_of_range("Truncated message."); - - p = (uint8_t*)dht_memmem(buf, buflen, "1:t", 3); - if (p) { - char *q; - size_t l = strtoul((char*)p + 3, &q, 10); - if (q && *q == ':') { - CHECK(q + 1, l); - tid_return = {q+1, l}; - } else - tid_return.length = 0; - } - - p = (uint8_t*)dht_memmem(buf, buflen, "2:id20:", 7); - if (p) { - CHECK(p + 7, HASH_LEN); - memcpy(id_return.data(), p + 7, HASH_LEN); - } else { - id_return = {}; - } - - if (addr_return and addr_length_return) { - p = (uint8_t*)dht_memmem(buf, buflen, "2:sa", 4); - if (p) { - char *q; - size_t l = strtoul((char*)p + 4, &q, 10); - if (q && *q == ':' && (l == sizeof(in_addr) or l == sizeof(in6_addr))) { - CHECK(q + 1, l); - if (l == sizeof(in_addr)) { - auto addr = (sockaddr_in*)addr_return; - std::fill_n((uint8_t*)addr, sizeof(sockaddr_in), 0); - addr->sin_family = AF_INET; - addr->sin_port = 0; - memcpy(&addr->sin_addr, q+1, l); - addr_length_return = sizeof(sockaddr_in); - } else if (l == sizeof(in6_addr)) { - auto addr = (sockaddr_in6*)addr_return; - std::fill_n((uint8_t*)addr, sizeof(sockaddr_in6), 0); - addr_return->sa_family = AF_INET6; - addr->sin6_port = 0; - memcpy(&addr->sin6_addr, q+1, l); - addr_length_return = sizeof(sockaddr_in6); - } - } else - addr_length_return = 0; - } else - addr_length_return = 0; + pk.pack(std::string("r")); pk.pack_map(1); + pk.pack(std::string("id")); pk.pack(myid); } - p = (uint8_t*)dht_memmem(buf, buflen, "9:info_hash20:", 14); - if (p) { - CHECK(p + 14, HASH_LEN); - memcpy(info_hash_return.data(), p + 14, HASH_LEN); - } else { - info_hash_return = {}; + pk.pack(std::string("t")); pk.pack_bin(tid.size()); + pk.pack_bin_body((const char*)tid.data(), tid.size()); + pk.pack(std::string("y")); pk.pack(std::string("e")); + pk.pack(std::string("v")); pk.pack(std::string("RNG1")); + + return send(buffer.data(), buffer.size(), 0, sa, salen); +} + +msgpack::object* +findMapValue(msgpack::object& map, const std::string& key) { + if (map.type != msgpack::type::MAP) throw msgpack::type_error(); + for (unsigned i = 0; i < map.via.map.size; i++) { + auto& o = map.via.map.ptr[i]; + if(o.key.type != msgpack::type::STR) + continue; + if (o.key.as<std::string>() == key) { + return &o.val; + } } + return nullptr; +} - p = (uint8_t*)dht_memmem(buf, buflen, "porti", 5); - if (p) { - char *q; - unsigned long l = strtoul((char*)p + 5, &q, 10); - if (q && *q == 'e' && l < 0x10000) - port_return = l; - else - port_return = 0; - } else - port_return = 0; +void +Dht::ParsedMessage::msgpack_unpack(msgpack::object msg) +{ + auto y = findMapValue(msg, "y"); + auto a = findMapValue(msg, "a"); + auto r = findMapValue(msg, "r"); + auto e = findMapValue(msg, "e"); - p = (uint8_t*)dht_memmem(buf, buflen, "6:target20:", 11); - if (p) { - CHECK(p + 11, HASH_LEN); - memcpy(target_return.data(), p + 11, HASH_LEN); - } else { - target_return = {}; + std::string query; + if (auto q = findMapValue(msg, "q")) { + if (q->type != msgpack::type::STR) + throw msgpack::type_error(); + query = q->as<std::string>(); } - p = (uint8_t*)dht_memmem(buf, buflen, "5:token", 7); - if (p) { - char *q; - size_t l = strtoul((char*)p + 7, &q, 10); - if (q && *q == ':' && l > 0 && l <= 128) { - CHECK(q + 1, l); - token.clear(); - token.insert(token.begin(), q + 1, q + 1 + l); - } + auto& req = a ? *a : (r ? *r : *e); + if (not &req) + throw msgpack::type_error(); + + if (e) { + if (e->type != msgpack::type::ARRAY) + throw msgpack::type_error(); + error_code = e->via.array.ptr[0].as<uint16_t>(); } - if (nodes_len) { - p = (uint8_t*)dht_memmem(buf, buflen, "5:nodes", 7); - if (p) { - char *q; - size_t l = strtoul((char*)p + 7, &q, 10); - if (q && *q == ':' && l > 0 && l <= *nodes_len) { - CHECK(q + 1, l); - memcpy(nodes_return, q + 1, l); - *nodes_len = l; - } else - *nodes_len = 0; - } else - *nodes_len = 0; - } - - if (nodes6_len) { - p = (uint8_t*)dht_memmem(buf, buflen, "6:nodes6", 8); - if (p) { - char *q; - size_t l = strtoul((char*)p + 8, &q, 10); - if (q && *q == ':' && l > 0 && l <= *nodes6_len) { - CHECK(q + 1, l); - memcpy(nodes6_return, q + 1, l); - *nodes6_len = l; - } else - *nodes6_len = 0; - } else - *nodes6_len = 0; - } - - p = (uint8_t*)dht_memmem(buf, buflen, "6:valuesl", 9); - if (p) { - unsigned i = p - buf + 9; - while (true) { - char *q; - size_t l = strtoul((char*)buf + i, &q, 10); - if (q && *q == ':' && l > 0) { - CHECK(q + 1, l); - i = q + 1 + l - (char*)buf; - Value v; - v.unpackBlob(Blob {q + 1, q + 1 + l}); - values_return.push_back(std::make_shared<Value>(std::move(v))); - } else - break; - } - if (i >= buflen || buf[i] != 'e') - DHT_DEBUG("eek... unexpected end for values."); + if (auto rid = findMapValue(req, "id")) + id = {*rid}; + + if (auto rh = findMapValue(req, "h")) + info_hash = {*rh}; + + if (auto rtarget = findMapValue(req, "target")) + target = {*rtarget}; + + if (auto otoken = findMapValue(req, "token")) + token = otoken->as<Blob>(); + + if (auto vid = findMapValue(req, "vid")) + value_id = vid->as<Value::Id>(); + + if (auto rnodes4 = findMapValue(req, "n4")) { + auto n4b = rnodes4->as<std::vector<char>>(); + nodes4 = {n4b.begin(), n4b.end()}; } - p = (uint8_t*)dht_memmem(buf, buflen, "3:vid8:", 7); - if (p) { - CHECK(p + 7, sizeof(value_id)); - memcpy(&value_id, p + 7, sizeof(value_id)); - } else { - value_id = Value::INVALID_ID; - } - - if (want_return) { - p = (uint8_t*)dht_memmem(buf, buflen, "4:wantl", 7); - if (p) { - unsigned i = p - buf + 7; - *want_return = 0; - while (buf[i] > '0' && buf[i] <= '9' && buf[i + 1] == ':' && - i + 2 + buf[i] - '0' < buflen) { - CHECK(buf + i + 2, buf[i] - '0'); - if (buf[i] == '2' && memcmp(buf + i + 2, "n4", 2) == 0) - *want_return |= WANT4; - else if (buf[i] == '2' && memcmp(buf + i + 2, "n6", 2) == 0) - *want_return |= WANT6; - else - DHT_DEBUG("eek... unexpected want flag (%c)", buf[i]); - i += 2 + buf[i] - '0'; - } - if (i >= buflen || buf[i] != 'e') - DHT_DEBUG("eek... unexpected end for want."); - } else { - *want_return = -1; + if (auto rnodes6 = findMapValue(req, "n6")) { + auto n6b = rnodes6->as<std::vector<char>>(); + nodes6 = {n6b.begin(), n6b.end()}; + } + + if (auto sa = findMapValue(req, "sa")) { + if (sa->type != msgpack::type::BIN) + throw msgpack::type_error(); + auto l = sa->via.bin.size; + if (l == sizeof(in_addr)) { + auto a = (sockaddr_in*)&addr.first; + std::fill_n((uint8_t*)a, sizeof(sockaddr_in), 0); + a->sin_family = AF_INET; + a->sin_port = 0; + std::copy_n(sa->via.bin.ptr, l, (char*)&a->sin_addr); + addr.second = sizeof(sockaddr_in); + } else if (l == sizeof(in6_addr)) { + auto a = (sockaddr_in6*)&addr.first; + std::fill_n((uint8_t*)a, sizeof(sockaddr_in6), 0); + a->sin6_family = AF_INET6; + a->sin6_port = 0; + std::copy_n(sa->via.bin.ptr, l, (char*)&a->sin6_addr); + addr.second = sizeof(sockaddr_in6); } + } else + addr.second = 0; + + if (auto rvalues = findMapValue(req, "values")) { + if (rvalues->type != msgpack::type::ARRAY) + throw msgpack::type_error(); + for (size_t i = 0; i < rvalues->via.array.size; i++) + try { + values.emplace_back(std::make_shared<Value>(rvalues->via.array.ptr[i])); + } catch (const std::exception& e) { + //DHT_WARN("Error reading value: %s", e.what()); + std::cout << "Error reading value: " << e.what() << std::endl; + } } - p = (uint8_t*)dht_memmem(buf, buflen, "1:eli", 5); - if (p) { - char *q; - unsigned long l = strtoul((char*)p + 5, &q, 10); - if (q && *q == 'e' && l < 0x10000) - error_code = l; + if (auto w = findMapValue(req, "w")) { + if (w->type != msgpack::type::ARRAY) + throw msgpack::type_error(); + want = 0; + for (unsigned i=0; i<w->via.array.size; i++) { + auto& val = w->via.array.ptr[i]; + try { + auto w = val.as<sa_family_t>(); + if (w == AF_INET) + want |= WANT4; + else if(w == AF_INET6) + want |= WANT6; + } catch (const std::exception& e) {}; + } } else { - error_code = 0; - } - -#undef CHECK - - ring = dht_memmem(buf, buflen, my_v, sizeof(my_v)); - - if (dht_memmem(buf, buflen, "1:y1:r", 6)) - return MessageType::Reply; - if (dht_memmem(buf, buflen, "1:y1:e", 6)) - return MessageType::Error; - if (!dht_memmem(buf, buflen, "1:y1:q", 6)) - throw DhtException("Parse error"); - if (dht_memmem(buf, buflen, "1:q4:ping", 9)) - return MessageType::Ping; - if (dht_memmem(buf, buflen, "1:q9:find_node", 14)) - return MessageType::FindNode; - if (dht_memmem(buf, buflen, "1:q9:get_peers", 14)) - return MessageType::GetValues; - if (dht_memmem(buf, buflen, "1:q13:announce_peer", 19)) - return MessageType::AnnounceValue; - if (dht_memmem(buf, buflen, "1:q6:listen", 11)) - return MessageType::Listen; - throw DhtException("Can't read message type."); + want = -1; + } + + if (auto t = findMapValue(msg, "t")) + tid = {t->as<std::array<char, 4>>()}; + + if (auto rv = findMapValue(msg, "v")) + ua = rv->as<std::string>(); + + if (r) + type = MessageType::Reply; + else if (e) + type = MessageType::Error; + else if (y and y->as<std::string>() != "q") + throw msgpack::type_error(); + else if (query == "ping") + type = MessageType::Ping; + else if (query == "find") + type = MessageType::FindNode; + else if (query == "get") + type = MessageType::GetValues; + else if (query == "listen") + type = MessageType::Listen; + else if (query == "put") + type = MessageType::AnnounceValue; + else + throw msgpack::type_error(); } #ifdef HAVE_MEMMEM diff --git a/src/dhtrunner.cpp b/src/dhtrunner.cpp index 39865a72576320dd91277ee0e9437be5e072d81c..c757187b58acf6c46f3350a8f39a72581ab7b863 100644 --- a/src/dhtrunner.cpp +++ b/src/dhtrunner.cpp @@ -107,7 +107,9 @@ DhtRunner::run(const sockaddr_in* local4, const sockaddr_in6* local6, DhtRunner: } { std::lock_guard<std::mutex> lck(storage_mtx); - if (not pending_ops.empty()) + if (not pending_ops_prio.empty()) + return true; + if (not pending_ops.empty() and getStatus() >= Dht::Status::Connecting) return true; } return false; @@ -125,6 +127,11 @@ DhtRunner::join() dht_thread.join(); if (rcv_thread.joinable()) rcv_thread.join(); + { + std::lock_guard<std::mutex> lck(storage_mtx); + pending_ops = decltype(pending_ops)(); + pending_ops_prio = decltype(pending_ops_prio)(); + } { std::lock_guard<std::mutex> lck(dht_mtx); dht_.reset(); @@ -142,12 +149,22 @@ DhtRunner::loop_() decltype(pending_ops) ops {}; { std::lock_guard<std::mutex> lck(storage_mtx); - ops = std::move(pending_ops); + ops = std::move(pending_ops_prio); } while (not ops.empty()) { ops.front()(*dht_); ops.pop(); } + if (getStatus() >= Dht::Status::Connecting) { + { + std::lock_guard<std::mutex> lck(storage_mtx); + ops = std::move(pending_ops); + } + while (not ops.empty()) { + ops.front()(*dht_); + ops.pop(); + } + } time_point wakeup {}; { @@ -164,14 +181,13 @@ DhtRunner::loop_() } } - if (statusCb) { - Dht::Status nstatus4 = dht_->getStatus(AF_INET); - Dht::Status nstatus6 = dht_->getStatus(AF_INET6); - if (nstatus4 != status4 || nstatus6 != status6) { - status4 = nstatus4; - status6 = nstatus6; + Dht::Status nstatus4 = dht_->getStatus(AF_INET); + Dht::Status nstatus6 = dht_->getStatus(AF_INET6); + if (nstatus4 != status4 || nstatus6 != status6) { + status4 = nstatus4; + status6 = nstatus6; + if (statusCb) statusCb(status4, status6); - } } return wakeup; @@ -193,7 +209,7 @@ DhtRunner::doRun(const sockaddr_in* sin4, const sockaddr_in6* sin6, SecureDht::C } } -#if 0 +#if 1 if (sin6) { s6 = socket(PF_INET6, SOCK_DGRAM, 0); if(s6 >= 0) { @@ -241,13 +257,12 @@ DhtRunner::doRun(const sockaddr_in* sin4, const sockaddr_in6* sin6, SecureDht::C if(rc > 0) { fromlen = sizeof(from); if(s4 >= 0 && FD_ISSET(s4, &readfds)) - rc = recvfrom(s4, (char*)buf, sizeof(buf) - 1, 0, (struct sockaddr*)&from, &fromlen); + rc = recvfrom(s4, (char*)buf, sizeof(buf), 0, (struct sockaddr*)&from, &fromlen); else if(s6 >= 0 && FD_ISSET(s6, &readfds)) - rc = recvfrom(s6, (char*)buf, sizeof(buf) - 1, 0, (struct sockaddr*)&from, &fromlen); + rc = recvfrom(s6, (char*)buf, sizeof(buf), 0, (struct sockaddr*)&from, &fromlen); else break; if (rc > 0) { - buf[rc] = 0; { std::lock_guard<std::mutex> lck(sock_mtx); rcv.emplace_back(Blob {buf, buf+rc+1}, std::make_pair(from, fromlen)); @@ -343,7 +358,7 @@ DhtRunner::put(InfoHash hash, const std::shared_ptr<Value>& value, Dht::DoneCall } void -DhtRunner::put(const std::string& key, Value&& value, Dht::DoneCallback cb) +DhtRunner::put(const std::string& key, Value&& value, Dht::DoneCallbackSimple cb) { put(InfoHash::get(key), std::forward<Value>(value), cb); } @@ -370,7 +385,7 @@ DhtRunner::putSigned(InfoHash hash, Value&& value, Dht::DoneCallback cb) } void -DhtRunner::putSigned(const std::string& key, Value&& value, Dht::DoneCallback cb) +DhtRunner::putSigned(const std::string& key, Value&& value, Dht::DoneCallbackSimple cb) { putSigned(InfoHash::get(key), std::forward<Value>(value), cb); } @@ -427,7 +442,7 @@ void DhtRunner::bootstrap(const std::vector<std::pair<sockaddr_storage, socklen_t>>& nodes) { std::lock_guard<std::mutex> lck(storage_mtx); - pending_ops.emplace([=](SecureDht& dht) { + pending_ops_prio.emplace([=](SecureDht& dht) { for (auto& node : nodes) dht.pingNode((sockaddr*)&node.first, node.second); }); @@ -438,7 +453,7 @@ void DhtRunner::bootstrap(const std::vector<NodeExport>& nodes) { std::lock_guard<std::mutex> lck(storage_mtx); - pending_ops.emplace([=](SecureDht& dht) { + pending_ops_prio.emplace([=](SecureDht& dht) { for (auto& node : nodes) dht.insertNode(node); }); diff --git a/src/securedht.cpp b/src/securedht.cpp index 54a84de808b223194858e73303851a0c11dcb675..efca552f34f1c0441637cfad274587d0f20cd2a0 100644 --- a/src/securedht.cpp +++ b/src/securedht.cpp @@ -238,12 +238,8 @@ SecureDht::getCallbackFilter(GetCallback cb, Value::Filter&& filter) try { Value decrypted_val (decrypt(*v)); if (decrypted_val.recipient == getId()) { - if (decrypted_val.owner.checkSignature(decrypted_val.getToSign(), decrypted_val.signature)) { - if (not filter or filter(decrypted_val)) - tmpvals.push_back(std::make_shared<Value>(std::move(decrypted_val))); - } - else - DHT_WARN("Signature verification failed for %s", v->toString().c_str()); + if (not filter or filter(decrypted_val)) + tmpvals.push_back(std::make_shared<Value>(std::move(decrypted_val))); } // Ignore values belonging to other people } catch (const std::exception& e) { @@ -284,7 +280,7 @@ SecureDht::listen(const InfoHash& id, GetCallback cb, Value::Filter&& f) } void -SecureDht::putSigned(const InfoHash& hash, const std::shared_ptr<Value>& val, DoneCallback callback) +SecureDht::putSigned(const InfoHash& hash, std::shared_ptr<Value> val, DoneCallback callback) { if (val->id == Value::INVALID_ID) { crypto::random_device rdev; @@ -343,17 +339,16 @@ SecureDht::putEncrypted(const InfoHash& hash, const InfoHash& to, std::shared_pt void SecureDht::sign(Value& v) const { - if (v.flags.isEncrypted()) + if (v.isEncrypted()) throw DhtException("Can't sign encrypted data."); v.owner = key_->getPublicKey(); - v.flags = Value::ValueFlags(true, false, v.flags[2]); v.signature = key_->sign(v.getToSign()); } Value SecureDht::encrypt(Value& v, const crypto::PublicKey& to) const { - if (v.flags.isEncrypted()) + if (v.isEncrypted()) throw DhtException("Data is already encrypted."); v.setRecipient(to.getId()); sign(v); @@ -365,12 +360,20 @@ SecureDht::encrypt(Value& v, const crypto::PublicKey& to) const Value SecureDht::decrypt(const Value& v) { - if (not v.flags.isEncrypted()) + if (not v.isEncrypted()) throw DhtException("Data is not encrypted."); + auto decrypted = key_->decrypt(v.cypher); + Value ret {v.id}; - auto pb = decrypted.cbegin(), pe = decrypted.cend(); - ret.unpackBody(pb, pe); + auto msg = msgpack::unpack((const char*)decrypted.data(), decrypted.size()); + ret.msgpack_unpack_body(msg.get()); + + if (ret.recipient != getId()) + throw crypto::DecryptError("Recipient mismatch"); + if (not ret.owner.checkSignature(ret.getToSign(), ret.signature)) + throw crypto::DecryptError("Signature mismatch"); + return ret; } diff --git a/src/value.cpp b/src/value.cpp index dc6c50f07a5a66d0328d846d58b650c5ed7f0b0d..d34c653d994b8341d675ff47d2fd03e7651b7fbc 100644 --- a/src/value.cpp +++ b/src/value.cpp @@ -38,11 +38,17 @@ namespace dht { std::ostream& operator<< (std::ostream& s, const Value& v) { s << "Value[id:" << std::hex << v.id << std::dec << " "; - if (v.flags.isSigned()) + if (v.isSigned()) s << "signed (v" << v.seq << ") "; - if (v.flags.isEncrypted()) + if (v.isEncrypted()) s << "encrypted "; - else { + else if (v.isSigned()) { + if (v.recipient == InfoHash()) + s << "signed (v" << v.seq << ") "; + else + s << "decrypted "; + } + if (not v.isEncrypted()) { if (v.type == IpServiceAnnouncement::TYPE.id) { s << IpServiceAnnouncement(v.data); } else if (v.type == CERTIFICATE_TYPE.id) { @@ -57,7 +63,7 @@ std::ostream& operator<< (std::ostream& s, const Value& v) s << "Data (type: " << v.type << " ): "; s << std::hex; for (size_t i=0; i<v.data.size(); i++) - s << std::setfill('0') << std::setw(2) << (unsigned)v.data[i]; + s << std::setfill('0') << std::setw(2) << (unsigned)v.data[i] << " "; s << std::dec; } } @@ -68,62 +74,45 @@ std::ostream& operator<< (std::ostream& s, const Value& v) const ValueType ValueType::USER_DATA = {0, "User Data"}; -void -Value::packToSign(Blob& res) const -{ - res.push_back(flags.to_ulong()); - if (flags.isEncrypted()) { - res.insert(res.end(), cypher.begin(), cypher.end()); - } else { - if (flags.isSigned()) { - serialize<decltype(seq)>(seq, res); - owner.pack(res); - if (flags.haveRecipient()) - res.insert(res.end(), recipient.begin(), recipient.end()); - } - serialize<ValueType::Id>(type, res); - serialize<Blob>(data, res); - } -} - -Blob -Value::getToSign() const -{ - Blob ret; - packToSign(ret); - return ret; -} - -/** - * Pack part of the data to be encrypted - */ -void -Value::packToEncrypt(Blob& res) const -{ - packToSign(res); - if (!flags.isEncrypted() && flags.isSigned()) - serialize<Blob>(signature, res); +msgpack::unpacked +unpackMsg(Blob b) { + return msgpack::unpack((const char*)b.data(), b.size()); } -Blob -Value::getToEncrypt() const -{ - Blob ret; - packToEncrypt(ret); - return ret; +msgpack::object* +findMapValue(const msgpack::object& map, const std::string& key) { + if (map.type != msgpack::type::MAP) throw msgpack::type_error(); + for (unsigned i = 0; i < map.via.map.size; i++) { + auto& o = map.via.map.ptr[i]; + if(o.key.type != msgpack::type::STR) + continue; + if (o.key.as<std::string>() == key) { + return &o.val; + } + } + return nullptr; } void -Value::pack(Blob& res) const +Value::msgpack_unpack(msgpack::object o) { - serialize<Id>(id, res); - packToEncrypt(res); + if (o.type != msgpack::type::MAP) throw msgpack::type_error(); + if (o.via.map.size < 2) throw msgpack::type_error(); + + if (auto rid = findMapValue(o, "id")) { + id = rid->as<Id>(); + } else + throw msgpack::type_error(); + + if (auto rdat = findMapValue(o, "dat")) { + msgpack_unpack_body(*rdat); + } else + throw msgpack::type_error(); } void -Value::unpackBody(Blob::const_iterator& begin, Blob::const_iterator& end) +Value::msgpack_unpack_body(const msgpack::object& o) { - // clear optional fields owner = {}; recipient = {}; cypher.clear(); @@ -131,39 +120,48 @@ Value::unpackBody(Blob::const_iterator& begin, Blob::const_iterator& end) data.clear(); type = 0; - flags = {deserialize<uint8_t>(begin, end)}; - if (flags.isEncrypted()) { - cypher = {begin, end}; - begin = end; + if (o.type == msgpack::type::BIN) { + auto dat = o.as<std::vector<char>>(); + cypher = {dat.begin(), dat.end()}; } else { - if(flags.isSigned()) { - seq = deserialize<decltype(seq)>(begin, end); - owner.unpack(begin, end); - if (flags.haveRecipient()) - recipient = deserialize<InfoHash>(begin, end); + if (o.type != msgpack::type::MAP) + throw msgpack::type_error(); + auto rbody = findMapValue(o, "body"); + if (not rbody) + throw msgpack::type_error(); + + if (auto rdata = findMapValue(*rbody, "data")) { + auto dat = rdata->as<std::vector<char>>(); + data = {dat.begin(), dat.end()}; + } else + throw msgpack::type_error(); + + if (auto rtype = findMapValue(*rbody, "type")) { + type = rtype->as<ValueType::Id>(); + } else + throw msgpack::type_error(); + + if (auto rutype = findMapValue(*rbody, "utype")) { + user_type = rutype->as<std::string>(); } - type = deserialize<ValueType::Id>(begin, end); - data = deserialize<Blob>(begin, end); - if (flags.isSigned()) - signature = deserialize<Blob>(begin, end); - } -} - -void -Value::unpack(Blob::const_iterator& begin, Blob::const_iterator& end) -{ - id = deserialize<Id>(begin, end); - unpackBody(begin, end); -} -void -ValueSerializable::unpackValue(const Value& v) { - unpackBlob(v.data); -} + if (auto rowner = findMapValue(*rbody, "owner")) { + if (auto rseq = findMapValue(*rbody, "seq")) + seq = rseq->as<decltype(seq)>(); + else + throw msgpack::type_error(); + owner.msgpack_unpack(*rowner); + if (auto rrecipient = findMapValue(*rbody, "to")) { + recipient = rrecipient->as<InfoHash>(); + } -Value -ValueSerializable::packValue() const { - return Value {getType(), *this}; + if (auto rsig = findMapValue(o, "sig")) { + auto dat = rsig->as<std::vector<char>>(); + signature = {dat.begin(), dat.end()}; + } else + throw msgpack::type_error(); + } + } } } diff --git a/tools/CMakeLists.txt b/tools/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..a5122710f5c7d3643e18c46bdf6c4af68eef296d --- /dev/null +++ b/tools/CMakeLists.txt @@ -0,0 +1,8 @@ + +add_executable (dhtnode dhtnode.cpp tools_common.h) +add_executable (dhtscanner dhtscanner.cpp tools_common.h) +add_executable (dhtchat dhtchat.cpp tools_common.h) + +target_link_libraries (dhtnode LINK_PUBLIC opendht gnutls) +target_link_libraries (dhtscanner LINK_PUBLIC opendht gnutls) +target_link_libraries (dhtchat LINK_PUBLIC opendht gnutls) diff --git a/tools/dhtchat.cpp b/tools/dhtchat.cpp index 19ec7aacaa22eb6e91b3dd9d3f39e246d5e527e5..2698fc9d07708bccc44427a2c058ab4355aadf5b 100644 --- a/tools/dhtchat.cpp +++ b/tools/dhtchat.cpp @@ -57,8 +57,8 @@ main(int argc, char **argv) DhtRunner dht; dht.run(params.port, dht::crypto::generateIdentity("DHT Chat Node"), true); - if (not params.bootstrap.empty()) - dht.bootstrap(params.bootstrap[0].c_str(), params.bootstrap[1].c_str()); + if (not params.bootstrap.first.empty()) + dht.bootstrap(params.bootstrap.first.c_str(), params.bootstrap.second.c_str()); std::cout << "OpenDht node " << dht.getNodeId() << " running on port " << params.port << std::endl; std::cout << "Public key ID " << dht.getId() << std::endl; diff --git a/tools/dhtnode.cpp b/tools/dhtnode.cpp index 346b816ae4938aaba046d6aeb74eb05d386ebfa4..94903500dee91e567dba8808cc48843594f3c89f 100644 --- a/tools/dhtnode.cpp +++ b/tools/dhtnode.cpp @@ -59,194 +59,202 @@ void print_node_info(const DhtRunner& dht, const dht_params& params) { int main(int argc, char **argv) { - auto params = parseArgs(argc, argv); - if (params.help) { - print_usage(); - return 0; - } - - // TODO: remove with GnuTLS >= 3.3 - int rc = gnutls_global_init(); - if (rc != GNUTLS_E_SUCCESS) - throw std::runtime_error(std::string("Error initializing GnuTLS: ")+gnutls_strerror(rc)); - - dht::crypto::Identity crt {}; - if (params.generate_identity) { - auto ca_tmp = dht::crypto::generateIdentity("DHT Node CA"); - crt = dht::crypto::generateIdentity("DHT Node", ca_tmp); - } - DhtRunner dht; - dht.run(params.port, crt, true, params.is_bootstrap_node); + try { + auto params = parseArgs(argc, argv); + if (params.help) { + print_usage(); + return 0; + } - if (not params.bootstrap.empty()) { - std::cout << "Bootstrap: " << params.bootstrap[0] << ":" << params.bootstrap[1] << std::endl; - dht.bootstrap(params.bootstrap[0].c_str(), params.bootstrap[1].c_str()); - } + // TODO: remove with GnuTLS >= 3.3 + int rc = gnutls_global_init(); + if (rc != GNUTLS_E_SUCCESS) + throw std::runtime_error(std::string("Error initializing GnuTLS: ")+gnutls_strerror(rc)); - print_node_info(dht, params); - std::cout << " (type 'h' or 'help' for a list of possible commands)" << std::endl << std::endl; - - bool do_log {false}; - while (true) - { - std::cout << ">> "; - std::string line; - std::getline(std::cin, line); - std::istringstream iss(line); - std::string op, idstr, value; - iss >> op >> idstr; - - if (std::cin.eof() || op == "x" || op == "q" || op == "exit" || op == "quit") { - break; - } else if (op == "h" || op == "help") { - std::cout << "OpenDht command line interface (CLI)" << std::endl; - std::cout << "Possible commands:" << std::endl; - std::cout << " h, help Print this help message." << std::endl; - std::cout << " q, quit Quit the program." << std::endl; - std::cout << " log Print the full DHT log." << std::endl; - - std::cout << std::endl << "Node information:" << std::endl; - std::cout << " ll Print basic information and stats about the current node." << std::endl; - std::cout << " ls Print basic information about current searches." << std::endl; - std::cout << " ld Print basic information about currenty stored values on this node." << std::endl; - std::cout << " lr Print the full current routing table of this node" << std::endl; - - std::cout << std::endl << "Operations on the DHT:" << std::endl; - std::cout << " g [key] Get values at [key]." << std::endl; - std::cout << " l [key] Listen for value changes at [key]." << std::endl; - std::cout << " p [key] [str] Put string value at [key]." << std::endl; - std::cout << " s [key] [str] Put string value at [key], signed with our generated private key." << std::endl; - std::cout << " e [key] [dest] [str] Put string value at [key], encrypted for [dest] with its public key (if found)." << std::endl; - std::cout << std::endl; - continue; - } else if (op == "ll") { - print_node_info(dht, params); - unsigned good4, dubious4, cached4, incoming4; - unsigned good6, dubious6, cached6, incoming6; - dht.getNodesStats(AF_INET, &good4, &dubious4, &cached4, &incoming4); - dht.getNodesStats(AF_INET6, &good6, &dubious6, &cached6, &incoming6); - std::cout << "IPv4 nodes : " << good4 << " good, " << dubious4 << " dubious, " << incoming4 << " incoming." << std::endl; - std::cout << "IPv6 nodes : " << good6 << " good, " << dubious6 << " dubious, " << incoming6 << " incoming." << std::endl; - continue; - } else if (op == "lr") { - std::cout << "IPv4 routing table:" << std::endl; - std::cout << dht.getRoutingTablesLog(AF_INET) << std::endl; - std::cout << "IPv6 routing table:" << std::endl; - std::cout << dht.getRoutingTablesLog(AF_INET6) << std::endl; - continue; - } else if (op == "ld") { - std::cout << dht.getStorageLog() << std::endl; - continue; - } else if (op == "ls") { - std::cout << "Searches:" << std::endl; - std::cout << dht.getSearchesLog() << std::endl; - continue; - } else if (op == "la") { - std::cout << "Reported public addresses:" << std::endl; - auto addrs = dht.getPublicAddressStr(); - for (const auto& addr : addrs) - std::cout << addr << std::endl; - continue; - } else if (op == "log") { - do_log = !do_log; - if (do_log) - enableLogging(dht); - else - disableLogging(dht); - continue; + dht::crypto::Identity crt {}; + if (params.generate_identity) { + auto ca_tmp = dht::crypto::generateIdentity("DHT Node CA"); + crt = dht::crypto::generateIdentity("DHT Node", ca_tmp); } + + dht.run(params.port, crt, true, params.is_bootstrap_node); - if (op.empty()) - continue; + if (params.log) + enableLogging(dht); - dht::InfoHash id {idstr}; - static const std::set<std::string> VALID_OPS {"g", "l", "p", "s", "e", "a"}; - if (VALID_OPS.find(op) == VALID_OPS.cend()) { - std::cout << "Unknown command: " << op << std::endl; - std::cout << " (type 'h' or 'help' for a list of possible commands)" << std::endl; - continue; - } - static constexpr dht::InfoHash INVALID_ID {}; - if (id == INVALID_ID) { - std::cout << "Syntax error: invalid InfoHash." << std::endl; - continue; + if (not params.bootstrap.first.empty()) { + std::cout << "Bootstrap: " << params.bootstrap.first << ":" << params.bootstrap.second << std::endl; + dht.bootstrap(params.bootstrap.first.c_str(), params.bootstrap.second.c_str()); } - auto start = std::chrono::high_resolution_clock::now(); - if (op == "g") { - dht.get(id, [start](std::shared_ptr<Value> value) { - auto now = std::chrono::high_resolution_clock::now(); - std::cout << "Get: found value (after " << print_dt(now-start) << "s)" << std::endl; - std::cout << "\t" << *value << std::endl; - return true; - }, [start](bool ok) { - auto end = std::chrono::high_resolution_clock::now(); - std::cout << "Get: " << (ok ? "completed" : "failure") << " (took " << print_dt(end-start) << "s)" << std::endl; - }); - } - else if (op == "l") { - std::cout << id << std::endl; - dht.listen(id, [](std::shared_ptr<Value> value) { - std::cout << "Listen: found value:" << std::endl; - std::cout << "\t" << *value << std::endl; - return true; - }); - } - else if (op == "p") { - std::string v; - iss >> v; - dht.put(id, dht::Value { - dht::ValueType::USER_DATA.id, - std::vector<uint8_t> {v.begin(), v.end()} - }, [start](bool ok) { - auto end = std::chrono::high_resolution_clock::now(); - std::cout << "Put: " << (ok ? "success" : "failure") << " (took " << print_dt(end-start) << "s)" << std::endl; - }); - } - else if (op == "s") { - if (not params.generate_identity) { - print_id_req(); + print_node_info(dht, params); + std::cout << " (type 'h' or 'help' for a list of possible commands)" << std::endl << std::endl; + + while (true) + { + std::cout << ">> "; + std::string line; + std::getline(std::cin, line); + std::istringstream iss(line); + std::string op, idstr, value; + iss >> op >> idstr; + + if (std::cin.eof() || op == "x" || op == "q" || op == "exit" || op == "quit") { + break; + } else if (op == "h" || op == "help") { + std::cout << "OpenDht command line interface (CLI)" << std::endl; + std::cout << "Possible commands:" << std::endl; + std::cout << " h, help Print this help message." << std::endl; + std::cout << " q, quit Quit the program." << std::endl; + std::cout << " log Print the full DHT log." << std::endl; + + std::cout << std::endl << "Node information:" << std::endl; + std::cout << " ll Print basic information and stats about the current node." << std::endl; + std::cout << " ls Print basic information about current searches." << std::endl; + std::cout << " ld Print basic information about currenty stored values on this node." << std::endl; + std::cout << " lr Print the full current routing table of this node" << std::endl; + + std::cout << std::endl << "Operations on the DHT:" << std::endl; + std::cout << " g [key] Get values at [key]." << std::endl; + std::cout << " l [key] Listen for value changes at [key]." << std::endl; + std::cout << " p [key] [str] Put string value at [key]." << std::endl; + std::cout << " s [key] [str] Put string value at [key], signed with our generated private key." << std::endl; + std::cout << " e [key] [dest] [str] Put string value at [key], encrypted for [dest] with its public key (if found)." << std::endl; + std::cout << std::endl; + continue; + } else if (op == "ll") { + print_node_info(dht, params); + unsigned good4, dubious4, cached4, incoming4; + unsigned good6, dubious6, cached6, incoming6; + dht.getNodesStats(AF_INET, &good4, &dubious4, &cached4, &incoming4); + dht.getNodesStats(AF_INET6, &good6, &dubious6, &cached6, &incoming6); + std::cout << "IPv4 nodes : " << good4 << " good, " << dubious4 << " dubious, " << incoming4 << " incoming." << std::endl; + std::cout << "IPv6 nodes : " << good6 << " good, " << dubious6 << " dubious, " << incoming6 << " incoming." << std::endl; + continue; + } else if (op == "lr") { + std::cout << "IPv4 routing table:" << std::endl; + std::cout << dht.getRoutingTablesLog(AF_INET) << std::endl; + std::cout << "IPv6 routing table:" << std::endl; + std::cout << dht.getRoutingTablesLog(AF_INET6) << std::endl; + continue; + } else if (op == "ld") { + std::cout << dht.getStorageLog() << std::endl; + continue; + } else if (op == "ls") { + std::cout << "Searches:" << std::endl; + std::cout << dht.getSearchesLog() << std::endl; + continue; + } else if (op == "la") { + std::cout << "Reported public addresses:" << std::endl; + auto addrs = dht.getPublicAddressStr(); + for (const auto& addr : addrs) + std::cout << addr << std::endl; + continue; + } else if (op == "log") { + params.log = !params.log; + if (params.log) + enableLogging(dht); + else + disableLogging(dht); continue; } - std::string v; - iss >> v; - dht.putSigned(id, dht::Value { - dht::ValueType::USER_DATA.id, - std::vector<uint8_t> {v.begin(), v.end()} - }, [start](bool ok) { - auto end = std::chrono::high_resolution_clock::now(); - std::cout << "Put signed: " << (ok ? "success" : "failure") << " (took " << print_dt(end-start) << "s)" << std::endl; - }); - } - else if (op == "e") { - if (not params.generate_identity) { - print_id_req(); + + if (op.empty()) + continue; + + dht::InfoHash id {idstr}; + static const std::set<std::string> VALID_OPS {"g", "l", "p", "s", "e", "a"}; + if (VALID_OPS.find(op) == VALID_OPS.cend()) { + std::cout << "Unknown command: " << op << std::endl; + std::cout << " (type 'h' or 'help' for a list of possible commands)" << std::endl; continue; } - std::string tostr; - std::string v; - iss >> tostr >> v; - dht.putEncrypted(id, InfoHash(tostr), dht::Value { - dht::ValueType::USER_DATA.id, - std::vector<uint8_t> {v.begin(), v.end()} - }, [start](bool ok) { - auto end = std::chrono::high_resolution_clock::now(); - std::cout << "Put encrypted: " << (ok ? "success" : "failure") << " (took " << print_dt(end-start) << "s)" << std::endl; - }); - } - else if (op == "a") { - in_port_t port; - iss >> port; - dht.put(id, dht::Value {dht::IpServiceAnnouncement::TYPE.id, dht::IpServiceAnnouncement(port)}, [start](bool ok) { - auto end = std::chrono::high_resolution_clock::now(); - std::cout << "Announce: " << (ok ? "success" : "failure") << " (took " << print_dt(end-start) << "s)" << std::endl; - }); + static constexpr dht::InfoHash INVALID_ID {}; + if (id == INVALID_ID) { + std::cout << "Syntax error: invalid InfoHash." << std::endl; + continue; + } + + auto start = std::chrono::high_resolution_clock::now(); + if (op == "g") { + dht.get(id, [start](std::shared_ptr<Value> value) { + auto now = std::chrono::high_resolution_clock::now(); + std::cout << "Get: found value (after " << print_dt(now-start) << "s)" << std::endl; + std::cout << "\t" << *value << std::endl; + return true; + }, [start](bool ok) { + auto end = std::chrono::high_resolution_clock::now(); + std::cout << "Get: " << (ok ? "completed" : "failure") << " (took " << print_dt(end-start) << "s)" << std::endl; + }); + } + else if (op == "l") { + std::cout << id << std::endl; + dht.listen(id, [](std::shared_ptr<Value> value) { + std::cout << "Listen: found value:" << std::endl; + std::cout << "\t" << *value << std::endl; + return true; + }); + } + else if (op == "p") { + std::string v; + iss >> v; + dht.put(id, dht::Value { + dht::ValueType::USER_DATA.id, + std::vector<uint8_t> {v.begin(), v.end()} + }, [start](bool ok) { + auto end = std::chrono::high_resolution_clock::now(); + std::cout << "Put: " << (ok ? "success" : "failure") << " (took " << print_dt(end-start) << "s)" << std::endl; + }); + } + else if (op == "s") { + if (not params.generate_identity) { + print_id_req(); + continue; + } + std::string v; + iss >> v; + dht.putSigned(id, dht::Value { + dht::ValueType::USER_DATA.id, + std::vector<uint8_t> {v.begin(), v.end()} + }, [start](bool ok) { + auto end = std::chrono::high_resolution_clock::now(); + std::cout << "Put signed: " << (ok ? "success" : "failure") << " (took " << print_dt(end-start) << "s)" << std::endl; + }); + } + else if (op == "e") { + if (not params.generate_identity) { + print_id_req(); + continue; + } + std::string tostr; + std::string v; + iss >> tostr >> v; + dht.putEncrypted(id, InfoHash(tostr), dht::Value { + dht::ValueType::USER_DATA.id, + std::vector<uint8_t> {v.begin(), v.end()} + }, [start](bool ok) { + auto end = std::chrono::high_resolution_clock::now(); + std::cout << "Put encrypted: " << (ok ? "success" : "failure") << " (took " << print_dt(end-start) << "s)" << std::endl; + }); + } + else if (op == "a") { + in_port_t port; + iss >> port; + dht.put(id, dht::Value {dht::IpServiceAnnouncement::TYPE.id, dht::IpServiceAnnouncement(port)}, [start](bool ok) { + auto end = std::chrono::high_resolution_clock::now(); + std::cout << "Announce: " << (ok ? "success" : "failure") << " (took " << print_dt(end-start) << "s)" << std::endl; + }); + } } + + std::cout << std::endl << "Stopping node..." << std::endl; + } catch(const std::exception&e) { + std::cout << std::endl << e.what() << std::endl; } - std::cout << std::endl << "Stopping node..." << std::endl; dht.join(); gnutls_global_deinit(); + return 0; } diff --git a/tools/dhtscanner.cpp b/tools/dhtscanner.cpp index ee572bd00cbe3066a97bde7cbb4add4eb0a52024..f2c6ecd47dff077b7755d4a69eab40b1b3d52cd2 100644 --- a/tools/dhtscanner.cpp +++ b/tools/dhtscanner.cpp @@ -90,8 +90,8 @@ main(int argc, char **argv) DhtRunner dht; dht.run(params.port, crt_tmp, true, [](dht::Dht::Status /* ipv4 */, dht::Dht::Status /* ipv6 */) {}); - if (not params.bootstrap.empty()) - dht.bootstrap(params.bootstrap[0].c_str(), params.bootstrap[1].c_str()); + if (not params.bootstrap.first.empty()) + dht.bootstrap(params.bootstrap.first.c_str(), params.bootstrap.second.c_str()); std::cout << "OpenDht node " << dht.getNodeId() << " running on port " << params.port << std::endl; std::cout << "Scanning network..." << std::endl; diff --git a/tools/tools_common.h b/tools/tools_common.h index 83cb8b8e0ede2b3c1a2f2eeb885f3f514348a3df..4c7cc750180bc0ec598a52d751171b77693b307a 100644 --- a/tools/tools_common.h +++ b/tools/tools_common.h @@ -112,26 +112,37 @@ print_dt(DT d) { } /** - * Split string with delimiter + * Split "[host]:port" or "host:port" to pair<"host", "port">. */ -std::vector<std::string> -split(const std::string& s, char delim) { - std::vector<std::string> elems; - std::stringstream ss(s); - std::string item; - while (std::getline(ss, item, delim)) - elems.emplace_back(std::move(item)); - return elems; +std::pair<std::string, std::string> +splitPort(const std::string& s) { + if (s.empty()) + return {}; + if (s[0] == '[') { + std::size_t closure = s.find_first_of(']'); + std::size_t found = s.find_last_of(':'); + if (closure == std::string::npos) + return {s, ""}; + if (found == std::string::npos or found < closure) + return {s.substr(1,closure-1), ""}; + return {s.substr(1,closure-1), s.substr(found+1)}; + } + std::size_t found = s.find_last_of(':'); + std::size_t first = s.find_first_of(':'); + if (found == std::string::npos or found != first) + return {s, ""}; + return {s.substr(0,found), s.substr(found+1)}; } static const constexpr in_port_t DHT_DEFAULT_PORT = 4222; struct dht_params { bool help {false}; // print help and exit + bool log {false}; in_port_t port {DHT_DEFAULT_PORT}; bool is_bootstrap_node {false}; bool generate_identity {false}; - std::vector<std::string> bootstrap {}; + std::pair<std::string, std::string> bootstrap {}; }; static const struct option long_options[] = { @@ -139,6 +150,7 @@ static const struct option long_options[] = { {"port", required_argument, nullptr, 'p'}, {"bootstrap", optional_argument, nullptr, 'b'}, {"identity", no_argument , nullptr, 'i'}, + {"verbose", no_argument , nullptr, 'v'}, {nullptr, 0, nullptr, 0} }; @@ -146,7 +158,7 @@ dht_params parseArgs(int argc, char **argv) { dht_params params; int opt; - while ((opt = getopt_long(argc, argv, ":hip:b:", long_options, nullptr)) != -1) { + while ((opt = getopt_long(argc, argv, ":hivp:b:", long_options, nullptr)) != -1) { switch (opt) { case 'p': { int port_arg = atoi(optarg); @@ -158,9 +170,9 @@ parseArgs(int argc, char **argv) { break; case 'b': if (optarg) { - params.bootstrap = split((optarg[0] == '=') ? optarg+1 : optarg, ':'); - if (params.bootstrap.size() == 1) - params.bootstrap.emplace_back(std::to_string(DHT_DEFAULT_PORT)); + params.bootstrap = splitPort((optarg[0] == '=') ? optarg+1 : optarg); + if (not params.bootstrap.first.empty() and params.bootstrap.second.empty()) + params.bootstrap.second = std::to_string(DHT_DEFAULT_PORT); } else params.is_bootstrap_node = true; @@ -168,6 +180,9 @@ parseArgs(int argc, char **argv) { case 'h': params.help = true; break; + case 'v': + params.log = true; + break; case 'i': params.generate_identity = true; break;