diff --git a/include/opendht/network_engine.h b/include/opendht/network_engine.h index cb541b991171f34bb3a84e62d209a833ff4d161e..066c7d8aaff44031cb2433741aa2575e49c24434 100644 --- a/include/opendht/network_engine.h +++ b/include/opendht/network_engine.h @@ -217,6 +217,13 @@ public: and now - last_try <= Node::MAX_RESPONSE_TIME; } + void cancel() { + if (not completed) { + cancelled = true; + clear(); + } + } + Request() {} private: @@ -227,11 +234,17 @@ public: std::function<void(std::shared_ptr<Request> req_status, bool)> on_expired, bool persistent = false) : node(node), on_done(on_done), on_expired(on_expired), tid(tid), msg(std::move(msg)), persistent(persistent) { } + void clear() { + on_done = {}; + on_expired = {}; + msg.clear(); + } + std::function<void(std::shared_ptr<Request> req_status, ParsedMessage&&)> on_done {}; std::function<void(std::shared_ptr<Request> req_status, bool)> on_expired {}; const uint16_t tid {0}; /* the request id. */ - const Blob msg {}; /* the serialized message. */ + Blob msg {}; /* the serialized message. */ const bool persistent {false}; /* the request is not erased upon completion. */ }; @@ -241,9 +254,7 @@ public: */ void cancelRequest(std::shared_ptr<Request>& req) { if (req) { - req->cancelled = true; - req->on_done = {}; - req->on_expired = {}; + req->cancel(); requests.erase(req->tid); } } @@ -356,10 +367,8 @@ public: }; void clear() { - for (auto& req : requests) { - req.second->on_expired = {}; - req.second->on_done = {}; - } + for (auto& req : requests) + req.second->cancel(); requests.clear(); } diff --git a/src/dht.cpp b/src/dht.cpp index 390d321c8ccca5dd5d71a679d7ce23b1efb5be7b..aa2bf8e0602a0f6c95e4faf795759d20c7a6a3a5 100644 --- a/src/dht.cpp +++ b/src/dht.cpp @@ -2524,17 +2524,15 @@ Dht::pingNode(const sockaddr *sa, socklen_t salen) } void -Dht::onError(std::shared_ptr<NetworkEngine::Request> status, DhtProtocolException e) { +Dht::onError(std::shared_ptr<NetworkEngine::Request> req, DhtProtocolException e) { if (e.getCode() == DhtProtocolException::UNAUTHORIZED) { - //TODO - //auto esr = searches.find(status); - //if (esr == searches.end()) return; + network_engine.cancelRequest(req); unsigned cleared = 0; - for (auto& srp : status->node->ss.ss_family == AF_INET ? searches4 : searches6) { + for (auto& srp : req->node->ss.ss_family == AF_INET ? searches4 : searches6) { auto& sr = srp.second; for (auto& n : sr->nodes) { - if (n.node != status->node) continue; - n.getStatus = {}; + if (n.node != req->node) continue; + network_engine.cancelRequest(n.getStatus); n.last_get_reply = time_point::min(); cleared++; if (searchSendGetValues(sr)) @@ -2543,7 +2541,7 @@ Dht::onError(std::shared_ptr<NetworkEngine::Request> status, DhtProtocolExceptio } } DHT_LOG.WARN("[node %s %s] token flush (%d searches affected)", - status->node->id.toString().c_str(), print_addr((sockaddr*)&status->node->ss, status->node->sslen).c_str(), cleared); + req->node->id.toString().c_str(), print_addr((sockaddr*)&req->node->ss, req->node->sslen).c_str(), cleared); } } diff --git a/src/network_engine.cpp b/src/network_engine.cpp index 123c2e1b919a923a5867f65b51169a6e10b9935a..b3e24cb3ba26e02b7b906e201c0adf3b08bfdc72 100644 --- a/src/network_engine.cpp +++ b/src/network_engine.cpp @@ -137,6 +137,8 @@ NetworkEngine::processMessage(const uint8_t *buf, size_t buflen, const sockaddr if (reqp == requests.end()) throw DhtProtocolException {DhtProtocolException::UNKNOWN_TID, "Can't find transaction", msg.id}; auto req = reqp->second; + if (req->cancelled) + return; auto node = onNewNode(msg.id, from, fromlen, 2); onReportedAddr(msg.id, (sockaddr*)&msg.addr.first, msg.addr.second); @@ -158,15 +160,14 @@ NetworkEngine::processMessage(const uint8_t *buf, size_t buflen, const sockaddr break; } case MessageType::Reply: - if (not reqp->second->persistent or reqp->second->cancelled) + // erase before calling callback to make sure iterator is still valid + if (not req->persistent) requests.erase(reqp); req->reply_time = scheduler.time(); req->completed = true; req->on_done(req, std::move(msg)); - if (not req->persistent) { - req->on_done = {}; - req->on_expired = {}; - } + if (not req->persistent) + req->clear(); break; default: break;