suanPan
thread_pool.hpp
Go to the documentation of this file.
1 #ifndef THREAD_POOL_HPP
2 #define THREAD_POOL_HPP
3 
4 #include <atomic>
5 #include <condition_variable>
6 #include <exception>
7 #include <functional>
8 #include <future>
9 #include <memory>
10 #include <mutex>
11 #include <queue>
12 #include <thread>
13 #include <type_traits>
14 #include <utility>
15 
16 class [[nodiscard]] thread_pool {
17  using concurrency_t = std::invoke_result_t<decltype(std::thread::hardware_concurrency)>;
18 
19  std::atomic<bool> running = false;
20 
21  std::condition_variable task_available_cv = {};
22 
23  std::condition_variable task_done_cv = {};
24 
25  std::queue<std::function<void()>> tasks = {};
26 
27  std::atomic<size_t> tasks_total = 0;
28 
29  mutable std::mutex tasks_mutex = {};
30 
31  concurrency_t thread_count = 0;
32 
33  std::unique_ptr<std::thread[]> threads = nullptr;
34 
35  std::atomic<bool> waiting = false;
36 
37  void create_threads() {
38  running = true;
39  for(concurrency_t i = 0; i < thread_count; ++i) threads[i] = std::thread(&thread_pool::worker, this);
40  }
41 
42  void destroy_threads() {
43  running = false;
44  {
45  const std::scoped_lock tasks_lock(tasks_mutex);
46  task_available_cv.notify_all();
47  }
48  for(concurrency_t i = 0; i < thread_count; ++i) threads[i].join();
49  }
50 
51  [[nodiscard]] static concurrency_t determine_thread_count(const concurrency_t thread_count_) { return thread_count_ > 0 ? thread_count_ : std::thread::hardware_concurrency() > 0 ? std::thread::hardware_concurrency() : 1; }
52 
53  void worker() {
54  while(running) {
55  std::unique_lock tasks_lock(tasks_mutex);
56  task_available_cv.wait(tasks_lock, [this] { return !tasks.empty() || !running; });
57  if(running) {
58  std::function<void()> task;
59  task = std::move(tasks.front());
60  tasks.pop();
61  tasks_lock.unlock();
62  task();
63  tasks_lock.lock();
64  --tasks_total;
65  if(waiting) task_done_cv.notify_one();
66  }
67  }
68  }
69 
70 public:
71  explicit thread_pool(const concurrency_t thread_count_ = 0)
72  : thread_count(determine_thread_count(thread_count_))
73  , threads(std::make_unique<std::thread[]>(determine_thread_count(thread_count_))) { create_threads(); }
74 
76  wait_for_tasks();
77  destroy_threads();
78  }
79 
80  [[nodiscard]] concurrency_t get_thread_count() const { return thread_count; }
81 
82  template<typename F, typename T1, typename T2, typename T = std::common_type_t<T1, T2>> void push_loop(T1 first_index_, T2 index_after_last_, F&& loop, size_t num_blocks = 0) {
83  T first_index = static_cast<T>(first_index_);
84  T index_after_last = static_cast<T>(index_after_last_);
85  if(num_blocks == 0) num_blocks = thread_count;
86  if(index_after_last < first_index) std::swap(index_after_last, first_index);
87  const size_t total_size = static_cast<size_t>(index_after_last - first_index);
88  size_t block_size = total_size / num_blocks;
89  if(block_size == 0) {
90  block_size = 1;
91  num_blocks = (total_size > 1) ? total_size : 1;
92  }
93  if(total_size > 0) { for(size_t i = 0; i < num_blocks; ++i) push_task(std::forward<F>(loop), static_cast<T>(i * block_size) + first_index, (i == num_blocks - 1) ? index_after_last : (static_cast<T>((i + 1) * block_size) + first_index)); }
94  }
95 
96  template<typename F, typename T> void push_loop(const T index_after_last, F&& loop, const size_t num_blocks = 0) { push_loop(0, index_after_last, std::forward<F>(loop), num_blocks); }
97 
98  template<typename F, typename... A> void push_task(F&& task, A&&... args) {
99  {
100  const std::function<void()> task_function = std::bind(std::forward<F>(task), std::forward<A>(args)...);
101  const std::scoped_lock tasks_lock(tasks_mutex);
102  tasks.push(task_function);
103  ++tasks_total;
104  }
105  task_available_cv.notify_one();
106  }
107 
108  template<typename F, typename... A, typename R = std::invoke_result_t<std::decay_t<F>, std::decay_t<A>...>> [[nodiscard]] std::future<R> submit(F&& task, A&&... args) {
109  std::function<R()> task_function = std::bind(std::forward<F>(task), std::forward<A>(args)...);
110  std::shared_ptr<std::promise<R>> task_promise = std::make_shared<std::promise<R>>();
111  push_task([task_function, task_promise] {
112  try {
113  if constexpr(std::is_void_v<R>) {
114  std::invoke(task_function);
115  task_promise->set_value();
116  }
117  else { task_promise->set_value(std::invoke(task_function)); }
118  }
119  catch(...) {
120  try { task_promise->set_exception(std::current_exception()); }
121  catch(...) {}
122  }
123  });
124  return task_promise->get_future();
125  }
126 
127  void wait_for_tasks() {
128  if(waiting) return;
129  waiting = true;
130  std::unique_lock tasks_lock(tasks_mutex);
131  task_done_cv.wait(tasks_lock, [this] { return (tasks_total == 0); });
132  waiting = false;
133  }
134 };
135 
136 #endif
Definition: thread_pool.hpp:16
thread_pool(const concurrency_t thread_count_=0)
Definition: thread_pool.hpp:71
void push_task(F &&task, A &&... args)
Definition: thread_pool.hpp:98
std::future< R > submit(F &&task, A &&... args)
Definition: thread_pool.hpp:108
void wait_for_tasks()
Definition: thread_pool.hpp:127
void push_loop(T1 first_index_, T2 index_after_last_, F &&loop, size_t num_blocks=0)
Definition: thread_pool.hpp:82
~thread_pool()
Definition: thread_pool.hpp:75
concurrency_t get_thread_count() const
Definition: thread_pool.hpp:80
void push_loop(const T index_after_last, F &&loop, const size_t num_blocks=0)
Definition: thread_pool.hpp:96