From 694f421a4fd9e31c7d2c5f8245003ea4113a7110 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Adrien=20B=C3=A9raud?= <adrien.beraud@savoirfairelinux.com>
Date: Thu, 20 May 2021 16:30:46 -0400
Subject: [PATCH] thread pool: add ExecutionContext

---
 include/opendht/thread_pool.h | 67 +++++++++++++++++++++++++++++++++++
 tests/threadpooltester.cpp    | 23 ++++++++++--
 tests/threadpooltester.h      |  2 ++
 3 files changed, 89 insertions(+), 3 deletions(-)

diff --git a/include/opendht/thread_pool.h b/include/opendht/thread_pool.h
index d7a2995c..f16af058 100644
--- a/include/opendht/thread_pool.h
+++ b/include/opendht/thread_pool.h
@@ -93,4 +93,71 @@ private:
     void schedule();
 };
 
+class OPENDHT_PUBLIC ExecutionContext {
+public:
+    ExecutionContext(ThreadPool& pool)
+     : threadPool_(pool), state_(std::make_shared<SharedState>())
+    {}
+
+    ~ExecutionContext() {
+        state_->destroy();
+    }
+
+    /** Wait for ongoing tasks to complete execution and drop other pending tasks */
+    void stop() {
+        state_->destroy(false);
+    }
+
+    void run(std::function<void()>&& task) {
+        std::lock_guard<std::mutex> lock(state_->mtx);
+        if (state_->shutdown_) return;
+        state_->pendingTasks++;
+        threadPool_.get().run([task = std::move(task), state = state_] {
+            state->run(task);
+        });
+    }
+
+private:
+    struct SharedState {
+        std::mutex mtx {};
+        std::condition_variable cv {};
+        unsigned pendingTasks {0};
+        unsigned ongoingTasks {0};
+        /** When true, prevents new tasks to be scheduled */
+        bool shutdown_ {false};
+        /** When true, prevents scheduled tasks to be executed */
+        std::atomic_bool destroyed {false};
+
+        void destroy(bool wait = true) {
+            std::unique_lock<std::mutex> lock(mtx);
+            if (destroyed) return;
+            if (wait) {
+                cv.wait(lock, [this] { return pendingTasks == 0 && ongoingTasks == 0; });
+            }
+            shutdown_ = true;
+            if (not wait) {
+                cv.wait(lock, [this] { return ongoingTasks == 0; });
+            }
+            destroyed = true;
+        }
+
+        void run(const std::function<void()>& task) {
+            {
+                std::lock_guard<std::mutex> lock(mtx);
+                pendingTasks--;
+                ongoingTasks++;
+            }
+                if (destroyed) return;
+                task();
+            {
+                std::lock_guard<std::mutex> lock(mtx);
+                ongoingTasks--;
+                cv.notify_all();
+            }
+        }
+    };
+    std::reference_wrapper<ThreadPool> threadPool_;
+    std::shared_ptr<SharedState> state_;
+};
+
 }
diff --git a/tests/threadpooltester.cpp b/tests/threadpooltester.cpp
index 4295b87b..be948f94 100644
--- a/tests/threadpooltester.cpp
+++ b/tests/threadpooltester.cpp
@@ -59,7 +59,7 @@ ThreadPoolTester::testExecutor()
     auto executor8 = std::make_shared<dht::Executor>(pool, 8);
 
     constexpr unsigned N = 64 * 1024;
-    std::atomic_uint count1 {0};
+    unsigned count1 {0};
     std::atomic_uint count4 {0};
     std::atomic_uint count8 {0};
     for (unsigned i=0; i<N; i++) {
@@ -69,7 +69,7 @@ ThreadPoolTester::testExecutor()
     }
 
     auto start = clock::now();
-    while ((count1.load() != N ||
+    while ((count1 != N ||
             count4.load() != N ||
             count8.load() != N) && clock::now() - start < std::chrono::seconds(20))
     {
@@ -78,11 +78,28 @@ ThreadPoolTester::testExecutor()
     executor1.reset();
     executor4.reset();
     executor8.reset();
-    CPPUNIT_ASSERT_EQUAL(N, count1.load());
+    CPPUNIT_ASSERT_EQUAL(N, count1);
     CPPUNIT_ASSERT_EQUAL(N, count4.load());
     CPPUNIT_ASSERT_EQUAL(N, count8.load());
 }
 
+void
+ThreadPoolTester::testContext()
+{
+    std::atomic_uint count {0};
+    constexpr unsigned N = 64 * 1024;
+
+    {
+        dht::ExecutionContext ctx(dht::ThreadPool::computation());
+        for (unsigned i=0; i<N; i++) {
+            ctx.run([&] { count++; });
+        }
+    }
+
+    CPPUNIT_ASSERT_EQUAL(N, count.load());
+
+}
+
 void
 ThreadPoolTester::tearDown() {
 }
diff --git a/tests/threadpooltester.h b/tests/threadpooltester.h
index 4dc289c9..b8cbe15d 100644
--- a/tests/threadpooltester.h
+++ b/tests/threadpooltester.h
@@ -29,6 +29,7 @@ class ThreadPoolTester : public CppUnit::TestFixture {
     CPPUNIT_TEST_SUITE(ThreadPoolTester);
     CPPUNIT_TEST(testThreadPool);
     CPPUNIT_TEST(testExecutor);
+    CPPUNIT_TEST(testContext);
     CPPUNIT_TEST_SUITE_END();
 
  public:
@@ -43,6 +44,7 @@ class ThreadPoolTester : public CppUnit::TestFixture {
 
     void testThreadPool();
     void testExecutor();
+    void testContext();
 };
 
 }  // namespace test
-- 
GitLab