diff --git a/include/opendht/crypto.h b/include/opendht/crypto.h index 6470c4fe66db90fc5710fd48196c0c62fca72f7e..dc5c401fe23c1213d201f1c14f38fe60eb4e111e 100644 --- a/include/opendht/crypto.h +++ b/include/opendht/crypto.h @@ -651,7 +651,7 @@ private: using SecureBlob = secure_vector<uint8_t>; -class EcPublicKey : public std::array<uint8_t, crypto_box_PUBLICKEYBYTES> { +/*class EcPublicKey : public std::array<uint8_t, crypto_box_PUBLICKEYBYTES> { public: Blob encrypt(const Blob& message) const { Blob ret(crypto_box_SEALBYTES+message.size()); @@ -659,11 +659,16 @@ public: return ret; } -}; + const InfoHash& getId() const { + return *reinterpret_cast<const InfoHash*>(this); + } + +};*/ class EcSecretKey { public: using KeyData = secure_array<uint8_t, crypto_box_SECRETKEYBYTES>; + using EcPublicKey = InfoHash; static EcSecretKey generate() { EcSecretKey ret; @@ -675,34 +680,34 @@ public: return pk; } - Blob encrypt(const Blob& message, const EcPublicKey& pub) const { - Blob ret(crypto_box_NONCEBYTES+crypto_box_MACBYTES+message.size()); + Blob encrypt(const uint8_t* data, size_t size, const EcPublicKey& pub) const { + Blob ret(crypto_box_NONCEBYTES+crypto_box_MACBYTES+size); randombytes_buf(ret.data(), crypto_box_NONCEBYTES); - if (crypto_box_easy(ret.data()+crypto_box_NONCEBYTES, message.data(), message.size(), ret.data(), + if (crypto_box_easy(ret.data()+crypto_box_NONCEBYTES, data, size, ret.data(), pub.data(), key.data()) != 0) { throw CryptoException("Can't encrypt data"); } return ret; } - Blob decrypt(const Blob& cypher) const { - if (cypher.size() <= crypto_box_SEALBYTES) + Blob decrypt(const uint8_t* cypher, size_t cypher_size) const { + if (cypher_size <= crypto_box_SEALBYTES) throw DecryptError("Unexpected cipher length"); - Blob ret(cypher.size() - crypto_box_SEALBYTES); - if (crypto_box_seal_open(ret.data(), cypher.data(), cypher.size(), pk.data(), key.data()) != 0) { + Blob ret(cypher_size - crypto_box_SEALBYTES); + if (crypto_box_seal_open(ret.data(), cypher, cypher_size, pk.data(), key.data()) != 0) { throw DecryptError("Can't decrypt data"); } return ret; } - Blob decrypt(const Blob& cypher, const EcPublicKey& pub) const { - if (cypher.size() <= crypto_box_NONCEBYTES+crypto_box_MACBYTES) + Blob decrypt(const uint8_t* cypher, size_t cypher_size, const EcPublicKey& pub) const { + if (cypher_size <= crypto_box_NONCEBYTES+crypto_box_MACBYTES) throw DecryptError("Unexpected cipher length"); - Blob ret(cypher.size() - crypto_box_NONCEBYTES - crypto_box_MACBYTES); + Blob ret(cypher_size - crypto_box_NONCEBYTES - crypto_box_MACBYTES); if (crypto_box_open_easy(ret.data(), - cypher.data()+crypto_box_NONCEBYTES, - cypher.size()-crypto_box_NONCEBYTES, - cypher.data(), + cypher+crypto_box_NONCEBYTES, + cypher_size-crypto_box_NONCEBYTES, + cypher, pub.data(), key.data()) != 0) { throw DecryptError("Can't decrypt data"); } diff --git a/include/opendht/dht.h b/include/opendht/dht.h index b1ed3cd030534d39e541e190569ecd0acbf15a26..fbfc724b122044ed9f3d4641b569db7c61102284 100644 --- a/include/opendht/dht.h +++ b/include/opendht/dht.h @@ -76,7 +76,7 @@ public: /** * Get the ID of the node. */ - inline const InfoHash& getNodeId() const { return myid; } + inline const InfoHash& getNodeId() const { return network_engine.getNodeId(); } /** * Get the current status of the node for the given family. diff --git a/include/opendht/infohash.h b/include/opendht/infohash.h index a78c7e99272368da17c9149c7e8facd506197325..b889c64af6a4fe2402f45e5411eb7e4762151ef2 100644 --- a/include/opendht/infohash.h +++ b/include/opendht/infohash.h @@ -22,6 +22,8 @@ #include <msgpack.hpp> +#include <sodium.h> + #ifndef _WIN32 #include <netinet/in.h> #include <netdb.h> @@ -46,7 +48,7 @@ typedef uint16_t in_port_t; // bytes -#define HASH_LEN 32u +#define HASH_LEN crypto_box_PUBLICKEYBYTES namespace dht { @@ -147,6 +149,12 @@ public: std::copy_n(h, HASH_LEN, begin()); } + std::vector<uint8_t> encrypt(const uint8_t* msg, size_t msg_size) const { + std::vector<uint8_t> ret(crypto_box_SEALBYTES+msg_size); + crypto_box_seal(ret.data(), msg, msg_size, data()); + return ret; + } + /** * Constructor from an hexadecimal string (without "0x"). * hex must be at least 2.HASH_LEN characters long. diff --git a/include/opendht/network_engine.h b/include/opendht/network_engine.h index 40c903a976357fd3639062a8818dfeb1138d194d..9e54b8ee328d1920337a6efdefec398f2453bd77 100644 --- a/include/opendht/network_engine.h +++ b/include/opendht/network_engine.h @@ -213,7 +213,7 @@ public: using RequestExpiredCb = std::function<void(const Request&, bool)>; NetworkEngine(uv_loop_t*, Logger& log, Scheduler& scheduler); - NetworkEngine(uv_loop_t*, const NetworkConfig& config, InfoHash& myid, NetId net, Logger& log, Scheduler& scheduler, + NetworkEngine(uv_loop_t*, const NetworkConfig& config, NetId net, Logger& log, Scheduler& scheduler, decltype(NetworkEngine::onError) onError, decltype(NetworkEngine::onNewNode) onNewNode, decltype(NetworkEngine::onReportedAddr) onReportedAddr, @@ -226,6 +226,10 @@ public: virtual ~NetworkEngine(); + const InfoHash& getNodeId() const { + return id_key.getPublicKey(); + } + void clear(); void close(OnClose cb) { sock->close(cb); @@ -417,17 +421,6 @@ public: */ void closeSocket(Sp<Socket> socket); - /** - * Parses a message and calls appropriate callbacks. - * - * @param buf The buffer containing the binary message. - * @param buflen The length of the buffer. - * @param from The address info of the sender. - * @param fromlen The length of the corresponding sockaddr structure. - * @param now The time to adjust the clock in the network engine. - */ - void processMessage(const uint8_t *buf, size_t buflen, const SockAddr& addr); - Sp<Node> insertNode(const InfoHash& myid, const SockAddr& addr) { auto n = cache.getNode(myid, addr, {}, scheduler.time(), 0); onNewNode(n, 0); @@ -480,7 +473,20 @@ private: static const std::string my_v; static std::mt19937 rd_device; - void process(std::unique_ptr<ParsedMessage>&&, const SockAddr& from, const Sp<TcpSocket>& s = {}); + /** + * Parses a message and calls appropriate callbacks. + * + * @param buf The buffer containing the binary message. + * @param buflen The length of the buffer. + * @param from The address info of the sender. + * @param fromlen The length of the corresponding sockaddr structure. + * @param now The time to adjust the clock in the network engine. + */ + void processMessage(const uint8_t* data, size_t size, const SockAddr& addr); + Sp<Node> processEncrypted(const msgpack::object& o, const SockAddr& from); + Sp<Node> process(const msgpack::object& o, const SockAddr& addr); + + Sp<Node> process(std::unique_ptr<ParsedMessage>&&, const SockAddr& from, const Sp<TcpSocket>& s = {}); bool rateLimit(const SockAddr& addr); @@ -520,7 +526,7 @@ private: int send(const Blob& msg, int flags, const Sp<Node>& node); - void startTcp(const Sp<TcpSocket>& sock, bool assigned = false); + void startTcp(const Sp<TcpSocket>& sock, const Sp<Node>& node = {}); void sendValueParts(TransId tid, const std::vector<Blob>& svals, const SockAddr& addr); std::vector<Blob> packValueHeader(msgpack::sbuffer&, const std::vector<Sp<Value>>&, bool stream); @@ -560,7 +566,7 @@ private: void deserializeNodes(ParsedMessage& msg, const SockAddr& from); /* DHT info */ - const InfoHash& myid; + //const InfoHash& myid; const NetId network {0}; const Logger& DHT_LOG; diff --git a/include/opendht/node.h b/include/opendht/node.h index 751319bc9cb4733177b409165dfaf6f4564ae7d6..3af896b7e6e670fd2cdc4324a6ca70a6cf910a82 100644 --- a/include/opendht/node.h +++ b/include/opendht/node.h @@ -38,7 +38,7 @@ struct Node { SockAddr addr; Sp<TcpSocket> sock; - crypto::EcPublicKey last_known_pk; + InfoHash last_known_pk; //std::list<crypto::EcSecretKey> time_point time {time_point::min()}; /* last time eared about */ @@ -62,6 +62,10 @@ struct Node { return addr.toString(); } + void startTCP() { + + } + /** * Makes notice about an additionnal authentication error with this node. Up * to MAX_AUTH_ERRORS errors are accepted in order to let the node recover. diff --git a/src/dht.cpp b/src/dht.cpp index 80c7a621cb7a84a3b02b66d4ac005a89cdda261b..27a18e7521e62423dcbf8bdcf9e977804fe68365 100644 --- a/src/dht.cpp +++ b/src/dht.cpp @@ -203,7 +203,7 @@ Dht::onNewNode(const Sp<Node>& node, int confirm) trySearchInsert(node); const auto& now = scheduler.time(); - bool mybucket = list.contains(b, myid); + bool mybucket = list.contains(b, getNodeId()); if (mybucket) { if (node->getFamily() == AF_INET) mybucket_grow_time = now; @@ -606,7 +606,7 @@ Dht::searchStep(Sp<Search> sr) } // true if this node is part of the target nodes cluter. - /*bool in = sr->id.xorCmp(myid, sr->nodes.back().node->id) < 0; + /*bool in = sr->id.xorCmp(getNodeId(), sr->nodes.back().node->id) < 0; DHT_LOG_DEBUG("[search %s IPv%c] synced%s", sr->id.toString().c_str(), sr->af == AF_INET ? '4' : '6', in ? ", in" : "");*/ @@ -1508,7 +1508,7 @@ Dht::getNodesStats(sa_family_t af) const if (b.cached) stats.cached_nodes++; } - stats.table_depth = bcks.depth(bcks.findBucket(myid)); + stats.table_depth = bcks.depth(bcks.findBucket(getNodeId())); return stats; } @@ -1629,7 +1629,7 @@ void Dht::dumpTables() const { std::stringstream out; - out << "My id " << myid << std::endl; + out << "My id " << getNodeId() << std::endl; out << "Buckets IPv4 :" << std::endl; for (const auto& b : buckets4) @@ -1773,11 +1773,10 @@ Dht::~Dht() Dht::Dht() : store(), scheduler(nullptr, DHT_LOG), network_engine(nullptr, DHT_LOG, scheduler) {} Dht::Dht(uv_loop_t* loop, Config config) - : myid(config.node_id != zeroes ? config.node_id : InfoHash::getRandom()), - is_bootstrap(config.is_bootstrap), maintain_storage(config.maintain_storage), + : is_bootstrap(config.is_bootstrap), maintain_storage(config.maintain_storage), buckets4{Bucket {AF_INET}},buckets6{Bucket {AF_INET6}}, store(), store_quota(), scheduler(loop, DHT_LOG), - network_engine(loop, config.network_config, myid, config.network, DHT_LOG, scheduler, + network_engine(loop, config.network_config, config.network, DHT_LOG, scheduler, std::bind(&Dht::onError, this, _1, _2), std::bind(&Dht::onNewNode, this, _1, _2), std::bind(&Dht::onReportedAddr, this, _1, _2), @@ -1803,7 +1802,7 @@ Dht::Dht(uv_loop_t* loop, Config config) expire(); - DHT_LOG.d("DHT initialised with node ID %s", myid.toString().c_str()); + DHT_LOG.d("DHT initialised with node ID %s", getNodeId().toString().c_str()); } @@ -1811,11 +1810,11 @@ bool Dht::neighbourhoodMaintenance(RoutingTable& list) { //DHT_LOG_DEBUG("neighbourhoodMaintenance"); - auto b = list.findBucket(myid); + auto b = list.findBucket(getNodeId()); if (b == list.end()) return false; - InfoHash id = myid; + InfoHash id = getNodeId(); #ifdef _WIN32 std::uniform_int_distribution<int> rand_byte{ 0, std::numeric_limits<uint8_t>::max() }; #else @@ -1932,7 +1931,7 @@ Dht::maintainStorage(decltype(store)::value_type& storage, bool force, DoneCallb auto nodes = buckets4.findClosestNodes(storage.first, now); if (!nodes.empty()) { - if (force || storage.first.xorCmp(nodes.back()->id, myid) < 0) { + if (force || storage.first.xorCmp(nodes.back()->id, getNodeId()) < 0) { for (auto &value : storage.second.getValues()) { const auto& vt = getType(value.data->type); if (force || value.created + vt.expiration > now + MAX_STORAGE_MAINTENANCE_EXPIRE_TIME) { @@ -1947,7 +1946,7 @@ Dht::maintainStorage(decltype(store)::value_type& storage, bool force, DoneCallb auto nodes6 = buckets6.findClosestNodes(storage.first, now); if (!nodes6.empty()) { - if (force || storage.first.xorCmp(nodes6.back()->id, myid) < 0) { + if (force || storage.first.xorCmp(nodes6.back()->id, getNodeId()) < 0) { for (auto &value : storage.second.getValues()) { const auto& vt = getType(value.data->type); if (force || value.created + vt.expiration > now + MAX_STORAGE_MAINTENANCE_EXPIRE_TIME) { @@ -1980,7 +1979,7 @@ Dht::updateStatus() status6 = nstatus6; if (status4 == NodeStatus::Disconnected and status6 == NodeStatus::Disconnected) { // We have lost connection with the DHT. Try to recover using bootstrap nodes. - DHT_LOG.e(myid, "DHT disconnected", myid.toString().c_str()); + DHT_LOG.e(getNodeId(), "DHT disconnected", getNodeId().toString().c_str()); } else { //bootstrap_nodes.clear(); } @@ -2011,12 +2010,12 @@ Dht::confirmNodes() updateStatus(); if (searches4.empty() and status4 == NodeStatus::Connected) { - DHT_LOG.d(myid, "[confirm nodes] initial IPv4 'get' for my id (%s)", myid.toString().c_str()); - search(myid, AF_INET); + DHT_LOG.d(getNodeId(), "[confirm nodes] initial IPv4 'get' for my id (%s)", getNodeId().toString().c_str()); + search(getNodeId(), AF_INET); } if (searches6.empty() and status6 == NodeStatus::Connected) { - DHT_LOG.d(myid, "[confirm nodes] initial IPv6 'get' for my id (%s)", myid.toString().c_str()); - search(myid, AF_INET6); + DHT_LOG.d(getNodeId(), "[confirm nodes] initial IPv6 'get' for my id (%s)", getNodeId().toString().c_str()); + search(getNodeId(), AF_INET6); } soon |= bucketMaintenance(buckets4); @@ -2106,13 +2105,13 @@ Dht::exportNodes() { const auto& now = scheduler.time(); std::vector<NodeExport> nodes; - const auto b4 = buckets4.findBucket(myid); + const auto b4 = buckets4.findBucket(getNodeId()); if (b4 != buckets4.end()) { for (auto& n : b4->nodes) if (n->isGood(now)) nodes.push_back(n->exportNode()); } - const auto b6 = buckets6.findBucket(myid); + const auto b6 = buckets6.findBucket(getNodeId()); if (b6 != buckets6.end()) { for (auto& n : b6->nodes) if (n->isGood(now)) @@ -2365,7 +2364,7 @@ Dht::onAnnounce(Sp<Node> node, // We store a value only if we think we're part of the // SEARCH_NODES nodes around the target id. auto closest_nodes = buckets(node->getFamily()).findClosestNodes(hash, scheduler.time(), SEARCH_NODES); - if (closest_nodes.size() >= TARGET_NODES and hash.xorCmp(closest_nodes.back()->id, myid) < 0) { + if (closest_nodes.size() >= TARGET_NODES and hash.xorCmp(closest_nodes.back()->id, getNodeId()) < 0) { DHT_LOG.w(hash, node->id, "[node %s] announce too far from the target. Dropping value.", node->toString().c_str()); return {}; } diff --git a/src/network_engine.cpp b/src/network_engine.cpp index 47d37d92b52010675c1a2868a4c26ffe01fa5489..74e9ceeb34dc4c96481058da4f60833c9cee7feb 100644 --- a/src/network_engine.cpp +++ b/src/network_engine.cpp @@ -173,8 +173,8 @@ NetworkEngine::RequestAnswer::RequestAnswer(ParsedMessage&& msg) : ntoken(std::move(msg.token)), values(std::move(msg.values)), fields(std::move(msg.fields)), nodes4(std::move(msg.nodes4)), nodes6(std::move(msg.nodes6)) {} -NetworkEngine::NetworkEngine(uv_loop_t*, Logger& log, Scheduler& scheduler) : myid(zeroes), DHT_LOG(log), scheduler(scheduler) {} -NetworkEngine::NetworkEngine(uv_loop_t* loop, const NetworkConfig& config, InfoHash& myid, NetId net, Logger& log, Scheduler& scheduler, +NetworkEngine::NetworkEngine(uv_loop_t*, Logger& log, Scheduler& scheduler) : DHT_LOG(log), scheduler(scheduler) {} +NetworkEngine::NetworkEngine(uv_loop_t* loop, const NetworkConfig& config, NetId net, Logger& log, Scheduler& scheduler, decltype(NetworkEngine::onError) onError, decltype(NetworkEngine::onNewNode) onNewNode, decltype(NetworkEngine::onReportedAddr) onReportedAddr, @@ -185,7 +185,7 @@ NetworkEngine::NetworkEngine(uv_loop_t* loop, const NetworkConfig& config, InfoH decltype(NetworkEngine::onAnnounce) onAnnounce, decltype(NetworkEngine::onRefresh) onRefresh) : onError(onError), onNewNode(onNewNode), onReportedAddr(onReportedAddr), onPing(onPing), onFindNode(onFindNode), - onGetValues(onGetValues), onListen(onListen), onAnnounce(onAnnounce), onRefresh(onRefresh), myid(myid), + onGetValues(onGetValues), onListen(onListen), onAnnounce(onAnnounce), onRefresh(onRefresh), network(net), DHT_LOG(log), sock(std::make_shared<UdpSocket>(loop)), tcp_sock(std::make_shared<TcpSocket>(loop)), id_key(crypto::EcSecretKey::generate()), scheduler(scheduler) { @@ -406,19 +406,19 @@ NetworkEngine::isNodeBlacklisted(const SockAddr& addr) const } void -NetworkEngine::startTcp(const Sp<TcpSocket>& sock, bool assigned) +NetworkEngine::startTcp(const Sp<TcpSocket>& sock, const Sp<Node>& node) { struct SockRxData { msgpack::unpacker unpacker; /** the socket have been assigned to a node */ - bool assigned; + Sp<Node> node; }; auto rx_data = std::make_shared<SockRxData>(); - rx_data->assigned = assigned; - if (not rx_data->assigned) + rx_data->node = node; + if (not rx_data->node) pending_connect.emplace(sock); - std::cout << "startTcp " << rx_data->assigned << std::endl; + //std::cout << "startTcp " << rx_data->assigned << std::endl; //sock->start(std::bind(&NetworkEngine::onReceiveData, this, _1, _2, n, sock)); sock->start([this,sock,rx_data](const uint8_t* data, size_t size){ std::cout << "Received message !! " << size << std::endl; @@ -428,32 +428,60 @@ NetworkEngine::startTcp(const Sp<TcpSocket>& sock, bool assigned) msgpack::object_handle result; // Message pack data loop while(rx_data->unpacker.next(result)) { - std::unique_ptr<ParsedMessage> msg {new ParsedMessage}; + //std::unique_ptr<ParsedMessage> msg {new ParsedMessage}; + SockAddr from = sock->getPeerAddr(); + Sp<Node> node_found; try { - msg->msgpack_unpack(result.get()); + const auto& o = result.get(); + if (o.type == msgpack::type::BIN) { + auto decrypted = id_key.decrypt((const uint8_t*)o.via.bin.ptr, o.via.bin.size); + msgpack::object_handle oh = msgpack::unpack((const char *)decrypted.data(), decrypted.size()); + //msg->msgpack_unpack(oh.get()); + node_found = processEncrypted(oh.get(), from); + } else { + //msg->msgpack_unpack(o); + node_found = process(o, from); + } } catch (const std::exception& e) { DHT_LOG.w("Can't process message: %s", e.what()); return; } - - if (msg->network != network) { + if (node_found) { + if (not rx_data->node) { + pending_connect.erase(sock); + rx_data->node = node_found; + } else if (rx_data->node != node_found) { + DHT_LOG.w("Changing node id on TCP socket !"); + } + } + /*if (msg->network != network) { DHT_LOG.d("Received message from other network %u", msg->network); return; } SockAddr from = sock->getPeerAddr(); if (from.isMappedIPv4()) from = from.getMappedIPv4(); - if (not rx_data->assigned) { - pending_connect.erase(sock); - rx_data->assigned = true; - } - process(std::move(msg), from, sock); + process(std::move(msg), from, sock);*/ } }); } +Sp<Node> +NetworkEngine::processEncrypted(const msgpack::object& o, const SockAddr& from) +{ + if (o.type != msgpack::type::ARRAY) + throw msgpack::type_error(); + InfoHash id {o.via.array.ptr[0]}; + const auto& data = o.via.array.ptr[1]; + if (data.type != msgpack::type::BIN) + throw msgpack::type_error(); + auto decrypted = id_key.decrypt((const uint8_t *)data.via.bin.ptr, data.via.bin.size, id); + msgpack::object_handle oh = msgpack::unpack((const char *)decrypted.data(), decrypted.size()); + return process(oh.get(), from); +} + void -NetworkEngine::processMessage(const uint8_t *buf, size_t buflen, const SockAddr& from_raw) +NetworkEngine::processMessage(const uint8_t* buf, size_t buflen, const SockAddr& from_raw) { SockAddr from = from_raw.isMappedIPv4() ? from_raw.getMappedIPv4() : from_raw; if (isMartian(from)) { @@ -466,19 +494,37 @@ NetworkEngine::processMessage(const uint8_t *buf, size_t buflen, const SockAddr& return; } - std::unique_ptr<ParsedMessage> msg {new ParsedMessage}; try { - msgpack::unpacked msg_res = msgpack::unpack((const char*)buf, buflen); - msg->msgpack_unpack(msg_res.get()); + auto unp = msgpack::unpack((const char*)buf, buflen); + const auto& o = unp.get(); + if (o.type == msgpack::type::BIN) { + auto decrypted = id_key.decrypt((const uint8_t*)o.via.bin.ptr, o.via.bin.size); + msgpack::object_handle oh = msgpack::unpack((const char *)decrypted.data(), decrypted.size()); + processEncrypted(oh.get(), from); + } else { + process(o, from); + } } catch (const std::exception& e) { DHT_LOG.w("Can't process message of size %lu: %s", buflen, e.what()); DHT_LOG.DEBUG.logPrintable(buf, buflen); return; } +} + +Sp<Node> +NetworkEngine::process(const msgpack::object& o, const SockAddr& from) +{ + std::unique_ptr<ParsedMessage> msg {new ParsedMessage}; + try { + msg->msgpack_unpack(o); + } catch (const std::exception& e) { + DHT_LOG.w("Can't process message from %s: %s", from.toString().c_str(), e.what()); + return {}; + } if (msg->network != network) { DHT_LOG.d("Received message from other network %u", msg->network); - return; + return {}; } const auto& now = scheduler.time(); @@ -489,12 +535,12 @@ NetworkEngine::processMessage(const uint8_t *buf, size_t buflen, const SockAddr& if (pmsg_it == partial_messages.end()) { DHT_LOG.d("Can't find partial message"); rateLimit(from); - return; + return {}; } if (!pmsg_it->second.from.equals(from)) { DHT_LOG.d("Received partial message data from unexpected IP address"); rateLimit(from); - return; + return {}; } // append data block if (pmsg_it->second.msg->append(*msg)) { @@ -507,24 +553,24 @@ NetworkEngine::processMessage(const uint8_t *buf, size_t buflen, const SockAddr& } else scheduler.add(RX_TIMEOUT, std::bind(&NetworkEngine::maintainRxBuffer, this, msg->tid)); } - return; + return {}; } - if (msg->id == myid || msg->id == zeroes) { + if (msg->id == getNodeId() || msg->id == zeroes) { DHT_LOG.d("Received message from self"); - return; + return {}; } if (msg->type > MessageType::Reply) { /* Rate limit requests. */ if (!rateLimit(from)) { DHT_LOG.w("Dropping request due to rate limiting"); - return; + return {}; } } if (msg->value_parts.empty()) { - process(std::move(msg), from); + return process(std::move(msg), from); } else { // starting partial message session PartialMessage pmsg; @@ -539,12 +585,14 @@ NetworkEngine::processMessage(const uint8_t *buf, size_t buflen, const SockAddr& } else DHT_LOG.e("Partial message with given TID already exists"); } + return {}; } -void +Sp<Node> NetworkEngine::process(std::unique_ptr<ParsedMessage>&& msg, const SockAddr& from, const Sp<TcpSocket>& sock) { const auto& now = scheduler.time(); + Sp<Node> node; if (msg->type == MessageType::Error or msg->type == MessageType::Reply) { /* either response for a request or data for an opened socket */ @@ -554,13 +602,13 @@ NetworkEngine::process(std::unique_ptr<ParsedMessage>&& msg, const SockAddr& fro rsocket_it = opened_sockets.find(msg->tid); if (req_it == requests.end() and rsocket_it == opened_sockets.end()) { DHT_LOG.e("Can't find transaction."); - return; + return {}; } auto req = req_it != requests.end() ? req_it->second : nullptr; auto rsocket = rsocket_it != opened_sockets.end() ? rsocket_it->second : nullptr; - auto& node = req ? req->node : rsocket->node; + node = req ? req->node : rsocket->node; if (node->id != msg->id) { if (node->id == zeroes) // received reply to a message sent when we didn't know the node ID. node = cache.getNode(msg->id, from, sock, now, true); @@ -569,7 +617,7 @@ NetworkEngine::process(std::unique_ptr<ParsedMessage>&& msg, const SockAddr& fro node->received(now, req); onNewNode(node, 2); DHT_LOG.w(node->id, "[node %s] message received from unexpected node", node->toString().c_str()); - return; + return node; } } else node->update(from, sock); @@ -582,7 +630,7 @@ NetworkEngine::process(std::unique_ptr<ParsedMessage>&& msg, const SockAddr& fro if (req and (req->cancelled() or req->expired() or req->completed())) { DHT_LOG.w(node->id, "[node %s] response to expired, cancelled or completed request", node->toString().c_str()); requests.erase(req_it); - return; + return {}; } switch (msg->type) { @@ -621,7 +669,7 @@ NetworkEngine::process(std::unique_ptr<ParsedMessage>&& msg, const SockAddr& fro break; } } else { - auto node = cache.getNode(msg->id, from, sock, now, true); + node = cache.getNode(msg->id, from, sock, now, true); node->received(now, {}); onNewNode(node, 1); try { @@ -685,6 +733,8 @@ NetworkEngine::process(std::unique_ptr<ParsedMessage>&& msg, const SockAddr& fro sendError(node, msg->tid, e.getCode(), e.getMsg().c_str(), true); } } + + return node; } void @@ -717,12 +767,31 @@ NetworkEngine::sendUDP(msgpack::sbuffer& msg, const SockAddr& addr) int NetworkEngine::send(msgpack::sbuffer& msg, const Sp<Node>& node) { - /*if (addr.second == 0) - return -1; -*/ - // move data - size_t size = msg.size(); - uint8_t* data = (uint8_t*)msg.release(); + uint8_t* data; + size_t size; + if (node->id == InfoHash{}) { + size = msg.size(); + data = (uint8_t*)msg.release(); + } else { + Blob encrypted_out; + { + Blob encrypted = id_key.encrypt((const uint8_t *)msg.data(), msg.size(), node->id); + msgpack::sbuffer buffer; + msgpack::packer<msgpack::sbuffer> pk(&buffer); + pk.pack_array(2); + pk.pack(getNodeId()); + pk.pack_bin(encrypted.size()); + pk.pack_bin_body((const char *)encrypted.data(), encrypted.size()); + + encrypted_out = node->id.encrypt((const uint8_t *)buffer.data(), buffer.size()); + } + msgpack::sbuffer buffer; + msgpack::packer<msgpack::sbuffer> pk(&buffer); + pk.pack_bin(encrypted_out.size()); + pk.pack_bin_body((const char *)encrypted_out.data(), encrypted_out.size()); + size = buffer.size(); + data = (uint8_t*)buffer.release(); + } if (node->canStream()) { return node->sock->write(data, size); } else { @@ -733,14 +802,37 @@ NetworkEngine::send(msgpack::sbuffer& msg, const Sp<Node>& node) int NetworkEngine::send(const Blob& msg, int /*flags*/, const Sp<Node>& node) { - auto data = (uint8_t*)malloc(msg.size()); - memcpy(data, msg.data(), msg.size()); + uint8_t* data; + size_t size; + if (node->id == InfoHash{}) { + size = msg.size(); + data = (uint8_t*)malloc(size);//(uint8_t*)msg.release(); + std::memcpy(data, msg.data(), size); + } else { + Blob encrypted_out; + { + Blob encrypted = id_key.encrypt((const uint8_t *)msg.data(), msg.size(), node->id); + msgpack::sbuffer buffer; + msgpack::packer<msgpack::sbuffer> pk(&buffer); + pk.pack_array(2); + pk.pack(getNodeId()); + pk.pack_bin(encrypted.size()); + pk.pack_bin_body((const char *)encrypted.data(), encrypted.size()); + + encrypted_out = node->id.encrypt((const uint8_t *)buffer.data(), buffer.size()); + } + msgpack::sbuffer buffer; + msgpack::packer<msgpack::sbuffer> pk(&buffer); + pk.pack_bin(encrypted_out.size()); + pk.pack_bin_body((const char *)encrypted_out.data(), encrypted_out.size()); + size = buffer.size(); + data = (uint8_t*)buffer.release(); + } if (node->canStream()) { - return node->sock->write(data, msg.size()); + return node->sock->write(data, size); } else { - return sock->send(data, msg.size(), node->addr); + return sock->send(data, size, node->addr); } - //return sock->send(data, msg.size(), addr); } Sp<Request> @@ -751,7 +843,7 @@ NetworkEngine::sendPing(const Sp<Node>& node, RequestCb&& on_done, RequestExpire pk.pack_map(5+(network?1:0)); pk.pack(std::string("a")); pk.pack_map(1); - pk.pack(std::string("id")); pk.pack(myid); + pk.pack(std::string("id")); pk.pack(getNodeId()); pk.pack(std::string("q")); pk.pack(std::string("ping")); pk.pack(std::string("t")); pk.pack_bin(tid.size()); @@ -788,7 +880,7 @@ NetworkEngine::sendPong(const Sp<Node>& node, TransId tid) { pk.pack_map(4+(network?1:0)); pk.pack(std::string("r")); pk.pack_map(2); - pk.pack(std::string("id")); pk.pack(myid); + pk.pack(std::string("id")); pk.pack(getNodeId()); insertAddr(pk, node->addr); pk.pack(std::string("t")); pk.pack_bin(tid.size()); @@ -811,7 +903,7 @@ NetworkEngine::sendFindNode(const Sp<Node>& n, const InfoHash& target, want_t wa pk.pack_map(5+(network?1:0)); pk.pack(std::string("a")); pk.pack_map(2 + (want>0?1:0)); - pk.pack(std::string("id")); pk.pack(myid); + pk.pack(std::string("id")); pk.pack(getNodeId()); pk.pack(std::string("target")); pk.pack(target); if (want > 0) { pk.pack(std::string("w")); @@ -859,7 +951,7 @@ NetworkEngine::sendGetValues(const Sp<Node>& n, const InfoHash& info_hash, const pk.pack(std::string("a")); pk.pack_map(2 + (query.where.getFilter() or not query.select.getSelection().empty() ? 1:0) + (want>0?1:0)); - pk.pack(std::string("id")); pk.pack(myid); + pk.pack(std::string("id")); pk.pack(getNodeId()); pk.pack(std::string("h")); pk.pack(info_hash); pk.pack(std::string("q")); pk.pack(query); if (want > 0) { @@ -925,7 +1017,7 @@ NetworkEngine::deserializeNodes(ParsedMessage& msg, const SockAddr& from) { for (unsigned i = 0; i < msg.nodes4_raw.size() / NODE4_INFO_BUF_LEN; i++) { uint8_t *ni = msg.nodes4_raw.data() + i * NODE4_INFO_BUF_LEN; const InfoHash& ni_id = *reinterpret_cast<InfoHash*>(ni); - if (ni_id == myid) + if (ni_id == getNodeId()) continue; SockAddr addr = deserializeIPv4(ni + ni_id.size()); if (addr.isLoopback() and from.getFamily() == AF_INET) { @@ -941,7 +1033,7 @@ NetworkEngine::deserializeNodes(ParsedMessage& msg, const SockAddr& from) { for (unsigned i = 0; i < msg.nodes6_raw.size() / NODE6_INFO_BUF_LEN; i++) { uint8_t *ni = msg.nodes6_raw.data() + i * NODE6_INFO_BUF_LEN; const InfoHash& ni_id = *reinterpret_cast<InfoHash*>(ni); - if (ni_id == myid) + if (ni_id == getNodeId()) continue; SockAddr addr = deserializeIPv6(ni + ni_id.size()); if (addr.isLoopback() and from.getFamily() == AF_INET6) { @@ -1021,7 +1113,7 @@ NetworkEngine::sendNodesValues(const Sp<Node>& node, TransId tid, const Blob& no pk.pack(std::string("r")); pk.pack_map(2 + (not st.empty()?1:0) + (nodes.size()>0?1:0) + (nodes6.size()>0?1:0) + (not token.empty()?1:0)); - pk.pack(std::string("id")); pk.pack(myid); + pk.pack(std::string("id")); pk.pack(getNodeId()); insertAddr(pk, node->addr); if (nodes.size() > 0) { pk.pack(std::string("n4")); @@ -1129,16 +1221,15 @@ NetworkEngine::sendListen(const Sp<Node>& n, RequestExpiredCb&& on_expired, SocketCb&& socket_cb) { - + //n->startTcp(scheduler.getLoop()); if (not n->sock) { n->sock = std::make_shared<TcpSocket>(scheduler.getLoop()); n->sock->connect(n->addr.get(), [this,n](int status){ if (status == 0) - startTcp(n->sock, true); + startTcp(n->sock, n); }); } - Sp<Socket> socket; auto tid = TransId { TransPrefix::LISTEN, previous ? previous->tid.getTid() : getNewTid() }; if (previous and previous->node == n) { @@ -1165,7 +1256,7 @@ NetworkEngine::sendListen(const Sp<Node>& n, auto has_query = query.where.getFilter() or not query.select.getSelection().empty(); pk.pack(std::string("a")); pk.pack_map(4 + has_query); - pk.pack(std::string("id")); pk.pack(myid); + pk.pack(std::string("id")); pk.pack(getNodeId()); pk.pack(std::string("h")); pk.pack(hash); pk.pack(std::string("token")); packToken(pk, token); pk.pack(std::string("sid")); pk.pack_bin(socket->id.size()); @@ -1207,7 +1298,7 @@ NetworkEngine::sendListenConfirmation(const Sp<Node>& node, TransId tid) { pk.pack_map(4+(network?1:0)); pk.pack(std::string("r")); pk.pack_map(2); - pk.pack(std::string("id")); pk.pack(myid); + pk.pack(std::string("id")); pk.pack(getNodeId()); insertAddr(pk, node->addr); pk.pack(std::string("t")); pk.pack_bin(tid.size()); @@ -1236,7 +1327,7 @@ NetworkEngine::sendAnnounceValue(const Sp<Node>& n, pk.pack_map(5+(network?1:0)); pk.pack(std::string("a")); pk.pack_map((created < scheduler.time() ? 5 : 4)); - pk.pack(std::string("id")); pk.pack(myid); + pk.pack(std::string("id")); pk.pack(getNodeId()); pk.pack(std::string("h")); pk.pack(infohash); auto v = packValueHeader(buffer, {value}, n->canStream()); if (created < scheduler.time()) { @@ -1294,7 +1385,7 @@ NetworkEngine::sendRefreshValue(const Sp<Node>& n, pk.pack_map(5+(network?1:0)); pk.pack(std::string("a")); pk.pack_map(4); - pk.pack(std::string("id")); pk.pack(myid); + pk.pack(std::string("id")); pk.pack(getNodeId()); pk.pack(std::string("h")); pk.pack(infohash); pk.pack(std::string("vid")); pk.pack(vid); pk.pack(std::string("token")); pk.pack(token); @@ -1339,7 +1430,7 @@ NetworkEngine::sendValueAnnounced(const Sp<Node>& node, TransId tid, Value::Id v pk.pack_map(4+(network?1:0)); pk.pack(std::string("r")); pk.pack_map(3); - pk.pack(std::string("id")); pk.pack(myid); + pk.pack(std::string("id")); pk.pack(getNodeId()); pk.pack(std::string("vid")); pk.pack(vid); insertAddr(pk, node->addr); @@ -1371,7 +1462,7 @@ NetworkEngine::sendError(const Sp<Node>& node, if (include_id) { pk.pack(std::string("r")); pk.pack_map(1); - pk.pack(std::string("id")); pk.pack(myid); + pk.pack(std::string("id")); pk.pack(getNodeId()); } pk.pack(std::string("t")); pk.pack_bin(tid.size());