From 694f421a4fd9e31c7d2c5f8245003ea4113a7110 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrien=20B=C3=A9raud?= <adrien.beraud@savoirfairelinux.com> Date: Thu, 20 May 2021 16:30:46 -0400 Subject: [PATCH] thread pool: add ExecutionContext --- include/opendht/thread_pool.h | 67 +++++++++++++++++++++++++++++++++++ tests/threadpooltester.cpp | 23 ++++++++++-- tests/threadpooltester.h | 2 ++ 3 files changed, 89 insertions(+), 3 deletions(-) diff --git a/include/opendht/thread_pool.h b/include/opendht/thread_pool.h index d7a2995c..f16af058 100644 --- a/include/opendht/thread_pool.h +++ b/include/opendht/thread_pool.h @@ -93,4 +93,71 @@ private: void schedule(); }; +class OPENDHT_PUBLIC ExecutionContext { +public: + ExecutionContext(ThreadPool& pool) + : threadPool_(pool), state_(std::make_shared<SharedState>()) + {} + + ~ExecutionContext() { + state_->destroy(); + } + + /** Wait for ongoing tasks to complete execution and drop other pending tasks */ + void stop() { + state_->destroy(false); + } + + void run(std::function<void()>&& task) { + std::lock_guard<std::mutex> lock(state_->mtx); + if (state_->shutdown_) return; + state_->pendingTasks++; + threadPool_.get().run([task = std::move(task), state = state_] { + state->run(task); + }); + } + +private: + struct SharedState { + std::mutex mtx {}; + std::condition_variable cv {}; + unsigned pendingTasks {0}; + unsigned ongoingTasks {0}; + /** When true, prevents new tasks to be scheduled */ + bool shutdown_ {false}; + /** When true, prevents scheduled tasks to be executed */ + std::atomic_bool destroyed {false}; + + void destroy(bool wait = true) { + std::unique_lock<std::mutex> lock(mtx); + if (destroyed) return; + if (wait) { + cv.wait(lock, [this] { return pendingTasks == 0 && ongoingTasks == 0; }); + } + shutdown_ = true; + if (not wait) { + cv.wait(lock, [this] { return ongoingTasks == 0; }); + } + destroyed = true; + } + + void run(const std::function<void()>& task) { + { + std::lock_guard<std::mutex> lock(mtx); + pendingTasks--; + ongoingTasks++; + } + if (destroyed) return; + task(); + { + std::lock_guard<std::mutex> lock(mtx); + ongoingTasks--; + cv.notify_all(); + } + } + }; + std::reference_wrapper<ThreadPool> threadPool_; + std::shared_ptr<SharedState> state_; +}; + } diff --git a/tests/threadpooltester.cpp b/tests/threadpooltester.cpp index 4295b87b..be948f94 100644 --- a/tests/threadpooltester.cpp +++ b/tests/threadpooltester.cpp @@ -59,7 +59,7 @@ ThreadPoolTester::testExecutor() auto executor8 = std::make_shared<dht::Executor>(pool, 8); constexpr unsigned N = 64 * 1024; - std::atomic_uint count1 {0}; + unsigned count1 {0}; std::atomic_uint count4 {0}; std::atomic_uint count8 {0}; for (unsigned i=0; i<N; i++) { @@ -69,7 +69,7 @@ ThreadPoolTester::testExecutor() } auto start = clock::now(); - while ((count1.load() != N || + while ((count1 != N || count4.load() != N || count8.load() != N) && clock::now() - start < std::chrono::seconds(20)) { @@ -78,11 +78,28 @@ ThreadPoolTester::testExecutor() executor1.reset(); executor4.reset(); executor8.reset(); - CPPUNIT_ASSERT_EQUAL(N, count1.load()); + CPPUNIT_ASSERT_EQUAL(N, count1); CPPUNIT_ASSERT_EQUAL(N, count4.load()); CPPUNIT_ASSERT_EQUAL(N, count8.load()); } +void +ThreadPoolTester::testContext() +{ + std::atomic_uint count {0}; + constexpr unsigned N = 64 * 1024; + + { + dht::ExecutionContext ctx(dht::ThreadPool::computation()); + for (unsigned i=0; i<N; i++) { + ctx.run([&] { count++; }); + } + } + + CPPUNIT_ASSERT_EQUAL(N, count.load()); + +} + void ThreadPoolTester::tearDown() { } diff --git a/tests/threadpooltester.h b/tests/threadpooltester.h index 4dc289c9..b8cbe15d 100644 --- a/tests/threadpooltester.h +++ b/tests/threadpooltester.h @@ -29,6 +29,7 @@ class ThreadPoolTester : public CppUnit::TestFixture { CPPUNIT_TEST_SUITE(ThreadPoolTester); CPPUNIT_TEST(testThreadPool); CPPUNIT_TEST(testExecutor); + CPPUNIT_TEST(testContext); CPPUNIT_TEST_SUITE_END(); public: @@ -43,6 +44,7 @@ class ThreadPoolTester : public CppUnit::TestFixture { void testThreadPool(); void testExecutor(); + void testContext(); }; } // namespace test -- GitLab