2 * Copyright (c) 2020 Samsung Electronics Co., Ltd.
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
8 * http://www.apache.org/licenses/LICENSE-2.0
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
18 #include "thread-pool.h"
25 template<typename T, typename... Args>
26 std::unique_ptr<T> make_unique(Args&&... args)
28 return std::unique_ptr<T>(new T(std::forward<Args>(args)...));
33 * WorkerThread executes tasks submitted to the pool
39 * @brief Constructor of worker thread
40 * @param index Thread index assigned to the object during pool initialisation
42 explicit WorkerThread(uint32_t index);
45 * @brief Destructor of the worker thread
49 WorkerThread(const WorkerThread& other) = delete;
50 WorkerThread& operator=(const WorkerThread& other) = delete;
53 * @brief Adds task to the task queue
54 * @param task Task to be executed by the thread
56 void AddTask(Task task);
59 * @brief Wakes up thread.
64 * @brief Waits for the thread to complete all the tasks currently in the queue.
70 * @brief Internal thread loop function
72 void WaitAndExecute();
77 std::mutex mTaskQueueMutex;
78 std::condition_variable mConditionVariable;
80 bool mTerminating{false};
83 void WorkerThread::WaitAndExecute()
90 std::unique_lock<std::mutex> lock{mTaskQueueMutex};
92 mConditionVariable.wait(lock, [this]() -> bool {
93 return !mTaskQueue.empty() || mTerminating;
101 task = mTaskQueue.front();
107 std::lock_guard<std::mutex> lock{mTaskQueueMutex};
111 mConditionVariable.notify_one();
116 WorkerThread::WorkerThread(uint32_t index)
119 // Have to pass "this" as an argument because WaitAndExecute is a member function.
120 mWorker = std::thread{&WorkerThread::WaitAndExecute, this};
123 WorkerThread::~WorkerThread()
125 if(mWorker.joinable())
131 std::lock_guard<std::mutex> lock{mTaskQueueMutex};
133 mConditionVariable.notify_one();
140 void WorkerThread::AddTask(Task task)
142 std::lock_guard<std::mutex> lock{mTaskQueueMutex};
143 mTaskQueue.push(std::move(task));
144 mConditionVariable.notify_one();
147 void WorkerThread::Notify()
149 std::lock_guard<std::mutex> lock{mTaskQueueMutex};
150 mConditionVariable.notify_one();
153 void WorkerThread::Wait()
155 std::unique_lock<std::mutex> lock{mTaskQueueMutex};
156 mConditionVariable.wait(lock, [this]() -> bool {
157 return mTaskQueue.empty();
161 // ThreadPool -----------------------------------------------------------------------------------------------
163 struct ThreadPool::Impl
165 std::vector<std::unique_ptr<WorkerThread>> mWorkers;
166 uint32_t mWorkerIndex{0u};
169 ThreadPool::ThreadPool()
171 mImpl = make_unique<Impl>();
174 ThreadPool::~ThreadPool() = default;
176 bool ThreadPool::Initialize(uint32_t threadCount)
179 * Get the system's supported thread count.
181 auto thread_count = threadCount + 1;
184 thread_count = std::thread::hardware_concurrency();
192 * Spawn the worker threads.
194 for(auto i = 0u; i < thread_count - 1; i++)
197 * The workers will execute an infinite loop function
198 * and will wait for a job to enter the job queue. Once a job is in the the queue
199 * the threads will wake up to acquire and execute it.
201 mImpl->mWorkers.push_back(make_unique<WorkerThread>(i));
207 void ThreadPool::Wait()
209 for(auto& worker : mImpl->mWorkers)
215 SharedFuture ThreadPool::SubmitTask(uint32_t workerIndex, const Task& task)
217 auto future = std::shared_ptr<Future<void>>(new Future<void>);
218 mImpl->mWorkers[workerIndex]->AddTask([task, future](uint32_t index) {
221 future->mPromise.set_value();
227 SharedFuture ThreadPool::SubmitTasks(const std::vector<Task>& tasks)
229 auto future = std::shared_ptr<Future<void>>(new Future<void>);
231 mImpl->mWorkers[mImpl->mWorkerIndex++ % static_cast<uint32_t>(mImpl->mWorkers.size())]->AddTask(
232 [future, tasks](uint32_t index) {
233 for(auto& task : tasks)
238 future->mPromise.set_value();
244 UniqueFutureGroup ThreadPool::SubmitTasks(const std::vector<Task>& tasks, uint32_t threadMask)
246 auto retval = make_unique<FutureGroup<void>>();
249 * Use square root of number of sumbitted tasks to estimate optimal number of threads
250 * used to execute jobs
252 auto threads = uint32_t(std::log2(float(tasks.size())));
256 threads = threadMask;
259 if(threads > mImpl->mWorkers.size())
261 threads = uint32_t(mImpl->mWorkers.size());
268 auto payloadPerThread = uint32_t(tasks.size() / threads);
269 auto remaining = uint32_t(tasks.size() % threads);
271 uint32_t taskIndex = 0;
272 uint32_t taskSize = uint32_t(remaining + payloadPerThread); // add 'remaining' tasks to the very first job list
274 for(auto wt = 0u; wt < threads; ++wt)
276 auto future = std::shared_ptr<Future<void>>(new Future<void>);
277 retval->mFutures.emplace_back(future);
278 mImpl->mWorkers[mImpl->mWorkerIndex++ % static_cast<uint32_t>(mImpl->mWorkers.size())]->AddTask(
279 [future, tasks, taskIndex, taskSize](uint32_t index) {
280 auto begin = tasks.begin() + int(taskIndex);
281 auto end = begin + int(taskSize);
282 for(auto it = begin; it < end; ++it)
286 future->mPromise.set_value();
289 taskIndex += taskSize;
290 taskSize = payloadPerThread;
296 size_t ThreadPool::GetWorkerCount() const
298 return mImpl->mWorkers.size();