suanPan
🧮 An Open Source, Parallel and Heterogeneous Finite Element Analysis Framework
Loading...
Searching...
No Matches
thread_pool.hpp
Go to the documentation of this file.
1#ifndef THREAD_POOL_HPP
2#define THREAD_POOL_HPP
3
4#include <atomic>
5#include <condition_variable>
6#include <exception>
7#include <functional>
8#include <future>
9#include <memory>
10#include <mutex>
11#include <queue>
12#include <thread>
13#include <type_traits>
14#include <utility>
15
16class [[nodiscard]] thread_pool {
17 using concurrency_t = std::invoke_result_t<decltype(std::thread::hardware_concurrency)>;
18
19 std::atomic<bool> running = false;
20
21 std::condition_variable task_available_cv = {};
22
23 std::condition_variable task_done_cv = {};
24
25 std::queue<std::function<void()>> tasks = {};
26
27 std::atomic<size_t> tasks_total = 0;
28
29 mutable std::mutex tasks_mutex = {};
30
31 concurrency_t thread_count = 0;
32
33 std::unique_ptr<std::thread[]> threads = nullptr;
34
35 std::atomic<bool> waiting = false;
36
37 void create_threads() {
38 running = true;
39 for(concurrency_t i = 0; i < thread_count; ++i) threads[i] = std::thread(&thread_pool::worker, this);
40 }
41
42 void destroy_threads() {
43 running = false;
44 {
45 const std::scoped_lock tasks_lock(tasks_mutex);
46 task_available_cv.notify_all();
47 }
48 for(concurrency_t i = 0; i < thread_count; ++i) threads[i].join();
49 }
50
51 [[nodiscard]] static concurrency_t determine_thread_count(const concurrency_t thread_count_) { return thread_count_ > 0 ? thread_count_ : std::thread::hardware_concurrency() > 0 ? std::thread::hardware_concurrency() :
52 1; }
53
54 void worker() {
55 while(running) {
56 std::unique_lock tasks_lock(tasks_mutex);
57 task_available_cv.wait(tasks_lock, [this] { return !tasks.empty() || !running; });
58 if(running) {
59 std::function<void()> task;
60 task = std::move(tasks.front());
61 tasks.pop();
62 tasks_lock.unlock();
63 task();
64 tasks_lock.lock();
65 --tasks_total;
66 if(waiting) task_done_cv.notify_one();
67 }
68 }
69 }
70
71public:
72 explicit thread_pool(const concurrency_t thread_count_ = 0)
73 : thread_count(determine_thread_count(thread_count_))
74 , threads(std::make_unique<std::thread[]>(determine_thread_count(thread_count_))) { create_threads(); }
75
77 wait_for_tasks();
78 destroy_threads();
79 }
80
81 [[nodiscard]] concurrency_t get_thread_count() const { return thread_count; }
82
83 template<typename F, typename T1, typename T2, typename T = std::common_type_t<T1, T2>> void push_loop(T1 first_index_, T2 index_after_last_, F&& loop, size_t num_blocks = 0) {
84 T first_index = static_cast<T>(first_index_);
85 T index_after_last = static_cast<T>(index_after_last_);
86 if(num_blocks == 0) num_blocks = thread_count;
87 if(index_after_last < first_index) std::swap(index_after_last, first_index);
88 const size_t total_size = static_cast<size_t>(index_after_last - first_index);
89 size_t block_size = total_size / num_blocks;
90 if(block_size == 0) {
91 block_size = 1;
92 num_blocks = (total_size > 1) ? total_size : 1;
93 }
94 if(total_size > 0) {
95 for(size_t i = 0; i < num_blocks; ++i) push_task(std::forward<F>(loop), static_cast<T>(i * block_size) + first_index, (i == num_blocks - 1) ? index_after_last : (static_cast<T>((i + 1) * block_size) + first_index));
96 }
97 }
98
99 template<typename F, typename T> void push_loop(const T index_after_last, F&& loop, const size_t num_blocks = 0) { push_loop(0, index_after_last, std::forward<F>(loop), num_blocks); }
100
101 template<typename F, typename... A> void push_task(F&& task, A&&... args) {
102 {
103 const std::function<void()> task_function = std::bind(std::forward<F>(task), std::forward<A>(args)...);
104 const std::scoped_lock tasks_lock(tasks_mutex);
105 tasks.push(task_function);
106 ++tasks_total;
107 }
108 task_available_cv.notify_one();
109 }
110
111 template<typename F, typename... A, typename R = std::invoke_result_t<std::decay_t<F>, std::decay_t<A>...>> [[nodiscard]] std::future<R> submit(F&& task, A&&... args) {
112 std::function<R()> task_function = std::bind(std::forward<F>(task), std::forward<A>(args)...);
113 std::shared_ptr<std::promise<R>> task_promise = std::make_shared<std::promise<R>>();
114 push_task([task_function, task_promise] {
115 try {
116 if constexpr(std::is_void_v<R>) {
117 std::invoke(task_function);
118 task_promise->set_value();
119 }
120 else { task_promise->set_value(std::invoke(task_function)); }
121 }
122 catch(...) {
123 try {
124 task_promise->set_exception(std::current_exception());
125 }
126 catch(...) {
127 }
128 }
129 });
130 return task_promise->get_future();
131 }
132
134 if(waiting) return;
135 waiting = true;
136 std::unique_lock tasks_lock(tasks_mutex);
137 task_done_cv.wait(tasks_lock, [this] { return (tasks_total == 0); });
138 waiting = false;
139 }
140};
141
142#endif
Definition thread_pool.hpp:16
thread_pool(const concurrency_t thread_count_=0)
Definition thread_pool.hpp:72
void push_task(F &&task, A &&... args)
Definition thread_pool.hpp:101
void wait_for_tasks()
Definition thread_pool.hpp:133
void push_loop(T1 first_index_, T2 index_after_last_, F &&loop, size_t num_blocks=0)
Definition thread_pool.hpp:83
~thread_pool()
Definition thread_pool.hpp:76
std::future< R > submit(F &&task, A &&... args)
Definition thread_pool.hpp:111
concurrency_t get_thread_count() const
Definition thread_pool.hpp:81
void push_loop(const T index_after_last, F &&loop, const size_t num_blocks=0)
Definition thread_pool.hpp:99