diff --git a/include/opendht/crypto.h b/include/opendht/crypto.h index 505f7fb16f33676f5815fdafd1dad423d038d909..6ce933cfe7221a6a142dd3607a4edaa3e4881fb1 100644 --- a/include/opendht/crypto.h +++ b/include/opendht/crypto.h @@ -655,7 +655,7 @@ private: using SecureBlob = secure_vector<uint8_t>; -class EcPublicKey : public std::array<uint8_t, crypto_box_PUBLICKEYBYTES> { +/*class EcPublicKey : public std::array<uint8_t, crypto_box_PUBLICKEYBYTES> { public: Blob encrypt(const Blob& message) const { Blob ret(crypto_box_SEALBYTES+message.size()); @@ -663,11 +663,16 @@ public: return ret; } -}; + const InfoHash& getId() const { + return *reinterpret_cast<const InfoHash*>(this); + } + +};*/ class EcSecretKey { public: using KeyData = secure_array<uint8_t, crypto_box_SECRETKEYBYTES>; + using EcPublicKey = InfoHash; static EcSecretKey generate() { EcSecretKey ret; @@ -679,34 +684,34 @@ public: return pk; } - Blob encrypt(const Blob& message, const EcPublicKey& pub) const { - Blob ret(crypto_box_NONCEBYTES+crypto_box_MACBYTES+message.size()); + Blob encrypt(const uint8_t* data, size_t size, const EcPublicKey& pub) const { + Blob ret(crypto_box_NONCEBYTES+crypto_box_MACBYTES+size); randombytes_buf(ret.data(), crypto_box_NONCEBYTES); - if (crypto_box_easy(ret.data()+crypto_box_NONCEBYTES, message.data(), message.size(), ret.data(), + if (crypto_box_easy(ret.data()+crypto_box_NONCEBYTES, data, size, ret.data(), pub.data(), key.data()) != 0) { throw CryptoException("Can't encrypt data"); } return ret; } - Blob decrypt(const Blob& cypher) const { - if (cypher.size() <= crypto_box_SEALBYTES) + Blob decrypt(const uint8_t* cypher, size_t cypher_size) const { + if (cypher_size <= crypto_box_SEALBYTES) throw DecryptError("Unexpected cipher length"); - Blob ret(cypher.size() - crypto_box_SEALBYTES); - if (crypto_box_seal_open(ret.data(), cypher.data(), cypher.size(), pk.data(), key.data()) != 0) { + Blob ret(cypher_size - crypto_box_SEALBYTES); + if (crypto_box_seal_open(ret.data(), cypher, cypher_size, pk.data(), key.data()) != 0) { throw DecryptError("Can't decrypt data"); } return ret; } - Blob decrypt(const Blob& cypher, const EcPublicKey& pub) const { - if (cypher.size() <= crypto_box_NONCEBYTES+crypto_box_MACBYTES) + Blob decrypt(const uint8_t* cypher, size_t cypher_size, const EcPublicKey& pub) const { + if (cypher_size <= crypto_box_NONCEBYTES+crypto_box_MACBYTES) throw DecryptError("Unexpected cipher length"); - Blob ret(cypher.size() - crypto_box_NONCEBYTES - crypto_box_MACBYTES); + Blob ret(cypher_size - crypto_box_NONCEBYTES - crypto_box_MACBYTES); if (crypto_box_open_easy(ret.data(), - cypher.data()+crypto_box_NONCEBYTES, - cypher.size()-crypto_box_NONCEBYTES, - cypher.data(), + cypher+crypto_box_NONCEBYTES, + cypher_size-crypto_box_NONCEBYTES, + cypher, pub.data(), key.data()) != 0) { throw DecryptError("Can't decrypt data"); } diff --git a/include/opendht/dht.h b/include/opendht/dht.h index 5b0dcbe32d80f32c6b03004564aa465a7eff58aa..d88ccf9bce9025cf7efc934e9688d348e59e163d 100644 --- a/include/opendht/dht.h +++ b/include/opendht/dht.h @@ -76,7 +76,7 @@ public: /** * Get the ID of the node. */ - inline const InfoHash& getNodeId() const { return myid; } + inline const InfoHash& getNodeId() const { return network_engine.getNodeId(); } /** * Get the current status of the node for the given family. diff --git a/include/opendht/infohash.h b/include/opendht/infohash.h index 48d40041a40909c97fb6b3500f0d1ea700d4b0bf..ba8eb6ae7aca5b66ffe31aae2534d93fb2b064bd 100644 --- a/include/opendht/infohash.h +++ b/include/opendht/infohash.h @@ -22,6 +22,8 @@ #include <msgpack.hpp> +#include <sodium.h> + #ifndef _WIN32 #include <netinet/in.h> #include <netdb.h> @@ -46,7 +48,7 @@ typedef uint16_t in_port_t; // bytes -#define HASH_LEN 32u +#define HASH_LEN crypto_box_PUBLICKEYBYTES namespace dht { @@ -147,6 +149,12 @@ public: std::copy_n(h, HASH_LEN, begin()); } + std::vector<uint8_t> encrypt(const uint8_t* msg, size_t msg_size) const { + std::vector<uint8_t> ret(crypto_box_SEALBYTES+msg_size); + crypto_box_seal(ret.data(), msg, msg_size, data()); + return ret; + } + /** * Constructor from an hexadecimal string (without "0x"). * hex must be at least 2.HASH_LEN characters long. diff --git a/include/opendht/network_engine.h b/include/opendht/network_engine.h index 4791f03e2451f58de69bcfba7353ef4037bf5070..ab29b97e44289d645df4badc2cc59622b1276af8 100644 --- a/include/opendht/network_engine.h +++ b/include/opendht/network_engine.h @@ -205,7 +205,7 @@ public: using RequestExpiredCb = std::function<void(const Request&, bool)>; NetworkEngine(uv_loop_t*, Logger& log, Scheduler& scheduler); - NetworkEngine(uv_loop_t*, const NetworkConfig& config, InfoHash& myid, NetId net, Logger& log, Scheduler& scheduler, + NetworkEngine(uv_loop_t*, const NetworkConfig& config, NetId net, Logger& log, Scheduler& scheduler, decltype(NetworkEngine::onError) onError, decltype(NetworkEngine::onNewNode) onNewNode, decltype(NetworkEngine::onReportedAddr) onReportedAddr, @@ -218,6 +218,10 @@ public: virtual ~NetworkEngine(); + const InfoHash& getNodeId() const { + return id_key.getPublicKey(); + } + void clear(); void close(OnClose cb) { sock->close(cb); @@ -386,17 +390,6 @@ public: RequestCb&& on_done, RequestExpiredCb&& on_expired); - /** - * Parses a message and calls appropriate callbacks. - * - * @param buf The buffer containing the binary message. - * @param buflen The length of the buffer. - * @param from The address info of the sender. - * @param fromlen The length of the corresponding sockaddr structure. - * @param now The time to adjust the clock in the network engine. - */ - void processMessage(const uint8_t *buf, size_t buflen, const SockAddr& addr); - Sp<Node> insertNode(const InfoHash& myid, const SockAddr& addr) { auto n = cache.getNode(myid, addr, {}, scheduler.time(), 0); onNewNode(n, 0); @@ -448,7 +441,20 @@ private: static const std::string my_v; - void process(std::unique_ptr<ParsedMessage>&&, const SockAddr& from, const Sp<TcpSocket>& s = {}); + /** + * Parses a message and calls appropriate callbacks. + * + * @param buf The buffer containing the binary message. + * @param buflen The length of the buffer. + * @param from The address info of the sender. + * @param fromlen The length of the corresponding sockaddr structure. + * @param now The time to adjust the clock in the network engine. + */ + void processMessage(const uint8_t* data, size_t size, const SockAddr& addr); + Sp<Node> processEncrypted(const msgpack::object& o, const SockAddr& from); + Sp<Node> process(const msgpack::object& o, const SockAddr& addr); + + Sp<Node> process(std::unique_ptr<ParsedMessage>&&, const SockAddr& from, const Sp<TcpSocket>& s = {}); bool rateLimit(const SockAddr& addr); @@ -478,7 +484,7 @@ private: int send(const Blob& msg, int flags, const Sp<Node>& node, UdpSocket::OnSent&& cb); - void startTcp(const Sp<TcpSocket>& sock, bool assigned = false); + void startTcp(const Sp<TcpSocket>& sock, const Sp<Node>& node = {}); void sendValueParts(TransId tid, const std::vector<Blob>& svals, const SockAddr& addr); std::vector<Blob> packValueHeader(msgpack::sbuffer&, const std::vector<Sp<Value>>&, bool stream); @@ -518,7 +524,7 @@ private: void deserializeNodes(ParsedMessage& msg, const SockAddr& from); /* DHT info */ - const InfoHash& myid; + //const InfoHash& myid; const NetId network {0}; const Logger& DHT_LOG; diff --git a/include/opendht/node.h b/include/opendht/node.h index cd6d8def3858df30ba1b08e29b387cb1ec0ac8d0..90b78caecbe147e3efffec2e73221f029cf38ee3 100644 --- a/include/opendht/node.h +++ b/include/opendht/node.h @@ -52,7 +52,7 @@ struct Node { Sp<TcpSocket> sock; - crypto::EcPublicKey last_known_pk; + InfoHash last_known_pk; //std::list<crypto::EcSecretKey> Node(const InfoHash& id, const SockAddr& addr, const Sp<TcpSocket>& s = {}, bool client=false); @@ -77,6 +77,10 @@ struct Node { const time_point& getReplyTime() const { return reply_time; } void setTime(const time_point& t) { time = t; } + void startTCP() { + + } + /** * Makes notice about an additionnal authentication error with this node. Up * to MAX_AUTH_ERRORS errors are accepted in order to let the node recover. diff --git a/src/dht.cpp b/src/dht.cpp index 463f0d7fa59d1f8b8c74b9bb84316df354b361d5..cdbb40c0d43cfca371795fa2fc393d436d6f26ff 100644 --- a/src/dht.cpp +++ b/src/dht.cpp @@ -535,7 +535,7 @@ Dht::searchStep(Sp<Search> sr) } // true if this node is part of the target nodes cluter. - /*bool in = sr->id.xorCmp(myid, sr->nodes.back().node->id) < 0; + /*bool in = sr->id.xorCmp(getNodeId(), sr->nodes.back().node->id) < 0; DHT_LOG_DEBUG("[search %s IPv%c] synced%s", sr->id.toString().c_str(), sr->af == AF_INET ? '4' : '6', in ? ", in" : "");*/ @@ -1410,7 +1410,7 @@ Dht::getNodesStats(sa_family_t af) const if (b.cached) stats.cached_nodes++; } - stats.table_depth = bcks.depth(bcks.findBucket(myid)); + stats.table_depth = bcks.depth(bcks.findBucket(getNodeId())); return stats; } @@ -1533,7 +1533,7 @@ void Dht::dumpTables() const { std::stringstream out; - out << "My id " << myid << std::endl; + out << "My id " << getNodeId() << std::endl; out << "Buckets IPv4 :" << std::endl; for (const auto& b : buckets4) @@ -1677,11 +1677,10 @@ Dht::~Dht() Dht::Dht() : store(), scheduler(nullptr, DHT_LOG), network_engine(nullptr, DHT_LOG, scheduler) {} Dht::Dht(uv_loop_t* loop, Config config) - : myid(config.node_id != zeroes ? config.node_id : InfoHash::getRandom()), - is_bootstrap(config.is_bootstrap), maintain_storage(config.maintain_storage), + : is_bootstrap(config.is_bootstrap), maintain_storage(config.maintain_storage), buckets4{Bucket {AF_INET}},buckets6{Bucket {AF_INET6}}, store(), store_quota(), scheduler(loop, DHT_LOG), - network_engine(loop, config.network_config, myid, config.network, DHT_LOG, scheduler, + network_engine(loop, config.network_config, config.network, DHT_LOG, scheduler, std::bind(&Dht::onError, this, _1, _2), std::bind(&Dht::onNewNode, this, _1, _2), std::bind(&Dht::onReportedAddr, this, _1, _2), @@ -1709,7 +1708,7 @@ Dht::Dht(uv_loop_t* loop, Config config) expire(); - DHT_LOG.d("DHT initialised with node ID %s", myid.toString().c_str()); + DHT_LOG.d("DHT initialised with node ID %s", getNodeId().toString().c_str()); } @@ -1717,11 +1716,11 @@ bool Dht::neighbourhoodMaintenance(RoutingTable& list) { //DHT_LOG_DEBUG("neighbourhoodMaintenance"); - auto b = list.findBucket(myid); + auto b = list.findBucket(getNodeId()); if (b == list.end()) return false; - InfoHash id = myid; + InfoHash id = getNodeId(); #ifdef _WIN32 std::uniform_int_distribution<int> rand_byte{ 0, std::numeric_limits<uint8_t>::max() }; #else @@ -1838,7 +1837,7 @@ Dht::maintainStorage(decltype(store)::value_type& storage, bool force, DoneCallb auto nodes = buckets4.findClosestNodes(storage.first, now); if (!nodes.empty()) { - if (force || storage.first.xorCmp(nodes.back()->id, myid) < 0) { + if (force || storage.first.xorCmp(nodes.back()->id, getNodeId()) < 0) { for (auto &value : storage.second.getValues()) { const auto& vt = getType(value.data->type); if (force || value.created + vt.expiration > now + MAX_STORAGE_MAINTENANCE_EXPIRE_TIME) { @@ -1853,7 +1852,7 @@ Dht::maintainStorage(decltype(store)::value_type& storage, bool force, DoneCallb auto nodes6 = buckets6.findClosestNodes(storage.first, now); if (!nodes6.empty()) { - if (force || storage.first.xorCmp(nodes6.back()->id, myid) < 0) { + if (force || storage.first.xorCmp(nodes6.back()->id, getNodeId()) < 0) { for (auto &value : storage.second.getValues()) { const auto& vt = getType(value.data->type); if (force || value.created + vt.expiration > now + MAX_STORAGE_MAINTENANCE_EXPIRE_TIME) { @@ -1900,7 +1899,7 @@ Dht::updateStatus() status6 = nstatus6; if (status4 == NodeStatus::Disconnected and status6 == NodeStatus::Disconnected) { // We have lost connection with the DHT. Try to recover using bootstrap nodes. - DHT_LOG.e(myid, "DHT disconnected", myid.toString().c_str()); + DHT_LOG.e(getNodeId(), "DHT disconnected", getNodeId().toString().c_str()); } else { //bootstrap_nodes.clear(); } @@ -1931,12 +1930,12 @@ Dht::confirmNodes() updateStatus(); if (searches4.empty() and status4 == NodeStatus::Connected) { - DHT_LOG.d(myid, "[confirm nodes] initial IPv4 'get' for my id (%s)", myid.toString().c_str()); - search(myid, AF_INET); + DHT_LOG.d(getNodeId(), "[confirm nodes] initial IPv4 'get' for my id (%s)", getNodeId().toString().c_str()); + search(getNodeId(), AF_INET); } if (searches6.empty() and status6 == NodeStatus::Connected) { - DHT_LOG.d(myid, "[confirm nodes] initial IPv6 'get' for my id (%s)", myid.toString().c_str()); - search(myid, AF_INET6); + DHT_LOG.d(getNodeId(), "[confirm nodes] initial IPv6 'get' for my id (%s)", getNodeId().toString().c_str()); + search(getNodeId(), AF_INET6); } soon |= bucketMaintenance(buckets4); @@ -2026,13 +2025,13 @@ Dht::exportNodes() { const auto& now = scheduler.time(); std::vector<NodeExport> nodes; - const auto b4 = buckets4.findBucket(myid); + const auto b4 = buckets4.findBucket(getNodeId()); if (b4 != buckets4.end()) { for (auto& n : b4->nodes) if (n->isGood(now)) nodes.push_back(n->exportNode()); } - const auto b6 = buckets6.findBucket(myid); + const auto b6 = buckets6.findBucket(getNodeId()); if (b6 != buckets6.end()) { for (auto& n : b6->nodes) if (n->isGood(now)) @@ -2285,7 +2284,7 @@ Dht::onAnnounce(Sp<Node> n, // We store a value only if we think we're part of the // SEARCH_NODES nodes around the target id. auto closest_nodes = buckets(node.getFamily()).findClosestNodes(hash, scheduler.time(), SEARCH_NODES); - if (closest_nodes.size() >= TARGET_NODES and hash.xorCmp(closest_nodes.back()->id, myid) < 0) { + if (closest_nodes.size() >= TARGET_NODES and hash.xorCmp(closest_nodes.back()->id, getNodeId()) < 0) { DHT_LOG.w(hash, node.id, "[node %s] announce too far from the target. Dropping value.", node.toString().c_str()); return {}; } diff --git a/src/network_engine.cpp b/src/network_engine.cpp index 70ad29eb70b829b29c20f32b3062699f2b16c364..44e1478101861b83262f00a2f92997a5644d01c6 100644 --- a/src/network_engine.cpp +++ b/src/network_engine.cpp @@ -90,8 +90,8 @@ RequestAnswer::RequestAnswer(ParsedMessage&& msg) : ntoken(std::move(msg.token)), values(std::move(msg.values)), fields(std::move(msg.fields)), nodes4(std::move(msg.nodes4)), nodes6(std::move(msg.nodes6)) {} -NetworkEngine::NetworkEngine(uv_loop_t*, Logger& log, Scheduler& scheduler) : myid(zeroes), DHT_LOG(log), scheduler(scheduler) {} -NetworkEngine::NetworkEngine(uv_loop_t* loop, const NetworkConfig& config, InfoHash& myid, NetId net, Logger& log, Scheduler& scheduler, +NetworkEngine::NetworkEngine(uv_loop_t*, Logger& log, Scheduler& scheduler) : DHT_LOG(log), scheduler(scheduler) {} +NetworkEngine::NetworkEngine(uv_loop_t* loop, const NetworkConfig& config, NetId net, Logger& log, Scheduler& scheduler, decltype(NetworkEngine::onError) onError, decltype(NetworkEngine::onNewNode) onNewNode, decltype(NetworkEngine::onReportedAddr) onReportedAddr, @@ -102,7 +102,7 @@ NetworkEngine::NetworkEngine(uv_loop_t* loop, const NetworkConfig& config, InfoH decltype(NetworkEngine::onAnnounce) onAnnounce, decltype(NetworkEngine::onRefresh) onRefresh) : onError(onError), onNewNode(onNewNode), onReportedAddr(onReportedAddr), onPing(onPing), onFindNode(onFindNode), - onGetValues(onGetValues), onListen(onListen), onAnnounce(onAnnounce), onRefresh(onRefresh), myid(myid), + onGetValues(onGetValues), onListen(onListen), onAnnounce(onAnnounce), onRefresh(onRefresh), network(net), DHT_LOG(log), sock(std::make_shared<UdpSocket>(loop)), tcp_sock(std::make_shared<TcpSocket>(loop)), id_key(crypto::EcSecretKey::generate()), scheduler(scheduler) { @@ -269,19 +269,19 @@ NetworkEngine::isNodeBlacklisted(const SockAddr& addr) const } void -NetworkEngine::startTcp(const Sp<TcpSocket>& sock, bool assigned) +NetworkEngine::startTcp(const Sp<TcpSocket>& sock, const Sp<Node>& node) { struct SockRxData { msgpack::unpacker unpacker; /** the socket have been assigned to a node */ - bool assigned; + Sp<Node> node; }; auto rx_data = std::make_shared<SockRxData>(); - rx_data->assigned = assigned; - if (not rx_data->assigned) + rx_data->node = node; + if (not rx_data->node) pending_connect.emplace(sock); - std::cout << "startTcp " << rx_data->assigned << std::endl; + //std::cout << "startTcp " << rx_data->assigned << std::endl; //sock->start(std::bind(&NetworkEngine::onReceiveData, this, _1, _2, n, sock)); sock->start([this,sock,rx_data](const uint8_t* data, size_t size){ std::cout << "Received message !! " << size << std::endl; @@ -291,32 +291,60 @@ NetworkEngine::startTcp(const Sp<TcpSocket>& sock, bool assigned) msgpack::object_handle result; // Message pack data loop while(rx_data->unpacker.next(result)) { - std::unique_ptr<ParsedMessage> msg {new ParsedMessage}; + //std::unique_ptr<ParsedMessage> msg {new ParsedMessage}; + SockAddr from = sock->getPeerAddr(); + Sp<Node> node_found; try { - msg->msgpack_unpack(result.get()); + const auto& o = result.get(); + if (o.type == msgpack::type::BIN) { + auto decrypted = id_key.decrypt((const uint8_t*)o.via.bin.ptr, o.via.bin.size); + msgpack::object_handle oh = msgpack::unpack((const char *)decrypted.data(), decrypted.size()); + //msg->msgpack_unpack(oh.get()); + node_found = processEncrypted(oh.get(), from); + } else { + //msg->msgpack_unpack(o); + node_found = process(o, from); + } } catch (const std::exception& e) { DHT_LOG.w("Can't process message: %s", e.what()); return; } - - if (msg->network != network) { + if (node_found) { + if (not rx_data->node) { + pending_connect.erase(sock); + rx_data->node = node_found; + } else if (rx_data->node != node_found) { + DHT_LOG.w("Changing node id on TCP socket !"); + } + } + /*if (msg->network != network) { DHT_LOG.d("Received message from other network %u", msg->network); return; } SockAddr from = sock->getPeerAddr(); if (from.isMappedIPv4()) from = from.getMappedIPv4(); - if (not rx_data->assigned) { - pending_connect.erase(sock); - rx_data->assigned = true; - } - process(std::move(msg), from, sock); + process(std::move(msg), from, sock);*/ } }); } +Sp<Node> +NetworkEngine::processEncrypted(const msgpack::object& o, const SockAddr& from) +{ + if (o.type != msgpack::type::ARRAY) + throw msgpack::type_error(); + InfoHash id {o.via.array.ptr[0]}; + const auto& data = o.via.array.ptr[1]; + if (data.type != msgpack::type::BIN) + throw msgpack::type_error(); + auto decrypted = id_key.decrypt((const uint8_t *)data.via.bin.ptr, data.via.bin.size, id); + msgpack::object_handle oh = msgpack::unpack((const char *)decrypted.data(), decrypted.size()); + return process(oh.get(), from); +} + void -NetworkEngine::processMessage(const uint8_t *buf, size_t buflen, const SockAddr& from_raw) +NetworkEngine::processMessage(const uint8_t* buf, size_t buflen, const SockAddr& from_raw) { SockAddr from = from_raw.isMappedIPv4() ? from_raw.getMappedIPv4() : from_raw; if (isMartian(from)) { @@ -329,19 +357,37 @@ NetworkEngine::processMessage(const uint8_t *buf, size_t buflen, const SockAddr& return; } - std::unique_ptr<ParsedMessage> msg {new ParsedMessage}; try { - msgpack::unpacked msg_res = msgpack::unpack((const char*)buf, buflen); - msg->msgpack_unpack(msg_res.get()); + auto unp = msgpack::unpack((const char*)buf, buflen); + const auto& o = unp.get(); + if (o.type == msgpack::type::BIN) { + auto decrypted = id_key.decrypt((const uint8_t*)o.via.bin.ptr, o.via.bin.size); + msgpack::object_handle oh = msgpack::unpack((const char *)decrypted.data(), decrypted.size()); + processEncrypted(oh.get(), from); + } else { + process(o, from); + } } catch (const std::exception& e) { DHT_LOG.w("Can't process message of size %lu: %s", buflen, e.what()); DHT_LOG.DEBUG.logPrintable(buf, buflen); return; } +} + +Sp<Node> +NetworkEngine::process(const msgpack::object& o, const SockAddr& from) +{ + std::unique_ptr<ParsedMessage> msg {new ParsedMessage}; + try { + msg->msgpack_unpack(o); + } catch (const std::exception& e) { + DHT_LOG.w("Can't process message from %s: %s", from.toString().c_str(), e.what()); + return {}; + } if (msg->network != network) { DHT_LOG.d("Received message from other network %u", msg->network); - return; + return {}; } const auto& now = scheduler.time(); @@ -352,12 +398,12 @@ NetworkEngine::processMessage(const uint8_t *buf, size_t buflen, const SockAddr& if (pmsg_it == partial_messages.end()) { DHT_LOG.d("Can't find partial message"); rateLimit(from); - return; + return {}; } if (!pmsg_it->second.from.equals(from)) { DHT_LOG.d("Received partial message data from unexpected IP address"); rateLimit(from); - return; + return {}; } // append data block if (pmsg_it->second.msg->append(*msg)) { @@ -370,24 +416,24 @@ NetworkEngine::processMessage(const uint8_t *buf, size_t buflen, const SockAddr& } else scheduler.add(RX_TIMEOUT, std::bind(&NetworkEngine::maintainRxBuffer, this, msg->tid)); } - return; + return {}; } - if (msg->id == myid || msg->id == zeroes) { + if (msg->id == getNodeId() || msg->id == zeroes) { DHT_LOG.d("Received message from self"); - return; + return {}; } if (msg->type > MessageType::Reply) { /* Rate limit requests. */ if (!rateLimit(from)) { DHT_LOG.w("Dropping request due to rate limiting"); - return; + return {}; } } if (msg->value_parts.empty()) { - process(std::move(msg), from); + return process(std::move(msg), from); } else { // starting partial message session PartialMessage pmsg; @@ -402,13 +448,14 @@ NetworkEngine::processMessage(const uint8_t *buf, size_t buflen, const SockAddr& } else DHT_LOG.e("Partial message with given TID already exists"); } + return {}; } -void +Sp<Node> NetworkEngine::process(std::unique_ptr<ParsedMessage>&& msg, const SockAddr& from, const Sp<TcpSocket>& sock) { const auto& now = scheduler.time(); - auto node = cache.getNode(msg->id, from, now, true, msg->is_client); + auto node = cache.getNode(msg->id, from, now, sock, true, msg->is_client); if (msg->type == MessageType::Error or msg->type == MessageType::Reply) { auto rsocket = node->getSocket(msg->tid.toInt()); @@ -439,7 +486,7 @@ NetworkEngine::process(std::unique_ptr<ParsedMessage>&& msg, const SockAddr& fro if (req and (req->cancelled() or req->expired() or req->completed())) { DHT_LOG.w(node->id, "[node %s] response to expired, cancelled or completed request", node->toString().c_str()); - return; + return {}; } switch (msg->type) { @@ -538,6 +585,8 @@ NetworkEngine::process(std::unique_ptr<ParsedMessage>&& msg, const SockAddr& fro sendError(node, msg->tid, e.getCode(), e.getMsg().c_str(), true); } } + + return node; } void @@ -570,12 +619,31 @@ NetworkEngine::sendUDP(msgpack::sbuffer& msg, const SockAddr& addr) int NetworkEngine::send(msgpack::sbuffer& msg, const Sp<Node>& node) { - /*if (addr.second == 0) - return -1; -*/ - // move data - size_t size = msg.size(); - uint8_t* data = (uint8_t*)msg.release(); + uint8_t* data; + size_t size; + if (node->id == InfoHash{}) { + size = msg.size(); + data = (uint8_t*)msg.release(); + } else { + Blob encrypted_out; + { + Blob encrypted = id_key.encrypt((const uint8_t *)msg.data(), msg.size(), node->id); + msgpack::sbuffer buffer; + msgpack::packer<msgpack::sbuffer> pk(&buffer); + pk.pack_array(2); + pk.pack(getNodeId()); + pk.pack_bin(encrypted.size()); + pk.pack_bin_body((const char *)encrypted.data(), encrypted.size()); + + encrypted_out = node->id.encrypt((const uint8_t *)buffer.data(), buffer.size()); + } + msgpack::sbuffer buffer; + msgpack::packer<msgpack::sbuffer> pk(&buffer); + pk.pack_bin(encrypted_out.size()); + pk.pack_bin_body((const char *)encrypted_out.data(), encrypted_out.size()); + size = buffer.size(); + data = (uint8_t*)buffer.release(); + } if (node->canStream()) { return node->sock->write(data, size); } else { @@ -586,14 +654,37 @@ NetworkEngine::send(msgpack::sbuffer& msg, const Sp<Node>& node) int NetworkEngine::send(const Blob& msg, int /*flags*/, const Sp<Node>& node, UdpSocket::OnSent&& cb) { - auto data = (uint8_t*)malloc(msg.size()); - memcpy(data, msg.data(), msg.size()); + uint8_t* data; + size_t size; + if (node->id == InfoHash{}) { + size = msg.size(); + data = (uint8_t*)malloc(size);//(uint8_t*)msg.release(); + std::memcpy(data, msg.data(), size); + } else { + Blob encrypted_out; + { + Blob encrypted = id_key.encrypt((const uint8_t *)msg.data(), msg.size(), node->id); + msgpack::sbuffer buffer; + msgpack::packer<msgpack::sbuffer> pk(&buffer); + pk.pack_array(2); + pk.pack(getNodeId()); + pk.pack_bin(encrypted.size()); + pk.pack_bin_body((const char *)encrypted.data(), encrypted.size()); + + encrypted_out = node->id.encrypt((const uint8_t *)buffer.data(), buffer.size()); + } + msgpack::sbuffer buffer; + msgpack::packer<msgpack::sbuffer> pk(&buffer); + pk.pack_bin(encrypted_out.size()); + pk.pack_bin_body((const char *)encrypted_out.data(), encrypted_out.size()); + size = buffer.size(); + data = (uint8_t*)buffer.release(); + } if (node->canStream()) { - return node->sock->write(data, msg.size()); + return node->sock->write(data, size); } else { - return sock->send(data, msg.size(), node->addr, std::move(cb)); + return sock->send(data, size, node->addr, std::move(cb)); } - //return sock->send(data, msg.size(), addr); } Sp<Request> @@ -604,7 +695,7 @@ NetworkEngine::sendPing(const Sp<Node>& node, RequestCb&& on_done, RequestExpire pk.pack_map(5+(network?1:0)); pk.pack(std::string("a")); pk.pack_map(1); - pk.pack(std::string("id")); pk.pack(myid); + pk.pack(std::string("id")); pk.pack(getNodeId()); pk.pack(std::string("q")); pk.pack(std::string("ping")); pk.pack(std::string("t")); pk.pack_bin(tid.size()); @@ -641,7 +732,7 @@ NetworkEngine::sendPong(const Sp<Node>& node, TransId tid) { pk.pack_map(4+(network?1:0)); pk.pack(std::string("r")); pk.pack_map(2); - pk.pack(std::string("id")); pk.pack(myid); + pk.pack(std::string("id")); pk.pack(getNodeId()); insertAddr(pk, node->addr); pk.pack(std::string("t")); pk.pack_bin(tid.size()); @@ -664,7 +755,7 @@ NetworkEngine::sendFindNode(const Sp<Node>& n, const InfoHash& target, want_t wa pk.pack_map(5+(network?1:0)); pk.pack(std::string("a")); pk.pack_map(2 + (want>0?1:0)); - pk.pack(std::string("id")); pk.pack(myid); + pk.pack(std::string("id")); pk.pack(getNodeId()); pk.pack(std::string("target")); pk.pack(target); if (want > 0) { pk.pack(std::string("w")); @@ -712,7 +803,7 @@ NetworkEngine::sendGetValues(const Sp<Node>& n, const InfoHash& info_hash, const pk.pack(std::string("a")); pk.pack_map(2 + (query.where.getFilter() or not query.select.getSelection().empty() ? 1:0) + (want>0?1:0)); - pk.pack(std::string("id")); pk.pack(myid); + pk.pack(std::string("id")); pk.pack(getNodeId()); pk.pack(std::string("h")); pk.pack(info_hash); pk.pack(std::string("q")); pk.pack(query); if (want > 0) { @@ -776,7 +867,7 @@ NetworkEngine::deserializeNodes(ParsedMessage& msg, const SockAddr& from) { for (unsigned i = 0, n = msg.nodes4_raw.size() / NODE4_INFO_BUF_LEN; i < n; i++) { const uint8_t* ni = msg.nodes4_raw.data() + i * NODE4_INFO_BUF_LEN; const auto& ni_id = *reinterpret_cast<const InfoHash*>(ni); - if (ni_id == myid) + if (ni_id == getNodeId()) continue; SockAddr addr = deserializeIPv4(ni + ni_id.size()); if (addr.isLoopback() and from.getFamily() == AF_INET) { @@ -792,7 +883,7 @@ NetworkEngine::deserializeNodes(ParsedMessage& msg, const SockAddr& from) { for (unsigned i = 0, n = msg.nodes6_raw.size() / NODE6_INFO_BUF_LEN; i < n; i++) { const uint8_t* ni = msg.nodes6_raw.data() + i * NODE6_INFO_BUF_LEN; const auto& ni_id = *reinterpret_cast<const InfoHash*>(ni); - if (ni_id == myid) + if (ni_id == getNodeId()) continue; SockAddr addr = deserializeIPv6(ni + ni_id.size()); if (addr.isLoopback() and from.getFamily() == AF_INET6) { @@ -871,7 +962,7 @@ NetworkEngine::sendNodesValues(const Sp<Node>& node, TransId tid, const Blob& no pk.pack(std::string("r")); pk.pack_map(2 + (not st.empty()?1:0) + (nodes.size()>0?1:0) + (nodes6.size()>0?1:0) + (not token.empty()?1:0)); - pk.pack(std::string("id")); pk.pack(myid); + pk.pack(std::string("id")); pk.pack(getNodeId()); insertAddr(pk, node->addr); if (nodes.size() > 0) { pk.pack(std::string("n4")); @@ -979,16 +1070,15 @@ NetworkEngine::sendListen(const Sp<Node>& n, RequestExpiredCb&& on_expired, SocketCb&& socket_cb) { - + //n->startTcp(scheduler.getLoop()); if (not n->sock) { n->sock = std::make_shared<TcpSocket>(scheduler.getLoop()); n->sock->connect(n->addr.get(), [this,n](int status){ if (status == 0) - startTcp(n->sock, true); + startTcp(n->sock, n); }); } - uint32_t socket; auto tid = TransId { TransPrefix::LISTEN, n->getNewTid() }; if (previous and previous->node == n) { @@ -1011,7 +1101,7 @@ NetworkEngine::sendListen(const Sp<Node>& n, auto has_query = query.where.getFilter() or not query.select.getSelection().empty(); pk.pack(std::string("a")); pk.pack_map(4 + has_query); - pk.pack(std::string("id")); pk.pack(myid); + pk.pack(std::string("id")); pk.pack(getNodeId()); pk.pack(std::string("h")); pk.pack(hash); pk.pack(std::string("token")); packToken(pk, token); pk.pack(std::string("sid")); pk.pack_bin(sid.size()); @@ -1053,7 +1143,7 @@ NetworkEngine::sendListenConfirmation(const Sp<Node>& node, TransId tid) { pk.pack_map(4+(network?1:0)); pk.pack(std::string("r")); pk.pack_map(2); - pk.pack(std::string("id")); pk.pack(myid); + pk.pack(std::string("id")); pk.pack(getNodeId()); insertAddr(pk, node->addr); pk.pack(std::string("t")); pk.pack_bin(tid.size()); @@ -1082,7 +1172,7 @@ NetworkEngine::sendAnnounceValue(const Sp<Node>& n, pk.pack_map(5+(network?1:0)); pk.pack(std::string("a")); pk.pack_map((created < scheduler.time() ? 5 : 4)); - pk.pack(std::string("id")); pk.pack(myid); + pk.pack(std::string("id")); pk.pack(getNodeId()); pk.pack(std::string("h")); pk.pack(infohash); auto v = packValueHeader(buffer, {value}, n->canStream()); if (created < scheduler.time()) { @@ -1140,7 +1230,7 @@ NetworkEngine::sendRefreshValue(const Sp<Node>& n, pk.pack_map(5+(network?1:0)); pk.pack(std::string("a")); pk.pack_map(4); - pk.pack(std::string("id")); pk.pack(myid); + pk.pack(std::string("id")); pk.pack(getNodeId()); pk.pack(std::string("h")); pk.pack(infohash); pk.pack(std::string("vid")); pk.pack(vid); pk.pack(std::string("token")); pk.pack(token); @@ -1185,7 +1275,7 @@ NetworkEngine::sendValueAnnounced(const Sp<Node>& node, TransId tid, Value::Id v pk.pack_map(4+(network?1:0)); pk.pack(std::string("r")); pk.pack_map(3); - pk.pack(std::string("id")); pk.pack(myid); + pk.pack(std::string("id")); pk.pack(getNodeId()); pk.pack(std::string("vid")); pk.pack(vid); insertAddr(pk, node->addr); @@ -1217,7 +1307,7 @@ NetworkEngine::sendError(const Sp<Node>& node, if (include_id) { pk.pack(std::string("r")); pk.pack_map(1); - pk.pack(std::string("id")); pk.pack(myid); + pk.pack(std::string("id")); pk.pack(getNodeId()); } pk.pack(std::string("t")); pk.pack_bin(tid.size());