diff --git a/include/opendht/thread_pool.h b/include/opendht/thread_pool.h index f5b294b566b35b992a13f02a05f91803a970966d..7dda70ef90f7d9fae42324e7d4d1ec1b861d8904 100644 --- a/include/opendht/thread_pool.h +++ b/include/opendht/thread_pool.h @@ -65,15 +65,14 @@ public: void join(); private: - struct ThreadState; - std::queue<std::function<void()>> tasks_ {}; - std::vector<std::unique_ptr<ThreadState>> threads_; - unsigned readyThreads_ {0}; std::mutex lock_ {}; std::condition_variable cv_ {}; + std::queue<std::function<void()>> tasks_ {}; + std::vector<std::unique_ptr<std::thread>> threads_; + unsigned readyThreads_ {0}; + bool running_ {true}; const unsigned maxThreads_; - bool running_ {true}; }; class OPENDHT_PUBLIC Executor : public std::enable_shared_from_this<Executor> { diff --git a/src/thread_pool.cpp b/src/thread_pool.cpp index fb56365b26e00591973b66aee3c05480125abd53..e344cdba182f66dc30c0ccaf92d3107f81372309 100644 --- a/src/thread_pool.cpp +++ b/src/thread_pool.cpp @@ -28,12 +28,6 @@ namespace dht { constexpr const size_t IO_THREADS_MAX {64}; -struct ThreadPool::ThreadState -{ - std::thread thread {}; - std::atomic_bool run {true}; -}; - ThreadPool& ThreadPool::computation() { @@ -67,14 +61,12 @@ void ThreadPool::run(std::function<void()>&& cb) { std::unique_lock<std::mutex> l(lock_); - if (not running_) return; + if (not cb or not running_) return; // launch new thread if necessary if (not readyThreads_ && threads_.size() < maxThreads_) { - threads_.emplace_back(new ThreadState()); - auto& t = *threads_.back(); - t.thread = std::thread([&]() { - while (t.run) { + threads_.emplace_back(std::make_unique<std::thread>([this]() { + while (true) { std::function<void()> task; // pick task from queue @@ -82,10 +74,10 @@ ThreadPool::run(std::function<void()>&& cb) std::unique_lock<std::mutex> l(lock_); readyThreads_++; cv_.wait(l, [&](){ - return not t.run or not tasks_.empty(); + return not running_ or not tasks_.empty(); }); readyThreads_--; - if (not t.run) + if (not running_) break; task = std::move(tasks_.front()); tasks_.pop(); @@ -93,14 +85,13 @@ ThreadPool::run(std::function<void()>&& cb) // run task try { - if (task) - task(); + task(); } catch (const std::exception& e) { // LOG_ERR("Exception running task: %s", e.what()); std::cerr << "Exception running task: " << e.what() << std::endl; } } - }); + })); } // push task to queue @@ -113,12 +104,8 @@ ThreadPool::run(std::function<void()>&& cb) void ThreadPool::stop() { - { - std::lock_guard<std::mutex> l(lock_); - running_ = false; - } - for (auto& t : threads_) - t->run = false; + std::lock_guard<std::mutex> l(lock_); + running_ = false; cv_.notify_all(); } @@ -127,7 +114,7 @@ ThreadPool::join() { stop(); for (auto& t : threads_) - t->thread.join(); + t->join(); threads_.clear(); } @@ -147,7 +134,7 @@ Executor::run_(std::function<void()>&& task) { current_++; std::weak_ptr<Executor> w = shared_from_this(); - threadPool_.get().run([w,task] { + threadPool_.get().run([w,task = std::move(task)] { try { task(); } catch (const std::exception& e) { diff --git a/tests/threadpooltester.cpp b/tests/threadpooltester.cpp index 4a5199f68ad7395c679a01c50c67f4f199620e1d..00c49dfa1ed1fb263b8aec00cb2b80ac5a022a1c 100644 --- a/tests/threadpooltester.cpp +++ b/tests/threadpooltester.cpp @@ -48,7 +48,7 @@ ThreadPoolTester::testThreadPool() { std::this_thread::sleep_for(std::chrono::milliseconds(10)); pool.join(); - CPPUNIT_ASSERT(count.load() == N); + CPPUNIT_ASSERT_EQUAL(N, count.load()); } void