diff --git a/include/opendht/network_engine.h b/include/opendht/network_engine.h index 230075be8ccca964f24d58a090cfe3bcd02c0644..c803bf3058b1f069ea5eb03fbcc7f6e1824951cd 100644 --- a/include/opendht/network_engine.h +++ b/include/opendht/network_engine.h @@ -378,10 +378,14 @@ private: } } - bool matches(const TransPrefix prefix, uint16_t *seqno_return = nullptr) const { + uint16_t getTid() const { + return *reinterpret_cast<const uint16_t*>(&(*this)[2]); + } + + bool matches(const TransPrefix prefix, uint16_t* tid = nullptr) const { if (std::equal(begin(), begin()+2, prefix.begin())) { - if (seqno_return) - *seqno_return = *reinterpret_cast<const uint16_t*>(&(*this)[2]); + if (tid) + *tid = getTid(); return true; } else return false; diff --git a/src/dht.cpp b/src/dht.cpp index ac6dd9f0ef05a0ecd81ad42729a784689b9f5e4e..44012575a79070d877d9132568646d1e264e8b01 100644 --- a/src/dht.cpp +++ b/src/dht.cpp @@ -2720,19 +2720,12 @@ Dht::onListen(std::shared_ptr<Node> node, InfoHash& hash, Blob& token, size_t ri } void -Dht::onListenDone(std::shared_ptr<NetworkEngine::RequestStatus> status, NetworkEngine::RequestAnswer&, std::shared_ptr<Search> sr) +Dht::onListenDone(std::shared_ptr<NetworkEngine::RequestStatus> status, NetworkEngine::RequestAnswer& answer, std::shared_ptr<Search> sr) { - const auto& now = scheduler.time(); DHT_LOG.DEBUG("Got reply to listen."); if (sr) { - for (auto& sn : sr->nodes) - if (sn.node == status->node) { - sn.listenStatus->reply_time = now; - break; - } + onGetValuesDone(status, answer, sr); /* See comment for gp above. */ - if (searchSendGetValues(sr)) - sr->get_step_time = now; } } diff --git a/src/network_engine.cpp b/src/network_engine.cpp index fcb67550114e3d15fd1f4a3f0e1f81ab45d76b98..ab0a7e656375176031999a0fa87800de508784ac 100644 --- a/src/network_engine.cpp +++ b/src/network_engine.cpp @@ -128,13 +128,10 @@ NetworkEngine::processMessage(const uint8_t *buf, size_t buflen, const sockaddr uint16_t ttid = 0; if (msg.type == MessageType::Error or msg.type == MessageType::Reply) { - Request* req = nullptr; - const auto& reqp = requests.find(msg.tid[2]); - if (reqp != requests.end()) - req = &(*reqp->second); - - if (not req) + auto reqp = requests.find(msg.tid.getTid()); + if (reqp == requests.end()) throw DhtProtocolException {DhtProtocolException::UNKNOWN_TID, "Can't find transaction", msg.id}; + auto req = reqp->second; auto node = onNewNode(msg.id, from, fromlen, 2); onReportedAddr(msg.id, (sockaddr*)&msg.addr.first, msg.addr.second); @@ -155,9 +152,10 @@ NetworkEngine::processMessage(const uint8_t *buf, size_t buflen, const sockaddr break; } case MessageType::Reply: + requests.erase(reqp); req->status->reply_time = scheduler.time(); + req->status->completed = true; req->on_done(req->status, std::move(msg)); - requests.erase(reqp); break; default: break; @@ -212,7 +210,7 @@ NetworkEngine::processMessage(const uint8_t *buf, size_t buflen, const sockaddr break; } ++in_stats.listen; - RequestAnswer answer = onListen(node, msg.info_hash, msg.token, msg.tid[2]); + RequestAnswer answer = onListen(node, msg.info_hash, msg.token, msg.tid.getTid()); sendListenConfirmation(from, fromlen, msg.tid); break; } @@ -281,7 +279,7 @@ NetworkEngine::sendPing(const sockaddr* sa, socklen_t salen, RequestCb on_done, pk.pack(std::string("v")); pk.pack(my_v); Blob b {buffer.data(), buffer.data() + buffer.size()}; - Request req {tid[2], std::make_shared<Node>(InfoHash {}, sa, salen), std::move(b), + Request req {tid.getTid(), std::make_shared<Node>(InfoHash {}, sa, salen), std::move(b), [=](std::shared_ptr<RequestStatus> req_status, ParsedMessage&&){ if (on_done) { on_done(req_status, {}); @@ -343,7 +341,7 @@ NetworkEngine::sendFindNode(std::shared_ptr<Node> n, const InfoHash& target, wan Blob b {buffer.data(), buffer.data() + buffer.size()}; - Request req {tid[2], n, std::move(b), + Request req {tid.getTid(), n, std::move(b), [=](std::shared_ptr<RequestStatus> req_status, ParsedMessage&& msg) { /* on done */ if (on_done) { on_done(req_status, deserializeNodesValues(msg)); @@ -387,7 +385,7 @@ NetworkEngine::sendGetValues(std::shared_ptr<Node> n, const InfoHash& info_hash, pk.pack(std::string("v")); pk.pack(my_v); Blob b {buffer.data(), buffer.data() + buffer.size()}; - Request req {tid[2], n, std::move(b), + Request req {tid.getTid(), n, std::move(b), [=](std::shared_ptr<RequestStatus> req_status, ParsedMessage&& msg) { /* on done */ if (on_done) { on_done(req_status, deserializeNodesValues(msg)); @@ -597,16 +595,22 @@ NetworkEngine::sendListen(std::shared_ptr<Node> n, const InfoHash& infohash, con Blob b {buffer.data(), buffer.data() + buffer.size()}; - Request req {tid[2], n, std::move(b), + Request req {tid.getTid(), n, std::move(b), [=](std::shared_ptr<RequestStatus> req_status, ParsedMessage&&) { /* on done */ - if (on_done) { + requests.emplace(tid.getTid(), std::make_shared<Request>( + tid.getTid(), req_status->node, Blob {}, + [=](std::shared_ptr<RequestStatus> req_status, ParsedMessage&& msg){ + DHT_LOG.DEBUG("[listen %s] got new values", infohash.toString().c_str()); + if (on_done) + on_done(req_status, deserializeNodesValues(msg)); + }, nullptr + )); + if (on_done) on_done(req_status, {}); - } }, [=](std::shared_ptr<RequestStatus> req_status, bool) { /* on expired */ - if (on_expired) { + if (on_expired) on_expired(req_status, {}); - } } }; auto req_status = req.status; @@ -658,7 +662,7 @@ NetworkEngine::sendAnnounceValue(std::shared_ptr<Node> n, const InfoHash& infoha pk.pack(std::string("v")); pk.pack(my_v); Blob b {buffer.data(), buffer.data() + buffer.size()}; - Request req {tid[2], n, std::move(b), + Request req {tid.getTid(), n, std::move(b), [=](std::shared_ptr<RequestStatus> req_status, ParsedMessage&& msg) { /* on done */ if (msg.value_id == Value::INVALID_ID) { DHT_LOG.DEBUG("Unknown search or announce!");