Skip to content
Snippets Groups Projects
Commit 694f421a authored by Adrien Béraud's avatar Adrien Béraud
Browse files

thread pool: add ExecutionContext

parent a178a8f8
No related branches found
No related tags found
No related merge requests found
...@@ -93,4 +93,71 @@ private: ...@@ -93,4 +93,71 @@ private:
void schedule(); void schedule();
}; };
class OPENDHT_PUBLIC ExecutionContext {
public:
ExecutionContext(ThreadPool& pool)
: threadPool_(pool), state_(std::make_shared<SharedState>())
{}
~ExecutionContext() {
state_->destroy();
}
/** Wait for ongoing tasks to complete execution and drop other pending tasks */
void stop() {
state_->destroy(false);
}
void run(std::function<void()>&& task) {
std::lock_guard<std::mutex> lock(state_->mtx);
if (state_->shutdown_) return;
state_->pendingTasks++;
threadPool_.get().run([task = std::move(task), state = state_] {
state->run(task);
});
}
private:
struct SharedState {
std::mutex mtx {};
std::condition_variable cv {};
unsigned pendingTasks {0};
unsigned ongoingTasks {0};
/** When true, prevents new tasks to be scheduled */
bool shutdown_ {false};
/** When true, prevents scheduled tasks to be executed */
std::atomic_bool destroyed {false};
void destroy(bool wait = true) {
std::unique_lock<std::mutex> lock(mtx);
if (destroyed) return;
if (wait) {
cv.wait(lock, [this] { return pendingTasks == 0 && ongoingTasks == 0; });
}
shutdown_ = true;
if (not wait) {
cv.wait(lock, [this] { return ongoingTasks == 0; });
}
destroyed = true;
}
void run(const std::function<void()>& task) {
{
std::lock_guard<std::mutex> lock(mtx);
pendingTasks--;
ongoingTasks++;
}
if (destroyed) return;
task();
{
std::lock_guard<std::mutex> lock(mtx);
ongoingTasks--;
cv.notify_all();
}
}
};
std::reference_wrapper<ThreadPool> threadPool_;
std::shared_ptr<SharedState> state_;
};
} }
...@@ -59,7 +59,7 @@ ThreadPoolTester::testExecutor() ...@@ -59,7 +59,7 @@ ThreadPoolTester::testExecutor()
auto executor8 = std::make_shared<dht::Executor>(pool, 8); auto executor8 = std::make_shared<dht::Executor>(pool, 8);
constexpr unsigned N = 64 * 1024; constexpr unsigned N = 64 * 1024;
std::atomic_uint count1 {0}; unsigned count1 {0};
std::atomic_uint count4 {0}; std::atomic_uint count4 {0};
std::atomic_uint count8 {0}; std::atomic_uint count8 {0};
for (unsigned i=0; i<N; i++) { for (unsigned i=0; i<N; i++) {
...@@ -69,7 +69,7 @@ ThreadPoolTester::testExecutor() ...@@ -69,7 +69,7 @@ ThreadPoolTester::testExecutor()
} }
auto start = clock::now(); auto start = clock::now();
while ((count1.load() != N || while ((count1 != N ||
count4.load() != N || count4.load() != N ||
count8.load() != N) && clock::now() - start < std::chrono::seconds(20)) count8.load() != N) && clock::now() - start < std::chrono::seconds(20))
{ {
...@@ -78,11 +78,28 @@ ThreadPoolTester::testExecutor() ...@@ -78,11 +78,28 @@ ThreadPoolTester::testExecutor()
executor1.reset(); executor1.reset();
executor4.reset(); executor4.reset();
executor8.reset(); executor8.reset();
CPPUNIT_ASSERT_EQUAL(N, count1.load()); CPPUNIT_ASSERT_EQUAL(N, count1);
CPPUNIT_ASSERT_EQUAL(N, count4.load()); CPPUNIT_ASSERT_EQUAL(N, count4.load());
CPPUNIT_ASSERT_EQUAL(N, count8.load()); CPPUNIT_ASSERT_EQUAL(N, count8.load());
} }
void
ThreadPoolTester::testContext()
{
std::atomic_uint count {0};
constexpr unsigned N = 64 * 1024;
{
dht::ExecutionContext ctx(dht::ThreadPool::computation());
for (unsigned i=0; i<N; i++) {
ctx.run([&] { count++; });
}
}
CPPUNIT_ASSERT_EQUAL(N, count.load());
}
void void
ThreadPoolTester::tearDown() { ThreadPoolTester::tearDown() {
} }
......
...@@ -29,6 +29,7 @@ class ThreadPoolTester : public CppUnit::TestFixture { ...@@ -29,6 +29,7 @@ class ThreadPoolTester : public CppUnit::TestFixture {
CPPUNIT_TEST_SUITE(ThreadPoolTester); CPPUNIT_TEST_SUITE(ThreadPoolTester);
CPPUNIT_TEST(testThreadPool); CPPUNIT_TEST(testThreadPool);
CPPUNIT_TEST(testExecutor); CPPUNIT_TEST(testExecutor);
CPPUNIT_TEST(testContext);
CPPUNIT_TEST_SUITE_END(); CPPUNIT_TEST_SUITE_END();
public: public:
...@@ -43,6 +44,7 @@ class ThreadPoolTester : public CppUnit::TestFixture { ...@@ -43,6 +44,7 @@ class ThreadPoolTester : public CppUnit::TestFixture {
void testThreadPool(); void testThreadPool();
void testExecutor(); void testExecutor();
void testContext();
}; };
} // namespace test } // namespace test
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment