From bc32f848cbefbfd1ef01ef36cdc2b8cf18a9f635 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Adrien=20B=C3=A9raud?= <adrien.beraud@savoirfairelinux.com>
Date: Wed, 16 Feb 2022 00:43:59 -0500
Subject: [PATCH] thread pool: cleanup

---
 include/opendht/thread_pool.h |  9 ++++-----
 src/thread_pool.cpp           | 35 +++++++++++------------------------
 tests/threadpooltester.cpp    |  2 +-
 3 files changed, 16 insertions(+), 30 deletions(-)

diff --git a/include/opendht/thread_pool.h b/include/opendht/thread_pool.h
index f5b294b5..7dda70ef 100644
--- a/include/opendht/thread_pool.h
+++ b/include/opendht/thread_pool.h
@@ -65,15 +65,14 @@ public:
     void join();
 
 private:
-    struct ThreadState;
-    std::queue<std::function<void()>> tasks_ {};
-    std::vector<std::unique_ptr<ThreadState>> threads_;
-    unsigned readyThreads_ {0};
     std::mutex lock_ {};
     std::condition_variable cv_ {};
+    std::queue<std::function<void()>> tasks_ {};
+    std::vector<std::unique_ptr<std::thread>> threads_;
+    unsigned readyThreads_ {0};
+    bool running_ {true};
 
     const unsigned maxThreads_;
-    bool running_ {true};
 };
 
 class OPENDHT_PUBLIC Executor : public std::enable_shared_from_this<Executor> {
diff --git a/src/thread_pool.cpp b/src/thread_pool.cpp
index fb56365b..e344cdba 100644
--- a/src/thread_pool.cpp
+++ b/src/thread_pool.cpp
@@ -28,12 +28,6 @@ namespace dht {
 
 constexpr const size_t IO_THREADS_MAX {64};
 
-struct ThreadPool::ThreadState
-{
-    std::thread thread {};
-    std::atomic_bool run {true};
-};
-
 ThreadPool&
 ThreadPool::computation()
 {
@@ -67,14 +61,12 @@ void
 ThreadPool::run(std::function<void()>&& cb)
 {
     std::unique_lock<std::mutex> l(lock_);
-    if (not running_) return;
+    if (not cb or not running_) return;
 
     // launch new thread if necessary
     if (not readyThreads_ && threads_.size() < maxThreads_) {
-        threads_.emplace_back(new ThreadState());
-        auto& t = *threads_.back();
-        t.thread = std::thread([&]() {
-            while (t.run) {
+        threads_.emplace_back(std::make_unique<std::thread>([this]() {
+            while (true) {
                 std::function<void()> task;
 
                 // pick task from queue
@@ -82,10 +74,10 @@ ThreadPool::run(std::function<void()>&& cb)
                     std::unique_lock<std::mutex> l(lock_);
                     readyThreads_++;
                     cv_.wait(l, [&](){
-                        return not t.run or not tasks_.empty();
+                        return not running_ or not tasks_.empty();
                     });
                     readyThreads_--;
-                    if (not t.run)
+                    if (not running_)
                         break;
                     task = std::move(tasks_.front());
                     tasks_.pop();
@@ -93,14 +85,13 @@ ThreadPool::run(std::function<void()>&& cb)
 
                 // run task
                 try {
-                    if (task)
-                        task();
+                    task();
                 } catch (const std::exception& e) {
                     // LOG_ERR("Exception running task: %s", e.what());
                     std::cerr << "Exception running task: " << e.what() << std::endl;
                 }
             }
-        });
+        }));
     }
 
     // push task to queue
@@ -113,12 +104,8 @@ ThreadPool::run(std::function<void()>&& cb)
 void
 ThreadPool::stop()
 {
-    {
-        std::lock_guard<std::mutex> l(lock_);
-        running_ = false;
-    }
-    for (auto& t : threads_)
-        t->run = false;
+    std::lock_guard<std::mutex> l(lock_);
+    running_ = false;
     cv_.notify_all();
 }
 
@@ -127,7 +114,7 @@ ThreadPool::join()
 {
     stop();
     for (auto& t : threads_)
-        t->thread.join();
+        t->join();
     threads_.clear();
 }
 
@@ -147,7 +134,7 @@ Executor::run_(std::function<void()>&& task)
 {
     current_++;
     std::weak_ptr<Executor> w = shared_from_this();
-    threadPool_.get().run([w,task] {
+    threadPool_.get().run([w,task = std::move(task)] {
         try {
             task();
         } catch (const std::exception& e) {
diff --git a/tests/threadpooltester.cpp b/tests/threadpooltester.cpp
index 4a5199f6..00c49dfa 100644
--- a/tests/threadpooltester.cpp
+++ b/tests/threadpooltester.cpp
@@ -48,7 +48,7 @@ ThreadPoolTester::testThreadPool() {
         std::this_thread::sleep_for(std::chrono::milliseconds(10));
 
     pool.join();
-    CPPUNIT_ASSERT(count.load() == N);
+    CPPUNIT_ASSERT_EQUAL(N, count.load());
 }
 
 void
-- 
GitLab