diff --git a/include/opendht/network_engine.h b/include/opendht/network_engine.h index 2bdea5ee2e6fa5b7b4666a7d57ef5716118b87a4..13bfae4a12993e07b0a5b55585d986334daca33a 100644 --- a/include/opendht/network_engine.h +++ b/include/opendht/network_engine.h @@ -211,6 +211,7 @@ private: public: using RequestCb = std::function<void(const Request&, RequestAnswer&&)>; + using RequestErrorCb = std::function<bool(const Request&, DhtProtocolException&&)>; using RequestExpiredCb = std::function<void(const Request&, bool)>; NetworkEngine(const Sp<Logger>& log, std::mt19937_64& rd, Scheduler& scheduler, std::unique_ptr<DatagramSocket>&& sock); @@ -403,6 +404,7 @@ public: const Value::Id& vid, const Blob& token, RequestCb&& on_done, + RequestErrorCb&& on_error, RequestExpiredCb&& on_expired); /** * Send a "update" request to a given node. Used for Listen operations diff --git a/src/dht.cpp b/src/dht.cpp index 547d9757aab2489efa944e333038bc027983c7e5..31dd27b290dac337d95e22a4823865539827f590 100644 --- a/src/dht.cpp +++ b/src/dht.cpp @@ -490,7 +490,29 @@ void Dht::searchSendAnnounceValue(const Sp<Search>& sr) { logger_->w(sr->id, sn->node->id, "[search %s] [node %s] sending 'refresh' (vid: %d)", sr->id.toString().c_str(), sn->node->toString().c_str(), a.value->id); sn->acked[a.value->id] = { - network_engine.sendRefreshValue(sn->node, sr->id, a.value->id, sn->token, onDone, onExpired), + network_engine.sendRefreshValue(sn->node, sr->id, a.value->id, sn->token, onDone, + [this, ws, node=sn->node, v=a.value, + onDone, + onExpired, + created = a.permanent ? time_point::max() : a.created, + next_refresh_time + ](const net::Request& req, net::DhtProtocolException&& e){ + if (e.getCode() == net::DhtProtocolException::NOT_FOUND) { + if (logger_) + logger_->e(node->id, "[node %s] returned error 404: storage not found", node->toString().c_str()); + if (auto sr = ws.lock()) { + if (auto sn = sr->getNode(node)) { + sn->acked[v->id] = { + network_engine.sendAnnounceValue(sn->node, sr->id, v, created, sn->token, onDone, onExpired), + next_refresh_time + }; + scheduler.edit(sr->nextSearchStep, scheduler.time()); + return true; + } + } + } + return false; + }, onExpired), next_refresh_time }; } else { diff --git a/src/network_engine.cpp b/src/network_engine.cpp index f83f26108e1f4388183e8e4710b176afc19bb458..e91e7d1e6a69f5468c42a0c88559bf82d2537bd8 100644 --- a/src/network_engine.cpp +++ b/src/network_engine.cpp @@ -560,8 +560,8 @@ NetworkEngine::process(std::unique_ptr<ParsedMessage>&& msg, const SockAddr& fro { req->last_try = time_point::min(); req->reply_time = time_point::min(); - req->setError(); - onError(req, DhtProtocolException {msg->error_code}); + if (not req->setError(DhtProtocolException {msg->error_code})) + onError(req, DhtProtocolException {msg->error_code}); } else { if (logIncoming_) if (logger_) @@ -1264,6 +1264,7 @@ NetworkEngine::sendRefreshValue(Sp<Node> n, const Value::Id& vid, const Blob& token, RequestCb&& on_done, + RequestErrorCb&& on_error, RequestExpiredCb&& on_expired) { Tid tid (n->getNewTid()); @@ -1299,6 +1300,7 @@ NetworkEngine::sendRefreshValue(Sp<Node> n, } } }, + on_error, [=](const Request& req_status, bool done) { /* on expired */ if (on_expired) { on_expired(req_status, done); diff --git a/src/request.h b/src/request.h index 5b862911e7c51cec27661f3961eec6d7994927af..92aa8dbeedf8a347f26d04ea2ff7206c1db1ee82 100644 --- a/src/request.h +++ b/src/request.h @@ -27,6 +27,7 @@ struct Node; namespace net { class NetworkEngine; +class DhtProtocolException; struct ParsedMessage; /*! @@ -75,11 +76,19 @@ struct Request { std::function<void(const Request&, bool)> on_expired, Tid socket = 0) : node(node), tid(tid), type(type), on_done(on_done), on_expired(on_expired), msg(std::move(msg)), socket(socket) { } + Request(MessageType type, Tid tid, + Sp<Node> node, + Blob&& msg, + std::function<void(const Request&, ParsedMessage&&)> on_done, + std::function<bool(const Request&, DhtProtocolException&&)> on_error, + std::function<void(const Request&, bool)> on_expired, + Tid socket = 0) : + node(node), tid(tid), type(type), on_done(on_done), on_error(on_error), on_expired(on_expired), msg(std::move(msg)), socket(socket) { } Tid getTid() const { return tid; } MessageType getType() const { return type; } - Tid getSocket() { return socket; } + Tid getSocket() const { return socket; } Tid closeSocket() { auto ret = socket; socket = 0; return ret; } void setExpired() { @@ -96,11 +105,14 @@ struct Request { clear(); } } - void setError() { + bool setError(DhtProtocolException&& e) { if (pending()) { - state_ = Request::State::COMPLETED; + state_ = Request::State::EXPIRED; + bool handled = on_error and on_error(*this, std::forward<DhtProtocolException>(e)); clear(); + return handled; } + return true; } void cancel() { @@ -119,6 +131,7 @@ private: void clear() { on_done = {}; + on_error = {}; on_expired = {}; msg = {}; parts = {}; @@ -134,6 +147,7 @@ private: time_point last_try {time_point::min()}; /* time of the last attempt to process the request. */ std::function<void(const Request&, ParsedMessage&&)> on_done {}; + std::function<bool(const Request&, DhtProtocolException&&)> on_error {}; std::function<void(const Request&, bool)> on_expired {}; Blob msg {}; /* the serialized message. */