diff --git a/include/opendht/dhtrunner.h b/include/opendht/dhtrunner.h index 33b76a7c395eca3b3d81ff817538f1a61adbd3b5..65ecdba808071ac9d77b145f0a967365ae349b9e 100644 --- a/include/opendht/dhtrunner.h +++ b/include/opendht/dhtrunner.h @@ -340,7 +340,7 @@ public: void importValues(const std::vector<ValuesExport>& values); bool isRunning() const { - return running; + return running != State::Idle; } NodeStats getNodesStats(sa_family_t af) const; @@ -358,7 +358,7 @@ public: // securedht methods - void findCertificate(InfoHash hash, std::function<void(const std::shared_ptr<crypto::Certificate>)>); + void findCertificate(InfoHash hash, std::function<void(const std::shared_ptr<crypto::Certificate>&)>); void registerCertificate(std::shared_ptr<crypto::Certificate> cert); void setLocalCertificateStore(CertificateStoreQuery&& query_method); @@ -411,7 +411,7 @@ public: /** * Gracefuly disconnect from network. */ - void shutdown(ShutdownCallback cb); + void shutdown(ShutdownCallback cb = {}); /** * Quit and wait for all threads to terminate. @@ -449,6 +449,12 @@ public: private: static constexpr std::chrono::seconds BOOTSTRAP_PERIOD {10}; + enum class State { + Idle, + Running, + Stopping + }; + /** * Will try to resolve the list of hostnames `bootstrap_nodes` on seperate * thread and then queue ping requests. This list should contain reliable @@ -463,6 +469,11 @@ private: return std::max(status4, status6); } + bool checkShutdown(); + void opEnded(); + DoneCallback bindOpDoneCallback(DoneCallback&& cb); + DoneCallbackSimple bindOpDoneCallback(DoneCallbackSimple&& cb); + /** Local DHT instance */ std::unique_ptr<SecureDht> dht_; @@ -510,7 +521,9 @@ private: std::queue<std::function<void(SecureDht&)>> pending_ops {}; std::mutex storage_mtx {}; - std::atomic_bool running {false}; + std::atomic<State> running {State::Idle}; + std::atomic_uint ongoing_ops {0}; + std::vector<ShutdownCallback> shutdownCallbacks_; NodeStatus status4 {NodeStatus::Disconnected}, status6 {NodeStatus::Disconnected}; diff --git a/src/dhtrunner.cpp b/src/dhtrunner.cpp index e97d92a06bbb099de7a014278282d48e09cb16e3..26732405d2111c60488aabc79ebcea308968d0ee 100644 --- a/src/dhtrunner.cpp +++ b/src/dhtrunner.cpp @@ -100,7 +100,7 @@ DhtRunner::run(const char* ip4, const char* ip6, const char* service, const Conf void DhtRunner::run(const SockAddr& local4, const SockAddr& local6, const Config& config, Context&& context) { - if (not running) { + if (running == State::Idle) { if (not context.sock) context.sock.reset(new net::UdpSocket(local4, local6, context.logger ? *context.logger : Logger{})); run(config, std::move(context)); @@ -111,11 +111,14 @@ void DhtRunner::run(const Config& config, Context&& context) { std::lock_guard<std::mutex> lck(dht_mtx); - if (running) + auto expected = State::Idle; + if (not running.compare_exchange_strong(expected, State::Running)) return; - if (context.logger) + if (context.logger) { logger_ = context.logger; + logger_->d("[runner %p] state changed to Running", this); + } context.sock->setOnReceive([&] (std::unique_ptr<net::ReceivedPacket>&& pkt) { { @@ -148,16 +151,15 @@ DhtRunner::run(const Config& config, Context&& context) dht_via_proxy_->setLocalCertificateStore(std::move(context.certificateStore)); } - running = true; if (not config.threaded) return; dht_thread = std::thread([this]() { - while (running) { + while (running != State::Idle) { std::unique_lock<std::mutex> lk(dht_mtx); time_point wakeup = loop_(); auto hasJobToDo = [this]() { - if (not running) + if (running == State::Idle) return true; { std::lock_guard<std::mutex> lck(sock_mtx); @@ -219,36 +221,84 @@ DhtRunner::run(const Config& config, Context&& context) void DhtRunner::shutdown(ShutdownCallback cb) { - if (not running) { - cb(); + auto expected = State::Running; + if (not running.compare_exchange_strong(expected, State::Stopping)) { + if (expected == State::Stopping and ongoing_ops) { + std::lock_guard<std::mutex> lck(storage_mtx); + shutdownCallbacks_.emplace_back(std::move(cb)); + } + else if (cb) cb(); return; } + if (logger_) + logger_->d("[runner %p] state changed to Stopping", this); std::lock_guard<std::mutex> lck(storage_mtx); + ongoing_ops++; + shutdownCallbacks_.emplace_back(std::move(cb)); pending_ops_prio.emplace([=](SecureDht&) mutable { + auto onShutdown = [this]{ opEnded(); }; #ifdef OPENDHT_PROXY_CLIENT if (dht_via_proxy_) - dht_via_proxy_->shutdown(cb); + dht_via_proxy_->shutdown(onShutdown); #endif if (dht_) - dht_->shutdown(cb); + dht_->shutdown(onShutdown); }); cv.notify_all(); } +void +DhtRunner::opEnded() { + if (--ongoing_ops == 0) + checkShutdown(); +} + +DoneCallback +DhtRunner::bindOpDoneCallback(DoneCallback&& cb) { + return [this, cb = std::move(cb)](bool ok, const std::vector<std::shared_ptr<Node>>& nodes){ + if (cb) cb(ok, nodes); + opEnded(); + }; +} + +DoneCallbackSimple +DhtRunner::bindOpDoneCallback(DoneCallbackSimple&& cb) { + return [this, cb = std::move(cb)](bool ok){ + if (cb) cb(ok); + opEnded(); + }; +} + +bool +DhtRunner::checkShutdown() { + if (running != State::Stopping or ongoing_ops) + return false; + decltype(shutdownCallbacks_) cbs; + { + std::lock_guard<std::mutex> lck(storage_mtx); + cbs = std::move(shutdownCallbacks_); + } + for (auto& cb : cbs) + if (cb) cb(); + return true; +} + void DhtRunner::join() { - if (peerDiscovery_) - peerDiscovery_->stop(); - { std::lock_guard<std::mutex> lck(dht_mtx); - running = false; + if (running.exchange(State::Idle) == State::Idle) + return; cv.notify_all(); bootstrap_cv.notify_all(); + if (peerDiscovery_) + peerDiscovery_->stop(); if (dht_) if (auto sock = dht_->getSocket()) sock->stop(); + if (logger_) + logger_->d("[runner %p] state changed to Idle", this); } if (dht_thread.joinable()) @@ -257,9 +307,8 @@ DhtRunner::join() if (bootstrap_thread.joinable()) bootstrap_thread.join(); - if (peerDiscovery_) { + if (peerDiscovery_) peerDiscovery_->join(); - } { std::lock_guard<std::mutex> lck(storage_mtx); @@ -572,12 +621,15 @@ DhtRunner::loop_() void DhtRunner::get(InfoHash hash, GetCallback vcb, DoneCallback dcb, Value::Filter f, Where w) { - { - std::lock_guard<std::mutex> lck(storage_mtx); - pending_ops.emplace([=](SecureDht& dht) mutable { - dht.get(hash, std::move(vcb), std::move(dcb), std::move(f), std::move(w)); - }); + if (running != State::Running) { + if (dcb) dcb(false, {}); + return; } + std::lock_guard<std::mutex> lck(storage_mtx); + ongoing_ops++; + pending_ops.emplace([=](SecureDht& dht) mutable { + dht.get(hash, std::move(vcb), bindOpDoneCallback(std::move(dcb)), std::move(f), std::move(w)); + }); cv.notify_all(); } @@ -588,12 +640,15 @@ DhtRunner::get(const std::string& key, GetCallback vcb, DoneCallbackSimple dcb, } void DhtRunner::query(const InfoHash& hash, QueryCallback cb, DoneCallback done_cb, Query q) { - { - std::lock_guard<std::mutex> lck(storage_mtx); - pending_ops.emplace([=](SecureDht& dht) mutable { - dht.query(hash, std::move(cb), std::move(done_cb), std::move(q)); - }); + if (running != State::Running) { + if (done_cb) done_cb(false, {}); + return; } + std::lock_guard<std::mutex> lck(storage_mtx); + ongoing_ops++; + pending_ops.emplace([=](SecureDht& dht) mutable { + dht.query(hash, std::move(cb), bindOpDoneCallback(std::move(done_cb)), std::move(q)); + }); cv.notify_all(); } @@ -601,32 +656,34 @@ std::future<size_t> DhtRunner::listen(InfoHash hash, ValueCallback vcb, Value::Filter f, Where w) { auto ret_token = std::make_shared<std::promise<size_t>>(); - { - std::lock_guard<std::mutex> lck(storage_mtx); - pending_ops.emplace([=](SecureDht& dht) mutable { + if (running != State::Running) { + ret_token->set_value(0); + return ret_token->get_future(); + } + std::lock_guard<std::mutex> lck(storage_mtx); + pending_ops.emplace([=](SecureDht& dht) mutable { #ifdef OPENDHT_PROXY_CLIENT - auto tokenbGlobal = listener_token_++; - auto& listener = listeners_[tokenbGlobal]; - listener.hash = hash; - listener.f = std::move(f); - listener.w = std::move(w); - listener.gcb = [hash,vcb,tokenbGlobal,this](const std::vector<Sp<Value>>& vals, bool expired) { - if (not vcb(vals, expired)) { - cancelListen(hash, tokenbGlobal); - return false; - } - return true; - }; - if (auto token = dht.listen(hash, listener.gcb, listener.f, listener.w)) { - if (use_proxy) listener.tokenProxyDht = token; - else listener.tokenClassicDht = token; + auto tokenbGlobal = listener_token_++; + auto& listener = listeners_[tokenbGlobal]; + listener.hash = hash; + listener.f = std::move(f); + listener.w = std::move(w); + listener.gcb = [hash,vcb,tokenbGlobal,this](const std::vector<Sp<Value>>& vals, bool expired) { + if (not vcb(vals, expired)) { + cancelListen(hash, tokenbGlobal); + return false; } - ret_token->set_value(tokenbGlobal); + return true; + }; + if (auto token = dht.listen(hash, listener.gcb, listener.f, listener.w)) { + if (use_proxy) listener.tokenProxyDht = token; + else listener.tokenClassicDht = token; + } + ret_token->set_value(tokenbGlobal); #else - ret_token->set_value(dht.listen(hash, std::move(vcb), std::move(f), std::move(w))); + ret_token->set_value(dht.listen(hash, std::move(vcb), std::move(f), std::move(w))); #endif - }); - } + }); cv.notify_all(); return ret_token->get_future(); } @@ -640,73 +697,83 @@ DhtRunner::listen(const std::string& key, GetCallback vcb, Value::Filter f, Wher void DhtRunner::cancelListen(InfoHash h, size_t token) { - { - std::lock_guard<std::mutex> lck(storage_mtx); + std::lock_guard<std::mutex> lck(storage_mtx); + ongoing_ops++; #ifdef OPENDHT_PROXY_CLIENT - pending_ops.emplace([=](SecureDht&) { - auto it = listeners_.find(token); - if (it == listeners_.end()) return; - if (it->second.tokenClassicDht) - dht_->cancelListen(h, it->second.tokenClassicDht); - if (it->second.tokenProxyDht and dht_via_proxy_) - dht_via_proxy_->cancelListen(h, it->second.tokenProxyDht); - listeners_.erase(it); - }); + pending_ops.emplace([=](SecureDht&) { + auto it = listeners_.find(token); + if (it == listeners_.end()) return; + if (it->second.tokenClassicDht) + dht_->cancelListen(h, it->second.tokenClassicDht); + if (it->second.tokenProxyDht and dht_via_proxy_) + dht_via_proxy_->cancelListen(h, it->second.tokenProxyDht); + listeners_.erase(it); + opEnded(); + }); #else - pending_ops.emplace([=](SecureDht& dht) { - dht.cancelListen(h, token); - }); + pending_ops.emplace([=](SecureDht& dht) { + dht.cancelListen(h, token); + opEnded(); + }); #endif // OPENDHT_PROXY_CLIENT - } cv.notify_all(); } void DhtRunner::cancelListen(InfoHash h, std::shared_future<size_t> ftoken) { - { - std::lock_guard<std::mutex> lck(storage_mtx); + std::lock_guard<std::mutex> lck(storage_mtx); + ongoing_ops++; #ifdef OPENDHT_PROXY_CLIENT - pending_ops.emplace([=](SecureDht&) { - auto it = listeners_.find(ftoken.get()); - if (it == listeners_.end()) return; - if (it->second.tokenClassicDht) - dht_->cancelListen(h, it->second.tokenClassicDht); - if (it->second.tokenProxyDht and dht_via_proxy_) - dht_via_proxy_->cancelListen(h, it->second.tokenProxyDht); - listeners_.erase(it); - }); + pending_ops.emplace([=](SecureDht&) { + auto it = listeners_.find(ftoken.get()); + if (it == listeners_.end()) return; + if (it->second.tokenClassicDht) + dht_->cancelListen(h, it->second.tokenClassicDht); + if (it->second.tokenProxyDht and dht_via_proxy_) + dht_via_proxy_->cancelListen(h, it->second.tokenProxyDht); + listeners_.erase(it); + opEnded(); + }); #else - pending_ops.emplace([=](SecureDht& dht) { - dht.cancelListen(h, ftoken.get()); - }); + pending_ops.emplace([=](SecureDht& dht) { + dht.cancelListen(h, ftoken.get()); + opEnded(); + }); #endif // OPENDHT_PROXY_CLIENT - } cv.notify_all(); } void DhtRunner::put(InfoHash hash, Value&& value, DoneCallback cb, time_point created, bool permanent) { - { - std::lock_guard<std::mutex> lck(storage_mtx); - auto sv = std::make_shared<Value>(std::move(value)); - pending_ops.emplace([=](SecureDht& dht) { - dht.put(hash, sv, cb, created, permanent); - }); + if (running != State::Running) { + if (cb) cb(false, {}); + return; } + std::lock_guard<std::mutex> lck(storage_mtx); + ongoing_ops++; + pending_ops.emplace([=, + cb = std::move(cb), + sv = std::make_shared<Value>(std::move(value)) + ] (SecureDht& dht) mutable { + dht.put(hash, sv, bindOpDoneCallback(std::move(cb)), created, permanent); + }); cv.notify_all(); } void DhtRunner::put(InfoHash hash, std::shared_ptr<Value> value, DoneCallback cb, time_point created, bool permanent) { - { - std::lock_guard<std::mutex> lck(storage_mtx); - pending_ops.emplace([=](SecureDht& dht) { - dht.put(hash, value, cb, created, permanent); - }); + if (running != State::Running) { + if (cb) cb(false, {}); + return; } + std::lock_guard<std::mutex> lck(storage_mtx); + ongoing_ops++; + pending_ops.emplace([=, cb = std::move(cb)](SecureDht& dht) mutable { + dht.put(hash, value, bindOpDoneCallback(std::move(cb)), created, permanent); + }); cv.notify_all(); } @@ -719,36 +786,42 @@ DhtRunner::put(const std::string& key, Value&& value, DoneCallbackSimple cb, tim void DhtRunner::cancelPut(const InfoHash& h, Value::Id id) { - { - std::lock_guard<std::mutex> lck(storage_mtx); - pending_ops.emplace([=](SecureDht& dht) { - dht.cancelPut(h, id); - }); - } + std::lock_guard<std::mutex> lck(storage_mtx); + ongoing_ops++; + pending_ops.emplace([=](SecureDht& dht) { + dht.cancelPut(h, id); + opEnded(); + }); cv.notify_all(); } void DhtRunner::cancelPut(const InfoHash& h, const std::shared_ptr<Value>& value) { - { - std::lock_guard<std::mutex> lck(storage_mtx); - pending_ops.emplace([=](SecureDht& dht) { - dht.cancelPut(h, value->id); - }); - } + std::lock_guard<std::mutex> lck(storage_mtx); + ongoing_ops++; + pending_ops.emplace([=](SecureDht& dht) { + dht.cancelPut(h, value->id); + opEnded(); + }); cv.notify_all(); } void DhtRunner::putSigned(InfoHash hash, std::shared_ptr<Value> value, DoneCallback cb, bool permanent) { - { - std::lock_guard<std::mutex> lck(storage_mtx); - pending_ops.emplace([=](SecureDht& dht) { - dht.putSigned(hash, value, cb, permanent); - }); + if (running != State::Running) { + if (cb) cb(false, {}); + return; } + std::lock_guard<std::mutex> lck(storage_mtx); + ongoing_ops++; + pending_ops.emplace([=, + cb = std::move(cb), + value = std::move(value) + ](SecureDht& dht) mutable { + dht.putSigned(hash, value, bindOpDoneCallback(std::move(cb)), permanent); + }); cv.notify_all(); } @@ -767,12 +840,18 @@ DhtRunner::putSigned(const std::string& key, Value&& value, DoneCallbackSimple c void DhtRunner::putEncrypted(InfoHash hash, InfoHash to, std::shared_ptr<Value> value, DoneCallback cb, bool permanent) { - { - std::lock_guard<std::mutex> lck(storage_mtx); - pending_ops.emplace([=](SecureDht& dht) { - dht.putEncrypted(hash, to, value, cb, permanent); - }); + if (running != State::Running) { + if (cb) cb(false, {}); + return; } + std::lock_guard<std::mutex> lck(storage_mtx); + ongoing_ops++; + pending_ops.emplace([=, + cb = std::move(cb), + value = std::move(value) + ] (SecureDht& dht) mutable { + dht.putEncrypted(hash, to, value, bindOpDoneCallback(std::move(cb)), permanent); + }); cv.notify_all(); } @@ -817,7 +896,7 @@ DhtRunner::tryBootstrapContinuously() ++ping_count; try { bootstrap(SockAddr::resolve(it->first, it->second), [&](bool) { - if (not running) + if (running != State::Running) return; { std::unique_lock<std::mutex> blck(mtx); @@ -831,15 +910,15 @@ DhtRunner::tryBootstrapContinuously() } } // wait at least until the next BOOTSTRAP_PERIOD - bootstrap_cv.wait_until(blck, next, [&]() { return not running; }); + bootstrap_cv.wait_until(blck, next, [&]() { return running != State::Running; }); // wait for bootstrap requests to end. - if (running) - bootstrap_cv.wait(blck, [&]() { return not running or ping_count == 0; }); + if (running != State::Running) + bootstrap_cv.wait(blck, [&]() { return running != State::Running or ping_count == 0; }); } // update state { std::lock_guard<std::mutex> lck(dht_mtx); - bootstraping = running and + bootstraping = running == State::Running and status4 == NodeStatus::Disconnected and status6 == NodeStatus::Disconnected; } @@ -876,19 +955,28 @@ DhtRunner::clearBootstrap() void DhtRunner::bootstrap(std::vector<SockAddr> nodes, DoneCallbackSimple&& cb) { + if (running != State::Running) { + cb(false); + return; + } std::lock_guard<std::mutex> lck(storage_mtx); - pending_ops_prio.emplace([=](SecureDht& dht) mutable { + ongoing_ops++; + pending_ops_prio.emplace([ + cb = bindOpDoneCallback(std::move(cb)), + nodes = std::move(nodes) + ] (SecureDht& dht) mutable { auto rem = cb ? std::make_shared<std::pair<size_t, bool>>(nodes.size(), false) : nullptr; for (auto& node : nodes) { if (node.getPort() == 0) node.setPort(net::DHT_DEFAULT_PORT); - dht.pingNode(std::move(node), cb ? [rem,cb](bool ok) { + dht.pingNode(std::move(node), [rem,cb](bool ok) { auto& r = *rem; r.first--; r.second |= ok; - if (not r.first) + if (r.first == 0) { cb(r.second); - } : DoneCallbackSimple{}); + } + }); } }); cv.notify_all(); @@ -897,8 +985,13 @@ DhtRunner::bootstrap(std::vector<SockAddr> nodes, DoneCallbackSimple&& cb) void DhtRunner::bootstrap(const SockAddr& addr, DoneCallbackSimple&& cb) { + if (running != State::Running) { + if (cb) cb(false); + return; + } std::lock_guard<std::mutex> lck(storage_mtx); - pending_ops_prio.emplace([addr, cb](SecureDht& dht) mutable { + ongoing_ops++; + pending_ops_prio.emplace([addr, cb = bindOpDoneCallback(std::move(cb))](SecureDht& dht) mutable { dht.pingNode(std::move(addr), std::move(cb)); }); cv.notify_all(); @@ -907,48 +1000,52 @@ DhtRunner::bootstrap(const SockAddr& addr, DoneCallbackSimple&& cb) void DhtRunner::bootstrap(const InfoHash& id, const SockAddr& address) { - { - std::unique_lock<std::mutex> lck(storage_mtx); - pending_ops_prio.emplace([id, address](SecureDht& dht) mutable { - dht.insertNode(id, address); - }); - } + if (running != State::Running) + return; + std::unique_lock<std::mutex> lck(storage_mtx); + pending_ops_prio.emplace([id, address](SecureDht& dht) mutable { + dht.insertNode(id, address); + }); cv.notify_all(); } void DhtRunner::bootstrap(const std::vector<NodeExport>& nodes) { - { - std::lock_guard<std::mutex> lck(storage_mtx); - pending_ops_prio.emplace([=](SecureDht& dht) { - for (auto& node : nodes) - dht.insertNode(node); - }); - } + if (running != State::Running) + return; + std::lock_guard<std::mutex> lck(storage_mtx); + pending_ops_prio.emplace([=](SecureDht& dht) { + for (auto& node : nodes) + dht.insertNode(node); + }); cv.notify_all(); } void DhtRunner::connectivityChanged() { - { - std::lock_guard<std::mutex> lck(storage_mtx); - pending_ops_prio.emplace([=](SecureDht& dht) { - dht.connectivityChanged(); - }); - } + std::lock_guard<std::mutex> lck(storage_mtx); + pending_ops_prio.emplace([=](SecureDht& dht) { + dht.connectivityChanged(); + }); cv.notify_all(); } void -DhtRunner::findCertificate(InfoHash hash, std::function<void(const std::shared_ptr<crypto::Certificate>)> cb) { - { - std::lock_guard<std::mutex> lck(storage_mtx); - pending_ops.emplace([=](SecureDht& dht) { - dht.findCertificate(hash, cb); - }); +DhtRunner::findCertificate(InfoHash hash, std::function<void(const Sp<crypto::Certificate>&)> cb) { + if (running != State::Running) { + cb({}); + return; } + std::lock_guard<std::mutex> lck(storage_mtx); + ongoing_ops++; + pending_ops.emplace([this, hash, cb = std::move(cb)] (SecureDht& dht) { + dht.findCertificate(hash, [this, cb = std::move(cb)](const Sp<crypto::Certificate>& crt){ + cb(crt); + opEnded(); + }); + }); cv.notify_all(); }