diff --git a/include/opendht/dht.h b/include/opendht/dht.h index ed234110d86d1d00fe5d192184b8b8a50ba4542b..df7ef24588d7a13e7a9f7c66d6671a451f7a6e37 100644 --- a/include/opendht/dht.h +++ b/include/opendht/dht.h @@ -95,7 +95,7 @@ public: /** * Performs final operations before quitting. */ - void shutdown(ShutdownCallback cb) override; + void shutdown(ShutdownCallback cb, bool stop = false) override; /** * Returns true if the node is running (have access to an open socket). diff --git a/include/opendht/dht_interface.h b/include/opendht/dht_interface.h index 80045bff78a90f51e6ac13a12d6764f09b5bb1d8..4efcabde0aed13b24b90f0fce8a4fd5f51cb7409 100644 --- a/include/opendht/dht_interface.h +++ b/include/opendht/dht_interface.h @@ -61,8 +61,10 @@ public: /** * Performs final operations before quitting. + * stop: if true, cancel ongoing operations and call their 'done' + * callbacks synchronously. */ - virtual void shutdown(ShutdownCallback cb) = 0; + virtual void shutdown(ShutdownCallback cb, bool stop = false) = 0; /** * Returns true if the node is running (have access to an open socket). diff --git a/include/opendht/dht_proxy_client.h b/include/opendht/dht_proxy_client.h index 5ba3b7e75912bef40d5333b4b3dbc34bf1db8319..b4b45090bf74abd2cfb05b222bb7380920b09fb8 100644 --- a/include/opendht/dht_proxy_client.h +++ b/include/opendht/dht_proxy_client.h @@ -85,7 +85,7 @@ public: /** * Performs final operations before quitting. */ - void shutdown(ShutdownCallback cb) override; + void shutdown(ShutdownCallback cb, bool) override; /** * Returns true if the node is running (have access to an open socket). diff --git a/include/opendht/dhtrunner.h b/include/opendht/dhtrunner.h index 845a1f891a24417b7255d0cd2b8773938eb8d30a..31d3ae26834fed2174c836a320365c46446b8e47 100644 --- a/include/opendht/dhtrunner.h +++ b/include/opendht/dhtrunner.h @@ -408,7 +408,7 @@ public: /** * Gracefuly disconnect from network. */ - void shutdown(ShutdownCallback cb = {}); + void shutdown(ShutdownCallback cb = {}, bool stop = false); /** * Quit and wait for all threads to terminate. diff --git a/include/opendht/securedht.h b/include/opendht/securedht.h index 8db38aad5144040011cc7a343074551f2333eeb7..c944e568c60c38659ae99c6c419be6ad1537e01b 100644 --- a/include/opendht/securedht.h +++ b/include/opendht/securedht.h @@ -146,8 +146,8 @@ public: /** * SecureDht to Dht proxy */ - void shutdown(ShutdownCallback cb) override { - dht_->shutdown(cb); + void shutdown(ShutdownCallback cb, bool stop = false) override { + dht_->shutdown(cb, stop); } void dumpTables() const override { dht_->dumpTables(); diff --git a/src/dht.cpp b/src/dht.cpp index 331e418b745b5a5a4706cf08318ff662eb87c270..af0cc5ded356621f7c2fc903872a827998f79fa2 100644 --- a/src/dht.cpp +++ b/src/dht.cpp @@ -77,11 +77,27 @@ Dht::Kad::getStatus(time_point now) const } void -Dht::shutdown(ShutdownCallback cb) +Dht::shutdown(ShutdownCallback cb, bool stop) { if (not persistPath.empty()) saveState(persistPath); + if (stop) { + for (auto dht : {&dht4, &dht6}) { + for (auto& sr : dht->searches) { + for (const auto& r : sr.second->callbacks) + r.second.done_cb(false, {}); + sr.second->callbacks.clear(); + for (const auto& a : sr.second->announce) { + if (a.callback) a.callback(false, {}); + } + sr.second->announce.clear(); + sr.second->listeners.clear(); + } + } + network_engine.clear(); + } + if (not maintain_storage) { if (cb) cb(); return; diff --git a/src/dht_proxy_client.cpp b/src/dht_proxy_client.cpp index 60cd50d5221c410b2d9fd78e40cd395b096a1796..eb434970013705543db676be92f9a23f086b684b 100644 --- a/src/dht_proxy_client.cpp +++ b/src/dht_proxy_client.cpp @@ -242,7 +242,7 @@ DhtProxyClient::cancelAllListeners() } void -DhtProxyClient::shutdown(ShutdownCallback cb) +DhtProxyClient::shutdown(ShutdownCallback cb, bool) { stop(); if (cb) diff --git a/src/dhtrunner.cpp b/src/dhtrunner.cpp index c3d440d3c706dce8d2ce6b8d26ab2f670f711cc5..b43854bed771ba42034c9a96f5fc2ea96103393d 100644 --- a/src/dhtrunner.cpp +++ b/src/dhtrunner.cpp @@ -268,29 +268,31 @@ DhtRunner::run(const Config& config, Context&& context) } void -DhtRunner::shutdown(ShutdownCallback cb) { +DhtRunner::shutdown(ShutdownCallback cb, bool stop) { + std::unique_lock<std::mutex> lck(storage_mtx); auto expected = State::Running; if (not running.compare_exchange_strong(expected, State::Stopping)) { if (expected == State::Stopping and ongoing_ops) { - std::lock_guard<std::mutex> lck(storage_mtx); shutdownCallbacks_.emplace_back(std::move(cb)); } - else if (cb) cb(); + else if (cb) { + lck.unlock(); + cb(); + } return; } if (logger_) logger_->d("[runner %p] state changed to Stopping, %zu ongoing ops", this, ongoing_ops.load()); - std::lock_guard<std::mutex> lck(storage_mtx); ongoing_ops++; shutdownCallbacks_.emplace_back(std::move(cb)); - pending_ops_prio.emplace([=](SecureDht&) mutable { + pending_ops.emplace([=](SecureDht&) mutable { auto onShutdown = [this]{ opEnded(); }; #ifdef OPENDHT_PROXY_CLIENT if (dht_via_proxy_) - dht_via_proxy_->shutdown(onShutdown); + dht_via_proxy_->shutdown(onShutdown, stop); #endif if (dht_) - dht_->shutdown(onShutdown); + dht_->shutdown(onShutdown, stop); }); cv.notify_all(); } @@ -319,11 +321,11 @@ DhtRunner::bindOpDoneCallback(DoneCallbackSimple&& cb) { bool DhtRunner::checkShutdown() { - if (running != State::Stopping or ongoing_ops) - return false; decltype(shutdownCallbacks_) cbs; { std::lock_guard<std::mutex> lck(storage_mtx); + if (running != State::Stopping or ongoing_ops) + return false; cbs = std::move(shutdownCallbacks_); } for (auto& cb : cbs) @@ -355,9 +357,13 @@ DhtRunner::join() { std::lock_guard<std::mutex> lck(storage_mtx); + if (ongoing_ops and logger_) { + logger_->w("[runner %p] stopping with %zu remaining ops", this, ongoing_ops.load()); + } pending_ops = decltype(pending_ops)(); pending_ops_prio = decltype(pending_ops_prio)(); ongoing_ops = 0; + shutdownCallbacks_.clear(); } { std::lock_guard<std::mutex> lck(dht_mtx); @@ -709,11 +715,12 @@ DhtRunner::loop_() void DhtRunner::get(InfoHash hash, GetCallback vcb, DoneCallback dcb, Value::Filter f, Where w) { + std::unique_lock<std::mutex> lck(storage_mtx); if (running != State::Running) { + lck.unlock(); if (dcb) dcb(false, {}); return; } - std::lock_guard<std::mutex> lck(storage_mtx); ongoing_ops++; pending_ops.emplace([=](SecureDht& dht) mutable { dht.get(hash, std::move(vcb), bindOpDoneCallback(std::move(dcb)), std::move(f), std::move(w)); @@ -728,11 +735,12 @@ DhtRunner::get(const std::string& key, GetCallback vcb, DoneCallbackSimple dcb, } void DhtRunner::query(const InfoHash& hash, QueryCallback cb, DoneCallback done_cb, Query q) { + std::unique_lock<std::mutex> lck(storage_mtx); if (running != State::Running) { + lck.unlock(); if (done_cb) done_cb(false, {}); return; } - std::lock_guard<std::mutex> lck(storage_mtx); ongoing_ops++; pending_ops.emplace([=](SecureDht& dht) mutable { dht.query(hash, std::move(cb), bindOpDoneCallback(std::move(done_cb)), std::move(q)); @@ -744,11 +752,12 @@ std::future<size_t> DhtRunner::listen(InfoHash hash, ValueCallback vcb, Value::Filter f, Where w) { auto ret_token = std::make_shared<std::promise<size_t>>(); + std::unique_lock<std::mutex> lck(storage_mtx); if (running != State::Running) { + lck.unlock(); ret_token->set_value(0); return ret_token->get_future(); } - std::lock_guard<std::mutex> lck(storage_mtx); pending_ops.emplace([=](SecureDht& dht) mutable { #ifdef OPENDHT_PROXY_CLIENT auto tokenbGlobal = listener_token_++; @@ -829,11 +838,12 @@ DhtRunner::cancelListen(InfoHash h, std::shared_future<size_t> ftoken) void DhtRunner::put(InfoHash hash, Value&& value, DoneCallback cb, time_point created, bool permanent) { + std::unique_lock<std::mutex> lck(storage_mtx); if (running != State::Running) { + lck.unlock(); if (cb) cb(false, {}); return; } - std::lock_guard<std::mutex> lck(storage_mtx); ongoing_ops++; pending_ops.emplace([=, cb = std::move(cb), @@ -847,11 +857,12 @@ DhtRunner::put(InfoHash hash, Value&& value, DoneCallback cb, time_point created void DhtRunner::put(InfoHash hash, std::shared_ptr<Value> value, DoneCallback cb, time_point created, bool permanent) { + std::unique_lock<std::mutex> lck(storage_mtx); if (running != State::Running) { + lck.unlock(); if (cb) cb(false, {}); return; } - std::lock_guard<std::mutex> lck(storage_mtx); ongoing_ops++; pending_ops.emplace([=, cb = std::move(cb)](SecureDht& dht) mutable { dht.put(hash, value, bindOpDoneCallback(std::move(cb)), created, permanent); @@ -888,11 +899,12 @@ DhtRunner::cancelPut(const InfoHash& h, const std::shared_ptr<Value>& value) void DhtRunner::putSigned(InfoHash hash, std::shared_ptr<Value> value, DoneCallback cb, bool permanent) { + std::unique_lock<std::mutex> lck(storage_mtx); if (running != State::Running) { + lck.unlock(); if (cb) cb(false, {}); return; } - std::lock_guard<std::mutex> lck(storage_mtx); ongoing_ops++; pending_ops.emplace([=, cb = std::move(cb), @@ -918,11 +930,12 @@ DhtRunner::putSigned(const std::string& key, Value&& value, DoneCallbackSimple c void DhtRunner::putEncrypted(InfoHash hash, InfoHash to, std::shared_ptr<Value> value, DoneCallback cb, bool permanent) { + std::unique_lock<std::mutex> lck(storage_mtx); if (running != State::Running) { + lck.unlock(); if (cb) cb(false, {}); return; } - std::lock_guard<std::mutex> lck(storage_mtx); ongoing_ops++; pending_ops.emplace([=, cb = std::move(cb), @@ -1008,11 +1021,12 @@ DhtRunner::bootstrap(std::vector<SockAddr> nodes, DoneCallbackSimple&& cb) void DhtRunner::bootstrap(const SockAddr& addr, DoneCallbackSimple&& cb) { + std::unique_lock<std::mutex> lck(storage_mtx); if (running != State::Running) { + lck.unlock(); if (cb) cb(false); return; } - std::lock_guard<std::mutex> lck(storage_mtx); ongoing_ops++; pending_ops_prio.emplace([addr, cb = bindOpDoneCallback(std::move(cb))](SecureDht& dht) mutable { dht.pingNode(std::move(addr), std::move(cb)); @@ -1023,9 +1037,9 @@ DhtRunner::bootstrap(const SockAddr& addr, DoneCallbackSimple&& cb) void DhtRunner::bootstrap(const InfoHash& id, const SockAddr& address) { + std::lock_guard<std::mutex> lck(storage_mtx); if (running != State::Running) return; - std::unique_lock<std::mutex> lck(storage_mtx); pending_ops_prio.emplace([id, address](SecureDht& dht) mutable { dht.insertNode(id, address); }); @@ -1035,9 +1049,9 @@ DhtRunner::bootstrap(const InfoHash& id, const SockAddr& address) void DhtRunner::bootstrap(const std::vector<NodeExport>& nodes) { + std::lock_guard<std::mutex> lck(storage_mtx); if (running != State::Running) return; - std::lock_guard<std::mutex> lck(storage_mtx); pending_ops_prio.emplace([=](SecureDht& dht) { for (auto& node : nodes) dht.insertNode(node); @@ -1061,11 +1075,12 @@ DhtRunner::connectivityChanged() void DhtRunner::findCertificate(InfoHash hash, std::function<void(const Sp<crypto::Certificate>&)> cb) { + std::unique_lock<std::mutex> lck(storage_mtx); if (running != State::Running) { + lck.unlock(); cb({}); return; } - std::lock_guard<std::mutex> lck(storage_mtx); ongoing_ops++; pending_ops.emplace([this, hash, cb = std::move(cb)] (SecureDht& dht) { dht.findCertificate(hash, [this, cb = std::move(cb)](const Sp<crypto::Certificate>& crt){ diff --git a/tests/dhtproxytester.cpp b/tests/dhtproxytester.cpp index 21650bcc68fbc7b2f2ad0eecd0d76386af18dfff..649a9cec385b745b63c8f1fe952aa6fc909309a5 100644 --- a/tests/dhtproxytester.cpp +++ b/tests/dhtproxytester.cpp @@ -289,4 +289,47 @@ DhtProxyTester::testFuzzy() CPPUNIT_ASSERT(value->data == mtu); } +void +DhtProxyTester::testShutdownStop() +{ + constexpr size_t N = 40000; + constexpr unsigned C = 100; + + // Arrange + auto key = dht::InfoHash::get("testShutdownStop"); + std::vector<std::shared_ptr<dht::Value>> values; + std::vector<uint8_t> mtu; + mtu.reserve(N); + for (size_t i = 0; i < N; i++) + mtu.emplace_back((i % 2) ? 'T' : 'M'); + + std::atomic_uint callback_count {0}; + + // Act + for (size_t i = 0; i < C; i++) { + auto nodeTest = std::make_shared<dht::DhtRunner>(); + nodeTest->run(0, clientConfig); + nodeTest->put(key, dht::Value(mtu), [&](bool ok) { + callback_count++; + }); + nodeTest->get(key, [&](const std::vector<std::shared_ptr<dht::Value>>& vals){ + values.insert(values.end(), vals.begin(), vals.end()); + return true; + },[&](bool ok){ + callback_count++; + }); + bool done = false; + std::condition_variable cv; + std::mutex cv_m; + nodeTest->shutdown([&]{ + std::lock_guard<std::mutex> lk(cv_m); + done = true; + cv.notify_all(); + }, true); + std::unique_lock<std::mutex> lk(cv_m); + CPPUNIT_ASSERT(cv.wait_for(lk, 10s, [&]{ return done; })); + } + CPPUNIT_ASSERT_EQUAL(2*C, callback_count.load()); +} + } // namespace test diff --git a/tests/dhtproxytester.h b/tests/dhtproxytester.h index 6937c3e9d021c14919e1d61248c1bc6650dd4e9c..675975fb56943e563484315a1ee987c17b49d6cc 100644 --- a/tests/dhtproxytester.h +++ b/tests/dhtproxytester.h @@ -36,6 +36,7 @@ class DhtProxyTester : public CppUnit::TestFixture { CPPUNIT_TEST(testResubscribeGetValues); CPPUNIT_TEST(testPutGet40KChars); CPPUNIT_TEST(testFuzzy); + CPPUNIT_TEST(testShutdownStop); CPPUNIT_TEST_SUITE_END(); public: @@ -68,6 +69,8 @@ class DhtProxyTester : public CppUnit::TestFixture { void testFuzzy(); + void testShutdownStop(); + private: dht::DhtRunner::Config clientConfig {}; dht::DhtRunner nodePeer;