diff --git a/include/opendht/network_engine.h b/include/opendht/network_engine.h index 13f30e55eddb7faa36e4d08b40f2db2895813149..b2a250c6b66c34b2739856112a6f63e318915c3e 100644 --- a/include/opendht/network_engine.h +++ b/include/opendht/network_engine.h @@ -93,7 +93,8 @@ struct ParsedMessage; * @param onListen callback for "listen" request. * @param onAnnounce callback for "announce" request. */ -class NetworkEngine final { +class NetworkEngine final +{ struct TransPrefix : public std::array<uint8_t, 2> { TransPrefix(const std::string& str) : std::array<uint8_t, 2>({{(uint8_t)str[0], (uint8_t)str[1]}}) {} static const TransPrefix PING; @@ -259,7 +260,7 @@ public: using RequestCb = std::function<void(const Request&, RequestAnswer&&)>; using RequestExpiredCb = std::function<void(const Request&, bool)>; - NetworkEngine(Logger& log, Scheduler& scheduler) : myid(zeroes), DHT_LOG(log), scheduler(scheduler) {} + NetworkEngine(Logger& log, Scheduler& scheduler); NetworkEngine(InfoHash& myid, NetId net, int s, int s6, Logger& log, Scheduler& scheduler, decltype(NetworkEngine::onError) onError, decltype(NetworkEngine::onNewNode) onNewNode, @@ -268,16 +269,9 @@ public: decltype(NetworkEngine::onFindNode) onFindNode, decltype(NetworkEngine::onGetValues) onGetValues, decltype(NetworkEngine::onListen) onListen, - decltype(NetworkEngine::onAnnounce) onAnnounce) : - onError(onError), onNewNode(onNewNode), onReportedAddr(onReportedAddr), onPing(onPing), onFindNode(onFindNode), - onGetValues(onGetValues), onListen(onListen), onAnnounce(onAnnounce), myid(myid), network(net), - dht_socket(s), dht_socket6(s6), DHT_LOG(log), scheduler(scheduler) - { - transaction_id = std::uniform_int_distribution<decltype(transaction_id)>{1}(rd_device); - } - virtual ~NetworkEngine() { - clear(); - }; + decltype(NetworkEngine::onAnnounce) onAnnounce); + + virtual ~NetworkEngine(); void clear(); @@ -333,7 +327,7 @@ public: std::shared_ptr<Request> sendAnnounceValue(std::shared_ptr<Node> n, const InfoHash& infohash, - const Value& v, + const std::shared_ptr<Value>& v, time_point created, const Blob& token, RequestCb on_done, @@ -370,6 +364,9 @@ public: } private: + + struct PartialMessage; + /*************** * Constants * ***************/ @@ -380,13 +377,22 @@ private: static const constexpr size_t NODE6_INFO_BUF_LEN {38}; /* TODO */ static constexpr std::chrono::seconds UDP_REPLY_TIME {15}; + + /* Max. time to receive a full fragmented packet */ + static constexpr std::chrono::seconds RX_MAX_PACKET_TIME {10}; + /* Max. time between packet fragments */ + static constexpr std::chrono::seconds RX_TIMEOUT {3}; /* The maximum number of nodes that we snub. There is probably little reason to increase this value. */ static constexpr unsigned BLACKLISTED_MAX {10}; + static constexpr size_t MTU {1280}; + static constexpr size_t MAX_PACKET_VALUE_SIZE {8 * 1024}; + static const std::string my_v; static std::mt19937 rd_device; + void process(std::unique_ptr<ParsedMessage>&&, const SockAddr& from); bool rateLimit(const SockAddr& addr); @@ -423,6 +429,10 @@ private: // basic wrapper for socket sendto function int send(const char *buf, size_t len, int flags, const SockAddr& addr); + void sendValueParts(TransId tid, const std::vector<Blob>& svals, const SockAddr& addr); + std::vector<Blob> packValueHeader(msgpack::sbuffer&, const std::vector<std::shared_ptr<Value>>&); + void maintainRxBuffer(const TransId& tid); + /************* * Answers * *************/ @@ -502,6 +512,7 @@ private: // requests handling uint16_t transaction_id {1}; std::map<uint16_t, std::shared_ptr<Request>> requests {}; + std::map<TransId, PartialMessage> partial_messages; MessageStats in_stats {}, out_stats {}; std::set<SockAddr> blacklist {}; diff --git a/include/opendht/sockaddr.h b/include/opendht/sockaddr.h index 110b224ae214c995a0d7ae5d5e440556cd5c14d5..f1ea58549ada8dc84fad1b27e4fd213f23d9ffa8 100644 --- a/include/opendht/sockaddr.h +++ b/include/opendht/sockaddr.h @@ -50,7 +50,7 @@ public: return std::memcmp((uint8_t*)&first, (uint8_t*)&o.first, second) < 0; } - bool operator==(const SockAddr& o) const { + bool equals(const SockAddr& o) const { return second == o.second && std::memcmp((uint8_t*)&first, (uint8_t*)&o.first, second) == 0; } @@ -66,6 +66,8 @@ public: sa_family_t getFamily() const { return second > sizeof(sa_family_t) ? first.ss_family : AF_UNSPEC; } }; +bool operator==(const SockAddr& a, const SockAddr& b); + std::string printAddr(const SockAddr& addr); } diff --git a/include/opendht/value.h b/include/opendht/value.h index 31e3ac418d426f3263db29eb75535e5ecd6d071d..f9b5f331e350dd5846bb91828934b842ea850862 100644 --- a/include/opendht/value.h +++ b/include/opendht/value.h @@ -70,7 +70,7 @@ using StorePolicy = std::function<bool(InfoHash key, std::shared_ptr<Value>& val */ using EditPolicy = std::function<bool(InfoHash key, const std::shared_ptr<Value>& old_val, std::shared_ptr<Value>& new_val, InfoHash from, const sockaddr* from_addr, socklen_t from_len)>; -static constexpr const size_t MAX_VALUE_SIZE {1024 * 16}; +static constexpr const size_t MAX_VALUE_SIZE {1024 * 64}; struct ValueType { typedef uint16_t Id; diff --git a/src/dht.cpp b/src/dht.cpp index 352cc878e5cb333bc609a4102f6bc9f71db32db6..3a7b22d8e4fb413733099a1cf3fb9705568d6781 100644 --- a/src/dht.cpp +++ b/src/dht.cpp @@ -1195,7 +1195,7 @@ void Dht::searchSendAnnounceValue(const std::shared_ptr<Search>& sr) { a.value->id); sn->acked[a.value->id] = network_engine.sendAnnounceValue(sn->node, sr->id, - *a.value, + a.value, a.created, sn->token, onDone, diff --git a/src/network_engine.cpp b/src/network_engine.cpp index 4ada3962fb26832344bfcb82898abff14b692a82..cb78968d47d832ca286482e82f4f42e911072466 100644 --- a/src/network_engine.cpp +++ b/src/network_engine.cpp @@ -35,6 +35,9 @@ const std::string DhtProtocolException::PUT_WRONG_TOKEN {"Put with wrong token"} const std::string DhtProtocolException::PUT_INVALID_ID {"Put with invalid id"}; constexpr std::chrono::seconds NetworkEngine::UDP_REPLY_TIME; +constexpr std::chrono::seconds NetworkEngine::RX_MAX_PACKET_TIME; +constexpr std::chrono::seconds NetworkEngine::RX_TIMEOUT; + const std::string NetworkEngine::my_v {"RNG1"}; const constexpr uint16_t NetworkEngine::TransId::INVALID; std::mt19937 NetworkEngine::rd_device {dht::crypto::random_device{}()}; @@ -59,35 +62,95 @@ enum class MessageType { FindNode, GetValues, AnnounceValue, - Listen + Listen, + ValueData }; struct ParsedMessage { MessageType type; - InfoHash id; /* the id of the sender */ - NetId network {0}; /* network id */ - InfoHash info_hash; /* hash for which values are requested */ - InfoHash target; /* target id around which to find nodes */ - NetworkEngine::TransId tid; /* transaction id */ - Blob token; /* security token */ - Value::Id value_id; /* the value id */ - time_point created { time_point::max() }; /* time when value was first created */ - Blob nodes4_raw, nodes6_raw; /* IPv4 nodes in response to a 'find' request */ + /* Node ID of the sender */ + InfoHash id; + /* Network id */ + NetId network {0}; + /* hash for which values are requested */ + InfoHash info_hash; + /* target id around which to find nodes */ + InfoHash target; + /* transaction id */ + NetworkEngine::TransId tid; + /* security token */ + Blob token; + /* the value id (announce confirmation) */ + Value::Id value_id; + /* time when value was first created */ + time_point created { time_point::max() }; + /* IPv4 nodes in response to a 'find' request */ + Blob nodes4_raw, nodes6_raw; std::vector<std::shared_ptr<Node>> nodes4, nodes6; - std::vector<std::shared_ptr<Value>> values; /* values for a 'get' request */ - std::vector<std::shared_ptr<FieldValueIndex>> fields; /* index for fields values */ - Query query; /* query describing a filter to apply on values. */ - want_t want; /* states if ipv4 or ipv6 request */ - uint16_t error_code; /* error code in case of error */ + /* values to store or retreive request */ + std::vector<std::shared_ptr<Value>> values; + /* index for fields values */ + std::vector<std::shared_ptr<FieldValueIndex>> fields; + /** When part of the message header: {index -> (total size, {})} + * When part of partial value data: {index -> (offset, part_data)} */ + std::map<unsigned, std::pair<unsigned, Blob>> value_parts; + /* query describing a filter to apply on values. */ + Query query; + /* states if ipv4 or ipv6 request */ + want_t want; + /* error code in case of error */ + uint16_t error_code; + /* reported address by the distant node */ std::string ua; - SockAddr addr; /* reported address by the distant node */ + SockAddr addr; void msgpack_unpack(msgpack::object o); + + bool append(const ParsedMessage& block); + bool complete(); +}; + +struct NetworkEngine::PartialMessage { + SockAddr from; + time_point start; + time_point last_part; + std::unique_ptr<ParsedMessage> msg; }; +std::vector<Blob> +serializeValues(const std::vector<std::shared_ptr<Value>>& st) +{ + std::vector<Blob> svals; + svals.reserve(st.size()); + for (const auto& v : st) + svals.emplace_back(packMsg(v)); + return svals; +} + NetworkEngine::RequestAnswer::RequestAnswer(ParsedMessage&& msg) : ntoken(std::move(msg.token)), values(std::move(msg.values)), fields(std::move(msg.fields)), nodes4(std::move(msg.nodes4)), nodes6(std::move(msg.nodes6)) {} +NetworkEngine::NetworkEngine(Logger& log, Scheduler& scheduler) : myid(zeroes), DHT_LOG(log), scheduler(scheduler) {} +NetworkEngine::NetworkEngine(InfoHash& myid, NetId net, int s, int s6, Logger& log, Scheduler& scheduler, + decltype(NetworkEngine::onError) onError, + decltype(NetworkEngine::onNewNode) onNewNode, + decltype(NetworkEngine::onReportedAddr) onReportedAddr, + decltype(NetworkEngine::onPing) onPing, + decltype(NetworkEngine::onFindNode) onFindNode, + decltype(NetworkEngine::onGetValues) onGetValues, + decltype(NetworkEngine::onListen) onListen, + decltype(NetworkEngine::onAnnounce) onAnnounce) : + onError(onError), onNewNode(onNewNode), onReportedAddr(onReportedAddr), onPing(onPing), onFindNode(onFindNode), + onGetValues(onGetValues), onListen(onListen), onAnnounce(onAnnounce), myid(myid), network(net), + dht_socket(s), dht_socket6(s6), DHT_LOG(log), scheduler(scheduler) +{ + transaction_id = std::uniform_int_distribution<decltype(transaction_id)>{1}(rd_device); +} + +NetworkEngine::~NetworkEngine() { + clear(); +} + void NetworkEngine::tellListener(std::shared_ptr<Node> node, uint16_t rid, const InfoHash& hash, want_t want, const Blob& ntoken, std::vector<std::shared_ptr<Node>>&& nodes, @@ -285,29 +348,56 @@ NetworkEngine::processMessage(const uint8_t *buf, size_t buflen, const SockAddr& return; } - ParsedMessage msg; + std::unique_ptr<ParsedMessage> msg {new ParsedMessage}; try { msgpack::unpacked msg_res = msgpack::unpack((const char*)buf, buflen); - msg.msgpack_unpack(msg_res.get()); - if (msg.type != MessageType::Error && msg.id == zeroes) - throw DhtException("no or invalid InfoHash"); + msg->msgpack_unpack(msg_res.get()); } catch (const std::exception& e) { DHT_LOG.WARN("Can't process message of size %lu: %s.", buflen, e.what()); DHT_LOG.DEBUG.logPrintable(buf, buflen); return; } - if (msg.network != network) { - DHT_LOG.DEBUG("Received message from other network %u.", msg.network); + if (msg->network != network) { + DHT_LOG.DEBUG("Received message from other network %u.", msg->network); return; } - if (msg.id == myid || msg.id == zeroes) { + const auto& now = scheduler.time(); + + // partial value data + if (msg->type == MessageType::ValueData) { + auto pmsg_it = partial_messages.find(msg->tid); + if (pmsg_it == partial_messages.end()) { + DHT_LOG.DEBUG("Can't find partial message"); + rateLimit(from); + return; + } + if (!pmsg_it->second.from.equals(from)) { + DHT_LOG.DEBUG("Received partial message data from unexpected IP address"); + rateLimit(from); + return; + } + // append data block + if (pmsg_it->second.msg->append(*msg)) { + pmsg_it->second.last_part = now; + // check data completion + if (pmsg_it->second.msg->complete()) { + // process the full message + process(std::move(pmsg_it->second.msg), from); + partial_messages.erase(pmsg_it); + } else + scheduler.add(now + RX_TIMEOUT, std::bind(&NetworkEngine::maintainRxBuffer, this, msg->tid)); + } + return; + } + + if (msg->id == myid || msg->id == zeroes) { DHT_LOG.DEBUG("Received message from self."); return; } - if (msg.type > MessageType::Reply) { + if (msg->type > MessageType::Reply) { /* Rate limit requests. */ if (!rateLimit(from)) { DHT_LOG.WARN("Dropping request due to rate limiting."); @@ -315,27 +405,39 @@ NetworkEngine::processMessage(const uint8_t *buf, size_t buflen, const SockAddr& } } - const auto& now = scheduler.time(); - - if (msg.tid.length != 4) { - DHT_LOG.ERR("Broken node truncates transaction ids (len: %d): ", msg.tid.length); - DHT_LOG.ERR.logPrintable(buf, buflen); - blacklistNode(cache.getNode(msg.id, from, now, true)); - return; + if (msg->value_parts.empty()) { + process(std::move(msg), from); + } else { + // starting partial message session + PartialMessage pmsg; + pmsg.from = from; + pmsg.msg = std::move(msg); + pmsg.start = now; + pmsg.last_part = now; + auto wmsg = partial_messages.emplace(pmsg.msg->tid, std::move(pmsg)); + if (wmsg.second) { + scheduler.add(now + RX_MAX_PACKET_TIME, std::bind(&NetworkEngine::maintainRxBuffer, this, wmsg.first->first)); + scheduler.add(now + RX_TIMEOUT, std::bind(&NetworkEngine::maintainRxBuffer, this, wmsg.first->first)); + } else + DHT_LOG.ERR("Partial message with given TID already exists."); } +} + +void +NetworkEngine::process(std::unique_ptr<ParsedMessage>&& msg, const SockAddr& from) +{ + const auto& now = scheduler.time(); - uint16_t ttid = 0; - if (msg.type == MessageType::Error or msg.type == MessageType::Reply) { - auto reqp = requests.find(msg.tid.getTid()); + if (msg->type == MessageType::Error or msg->type == MessageType::Reply) { + auto reqp = requests.find(msg->tid.getTid()); if (reqp == requests.end()) { - throw DhtProtocolException {DhtProtocolException::UNKNOWN_TID, "Can't find transaction", msg.id}; + throw DhtProtocolException {DhtProtocolException::UNKNOWN_TID, "Can't find transaction", msg->id}; } auto req = reqp->second; - auto node = req->node; - if (node->id != msg.id) { + if (node->id != msg->id) { bool unknown_node = node->id == zeroes; - node = cache.getNode(msg.id, from, now, true); + node = cache.getNode(msg->id, from, now, true); if (unknown_node) { // received reply to a message sent when we didn't know the node ID. req->node = node; @@ -351,7 +453,7 @@ NetworkEngine::processMessage(const uint8_t *buf, size_t buflen, const SockAddr& node->received(now, req); onNewNode(node, 2); - onReportedAddr(msg.id, msg.addr); + onReportedAddr(msg->id, msg->addr); if (req->cancelled() or req->expired() or (req->completed() and not req->persistent)) { DHT_LOG.WARN("[node %s] response to expired, cancelled or completed request", node->toString().c_str()); @@ -359,25 +461,24 @@ NetworkEngine::processMessage(const uint8_t *buf, size_t buflen, const SockAddr& return; } - switch (msg.type) { + switch (msg->type) { case MessageType::Error: { - if (msg.error_code == DhtProtocolException::UNAUTHORIZED - && msg.id != zeroes - && (msg.tid.matches(TransPrefix::ANNOUNCE_VALUES, &ttid) - || msg.tid.matches(TransPrefix::LISTEN, &ttid))) + if (msg->error_code == DhtProtocolException::UNAUTHORIZED + && msg->id != zeroes + && (msg->tid.matches(TransPrefix::ANNOUNCE_VALUES) + || msg->tid.matches(TransPrefix::LISTEN))) { req->last_try = TIME_INVALID; req->reply_time = TIME_INVALID; onError(req, DhtProtocolException {DhtProtocolException::UNAUTHORIZED}); } else { DHT_LOG.WARN("[node %s %s] received unknown error message %u", - msg.id.toString().c_str(), from.toString().c_str(), msg.error_code); - DHT_LOG.WARN.logPrintable(buf, buflen); + msg->id.toString().c_str(), from.toString().c_str(), msg->error_code); } break; } case MessageType::Reply: - if (msg.type == MessageType::AnnounceValue or msg.type == MessageType::Listen) + if (msg->type == MessageType::AnnounceValue or msg->type == MessageType::Listen) req->node->authSuccess(); // erase before calling callback to make sure iterator is still valid @@ -385,63 +486,63 @@ NetworkEngine::processMessage(const uint8_t *buf, size_t buflen, const SockAddr& requests.erase(reqp); req->reply_time = scheduler.time(); - deserializeNodes(msg); - req->setDone(std::move(msg)); + deserializeNodes(*msg); + req->setDone(std::move(*msg)); break; default: break; } } else { - auto node = cache.getNode(msg.id, from, now, true); + auto node = cache.getNode(msg->id, from, now, true); node->received(now, {}); onNewNode(node, 1); try { - switch (msg.type) { + switch (msg->type) { case MessageType::Ping: ++in_stats.ping; DHT_LOG.DEBUG("Sending pong."); onPing(node); - sendPong(from, msg.tid); + sendPong(from, msg->tid); break; case MessageType::FindNode: { DHT_LOG.DEBUG("[node %s %s] got 'find' request (%d).", - msg.id.toString().c_str(), from.toString().c_str(), msg.want); + msg->id.toString().c_str(), from.toString().c_str(), msg->want); ++in_stats.find; - RequestAnswer answer = onFindNode(node, msg.target, msg.want); - auto nnodes = bufferNodes(from.getFamily(), msg.target, msg.want, answer.nodes4, answer.nodes6); - sendNodesValues(from, msg.tid, nnodes.first, nnodes.second, {}, {}, answer.ntoken); + RequestAnswer answer = onFindNode(node, msg->target, msg->want); + auto nnodes = bufferNodes(from.getFamily(), msg->target, msg->want, answer.nodes4, answer.nodes6); + sendNodesValues(from, msg->tid, nnodes.first, nnodes.second, {}, {}, answer.ntoken); break; } case MessageType::GetValues: { DHT_LOG.DEBUG("[node %s %s] got 'get' request for %s.", - msg.id.toString().c_str(), from.toString().c_str(), msg.info_hash.toString().c_str()); + msg->id.toString().c_str(), from.toString().c_str(), msg->info_hash.toString().c_str()); ++in_stats.get; - RequestAnswer answer = onGetValues(node, msg.info_hash, msg.want, msg.query); - auto nnodes = bufferNodes(from.getFamily(), msg.info_hash, msg.want, answer.nodes4, answer.nodes6); - sendNodesValues(from, msg.tid, nnodes.first, nnodes.second, answer.values, msg.query, answer.ntoken); + RequestAnswer answer = onGetValues(node, msg->info_hash, msg->want, msg->query); + auto nnodes = bufferNodes(from.getFamily(), msg->info_hash, msg->want, answer.nodes4, answer.nodes6); + sendNodesValues(from, msg->tid, nnodes.first, nnodes.second, answer.values, msg->query, answer.ntoken); break; } case MessageType::AnnounceValue: { DHT_LOG.DEBUG("[node %s %s] got 'put' request for %s.", - msg.id.toString().c_str(), from.toString().c_str(), - msg.info_hash.toString().c_str()); + msg->id.toString().c_str(), from.toString().c_str(), + msg->info_hash.toString().c_str()); ++in_stats.put; - onAnnounce(node, msg.info_hash, msg.token, msg.values, msg.created); + onAnnounce(node, msg->info_hash, msg->token, msg->values, msg->created); /* Note that if storageStore failed, we lie to the requestor. This is to prevent them from backtracking, and hence polluting the DHT. */ - for (auto& v : msg.values) { - sendValueAnnounced(from, msg.tid, v->id); + for (auto& v : msg->values) { + sendValueAnnounced(from, msg->tid, v->id); } break; } case MessageType::Listen: { DHT_LOG.DEBUG("[node %s %s] got 'listen' request for %s.", - msg.id.toString().c_str(), from.toString().c_str(), msg.info_hash.toString().c_str()); + msg->id.toString().c_str(), from.toString().c_str(), msg->info_hash.toString().c_str()); ++in_stats.listen; - RequestAnswer answer = onListen(node, msg.info_hash, msg.token, msg.tid.getTid(), std::move(msg.query)); - sendListenConfirmation(from, msg.tid); + RequestAnswer answer = onListen(node, msg->info_hash, msg->token, msg->tid.getTid(), std::move(msg->query)); + sendListenConfirmation(from, msg->tid); break; } default: @@ -450,7 +551,7 @@ NetworkEngine::processMessage(const uint8_t *buf, size_t buflen, const SockAddr& } catch (const std::overflow_error& e) { DHT_LOG.ERR("Can't send value: buffer not large enough !"); } catch (DhtProtocolException& e) { - sendError(from, msg.tid, e.getCode(), e.getMsg().c_str(), true); + sendError(from, msg->tid, e.getCode(), e.getMsg().c_str(), true); } } } @@ -689,6 +790,60 @@ NetworkEngine::deserializeNodes(ParsedMessage& msg) { } } +std::vector<Blob> +NetworkEngine::packValueHeader(msgpack::sbuffer& buffer, const std::vector<std::shared_ptr<Value>>& st) +{ + auto svals = serializeValues(st); + size_t total_size = 0; + for (const auto& v : svals) + total_size += v.size(); + + msgpack::packer<msgpack::sbuffer> pk(&buffer); + pk.pack(std::string("values")); + pk.pack_array(svals.size()); + // try to put everything in a single UDP packet + if (svals.size() < 50 && total_size < MAX_PACKET_VALUE_SIZE) { + for (const auto& b : svals) + buffer.write((const char*)b.data(), b.size()); + DHT_LOG.DEBUG("sending %lu bytes of values", total_size); + svals.clear(); + } else { + for (const auto& b : svals) + pk.pack(b.size()); + } + return svals; +} + +void +NetworkEngine::sendValueParts(TransId tid, const std::vector<Blob>& svals, const SockAddr& addr) +{ + msgpack::sbuffer buffer; + unsigned i=0; + for (const auto& v: svals) { + size_t start {0}, end; + do { + end = std::min(start + MTU, v.size()); + buffer.clear(); + msgpack::packer<msgpack::sbuffer> pk(&buffer); + pk.pack_map(3+(network?1:0)); + if (network) { + pk.pack(std::string("n")); pk.pack(network); + } + pk.pack(std::string("y")); pk.pack(std::string("v")); + pk.pack(std::string("t")); pk.pack_bin(tid.size()); + pk.pack_bin_body((const char*)tid.data(), tid.size()); + pk.pack(std::string("p")); pk.pack_map(1); + pk.pack(i); pk.pack_map(2); + pk.pack(std::string("o")); pk.pack(start); + pk.pack(std::string("d")); pk.pack_bin(end-start); + pk.pack_bin_body((const char*)v.data()+start, end-start); + send(buffer.data(), buffer.size(), 0, addr); + start = end; + } while (start != v.size()); + i++; + } +} + void NetworkEngine::sendNodesValues(const SockAddr& addr, TransId tid, const Blob& nodes, const Blob& nodes6, const std::vector<std::shared_ptr<Value>>& st, const Query& query, const Blob& token) { @@ -713,34 +868,11 @@ NetworkEngine::sendNodesValues(const SockAddr& addr, TransId tid, const Blob& no if (not token.empty()) { pk.pack(std::string("token")); packToken(pk, token); } + std::vector<Blob> svals {}; if (not st.empty()) { /* pack complete values */ auto fields = query.select.getSelection(); - size_t total_size = 0; if (fields.empty()) { - // We treat the storage as a circular list, and serve a randomly - // chosen slice. In order to make sure we fit, - // we limit ourselves to 50 values. - std::uniform_int_distribution<> pos_dis(0, st.size()-1); - std::vector<Blob> subset {}; - subset.reserve(std::min<size_t>(st.size(), 50)); - - unsigned j0 = pos_dis(rd_device); - unsigned j = j0; - unsigned k = 0; - - do { - subset.emplace_back(packMsg(st[j])); - total_size += subset.back().size(); - ++k; - j = (j + 1) % st.size(); - } while (j != j0 && k < 50 && total_size < MAX_VALUE_SIZE); - - pk.pack(std::string("values")); - pk.pack_array(subset.size()); - for (const auto& b : subset) - buffer.write((const char*)b.data(), b.size()); - DHT_LOG.DEBUG("sending closest nodes (%d+%d nodes.), %lu bytes of values", - nodes.size(), nodes6.size(), total_size); + svals = packValueHeader(buffer, st); } else { /* pack fields */ pk.pack(std::string("fields")); pk.pack_map(2); @@ -754,8 +886,6 @@ NetworkEngine::sendNodesValues(const SockAddr& addr, TransId tid, const Blob& no } else DHT_LOG.DEBUG("sending closest nodes (%d+%d nodes.)", nodes.size(), nodes6.size()); - DHT_LOG.DEBUG("sending closest nodes (%d+%d nodes.)", nodes.size(), nodes6.size()); - pk.pack(std::string("t")); pk.pack_bin(tid.size()); pk.pack_bin_body((const char*)tid.data(), tid.size()); pk.pack(std::string("y")); pk.pack(std::string("r")); @@ -764,7 +894,12 @@ NetworkEngine::sendNodesValues(const SockAddr& addr, TransId tid, const Blob& no pk.pack(std::string("n")); pk.pack(network); } + // send response send(buffer.data(), buffer.size(), 0, addr); + + // send parts + if (not svals.empty()) + sendValueParts(tid, svals, addr); } Blob @@ -884,7 +1019,7 @@ NetworkEngine::sendListenConfirmation(const SockAddr& addr, TransId tid) { } std::shared_ptr<Request> -NetworkEngine::sendAnnounceValue(std::shared_ptr<Node> n, const InfoHash& infohash, const Value& value, time_point created, +NetworkEngine::sendAnnounceValue(std::shared_ptr<Node> n, const InfoHash& infohash, const std::shared_ptr<Value>& value, time_point created, const Blob& token, RequestCb on_done, RequestExpiredCb on_expired) { auto tid = TransId {TransPrefix::ANNOUNCE_VALUES, getNewTid()}; msgpack::sbuffer buffer; @@ -894,7 +1029,7 @@ NetworkEngine::sendAnnounceValue(std::shared_ptr<Node> n, const InfoHash& infoha pk.pack(std::string("a")); pk.pack_map((created < scheduler.time() ? 5 : 4)); pk.pack(std::string("id")); pk.pack(myid); pk.pack(std::string("h")); pk.pack(infohash); - pk.pack(std::string("values")); pk.pack_array(1); pk.pack(value); + auto v = packValueHeader(buffer, {value}); if (created < scheduler.time()) { pk.pack(std::string("c")); pk.pack(to_time_t(created)); @@ -930,6 +1065,8 @@ NetworkEngine::sendAnnounceValue(std::shared_ptr<Node> n, const InfoHash& infoha } }); sendRequest(req); + if (not v.empty()) + sendValueParts(tid, v, n->addr); ++out_stats.put; return req; } @@ -992,6 +1129,16 @@ ParsedMessage::msgpack_unpack(msgpack::object msg) auto y = findMapValue(msg, "y"); auto r = findMapValue(msg, "r"); auto e = findMapValue(msg, "e"); + auto v = findMapValue(msg, "p"); + + if (auto t = findMapValue(msg, "t")) + tid = {t->as<std::array<char, 4>>()}; + + if (auto rv = findMapValue(msg, "v")) + ua = rv->as<std::string>(); + + if (auto netid = findMapValue(msg, "n")) + network = netid->as<NetId>(); std::string q; if (auto rq = findMapValue(msg, "q")) { @@ -1004,6 +1151,8 @@ ParsedMessage::msgpack_unpack(msgpack::object msg) type = MessageType::Error; else if (r) type = MessageType::Reply; + else if (v) + type = MessageType::ValueData; else if (y and y->as<std::string>() != "q") throw msgpack::type_error(); else if (q == "ping") @@ -1019,6 +1168,20 @@ ParsedMessage::msgpack_unpack(msgpack::object msg) else throw msgpack::type_error(); + if (type == MessageType::ValueData) { + if (v->type != msgpack::type::MAP) + throw msgpack::type_error(); + for (size_t i = 0; i < v->via.map.size; ++i) { + auto& vdat = v->via.map.ptr[i]; + auto o = findMapValue(vdat.val, "o"); + auto d = findMapValue(vdat.val, "d"); + if (not o or not d) + continue; + value_parts.emplace(vdat.key.as<unsigned>(), std::pair<size_t, Blob>(o->as<size_t>(), unpackBlob(*d))); + } + return; + } + auto a = findMapValue(msg, "a"); if (!a && !r && !e) throw msgpack::type_error(); @@ -1030,9 +1193,6 @@ ParsedMessage::msgpack_unpack(msgpack::object msg) error_code = e->via.array.ptr[0].as<uint16_t>(); } - if (auto netid = findMapValue(msg, "n")) - network = netid->as<NetId>(); - if (auto rid = findMapValue(req, "id")) id = {*rid}; @@ -1085,12 +1245,21 @@ ParsedMessage::msgpack_unpack(msgpack::object msg) if (auto rvalues = findMapValue(req, "values")) { if (rvalues->type != msgpack::type::ARRAY) throw msgpack::type_error(); - for (size_t i = 0; i < rvalues->via.array.size; i++) - try { - values.emplace_back(std::make_shared<Value>(rvalues->via.array.ptr[i])); - } catch (const std::exception& e) { - //DHT_LOG.WARN("Error reading value: %s", e.what()); + for (size_t i = 0; i < rvalues->via.array.size; i++) { + auto& packed_v = rvalues->via.array.ptr[i]; + if (packed_v.type == msgpack::type::POSITIVE_INTEGER) { + // Skip oversize values with a small margin for header overhead + if (packed_v.via.u64 > MAX_VALUE_SIZE + 32) + continue; + value_parts.emplace(i, std::make_pair(packed_v.via.u64, Blob{})); + } else { + try { + values.emplace_back(std::make_shared<Value>(rvalues->via.array.ptr[i])); + } catch (const std::exception& e) { + //DHT_LOG.WARN("Error reading value: %s", e.what()); + } } + } } else if (auto raw_fields = findMapValue(req, "fields")) { if (auto rfields = findMapValue(*raw_fields, "f")) { auto fields_ = rfields->as<std::set<Value::Field>>(); @@ -1127,13 +1296,59 @@ ParsedMessage::msgpack_unpack(msgpack::object msg) } else { want = -1; } +} - if (auto t = findMapValue(msg, "t")) - tid = {t->as<std::array<char, 4>>()}; +void +NetworkEngine::maintainRxBuffer(const TransId& tid) +{ + const auto& now = scheduler.time(); + auto msg = partial_messages.find(tid); + if (msg != partial_messages.end()) { + if (msg->second.start + RX_MAX_PACKET_TIME < now + || msg->second.last_part + RX_TIMEOUT < now) { + DHT_LOG.WARN("Dropping expired partial message from %s", msg->second.from.toString().c_str()); + partial_messages.erase(msg); + } + } +} - if (auto rv = findMapValue(msg, "v")) - ua = rv->as<std::string>(); +bool +ParsedMessage::append(const ParsedMessage& block) +{ + bool ret(false); + for (const auto& ve : block.value_parts) { + auto part_val = value_parts.find(ve.first); + if (part_val == value_parts.end() + || part_val->second.second.size() >= part_val->second.first) + continue; + // TODO: handle out-of-order packets + if (ve.second.first != part_val->second.second.size()) { + //std::cout << "skipping out-of-order packet" << std::endl; + continue; + } + ret = true; + part_val->second.second.insert(part_val->second.second.end(), + ve.second.second.begin(), + ve.second.second.end()); + } + return ret; +} +bool +ParsedMessage::complete() +{ + for (auto& e : value_parts) { + //std::cout << "part " << e.first << ": " << e.second.second.size() << "/" << e.second.first << std::endl; + if (e.second.first > e.second.second.size()) + return false; + } + for (auto& e : value_parts) { + msgpack::unpacked msg; + msgpack::unpack(msg, (const char*)e.second.second.data(), e.second.second.size()); + values.emplace_back(std::make_shared<Value>(msg.get())); + } + return true; } + } diff --git a/src/utils.cpp b/src/utils.cpp index a75ff99890923f1699a19df137eb1a8fbf2274e3..3b68221c37047fb6d904836d93884e4ce3baaf4a 100644 --- a/src/utils.cpp +++ b/src/utils.cpp @@ -52,6 +52,10 @@ printAddr(const SockAddr& addr) { return print_addr((const sockaddr*)&addr.first, addr.second); } +bool operator==(const SockAddr& a, const SockAddr& b) { + return a.equals(b); +} + time_point from_time_t(std::time_t t) { return clock::now() + (std::chrono::system_clock::from_time_t(t) - std::chrono::system_clock::now()); } diff --git a/src/value.cpp b/src/value.cpp index a1f4b628254dc01bff919ef7951a687b9ee5b832..b49ce49cbf0edd91aa5354860d3e430cc5c29acf 100644 --- a/src/value.cpp +++ b/src/value.cpp @@ -65,7 +65,7 @@ const ValueType ValueType::USER_DATA = {0, "User Data"}; bool ValueType::DEFAULT_STORE_POLICY(InfoHash, std::shared_ptr<Value>& v, InfoHash, const sockaddr*, socklen_t) { - return v->data.size() <= MAX_VALUE_SIZE; + return v->size() <= MAX_VALUE_SIZE; } msgpack::object*