2 * Copyright (c) 2018 Samsung Electronics Co., Ltd.
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
8 * http://www.apache.org/licenses/LICENSE-2.0
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
18 #include "thread-pool.h"
27 template<typename T, typename... Args>
28 std::unique_ptr<T> make_unique(Args&&... args)
30 return std::unique_ptr<T>(new T(std::forward<Args>(args)...));
35 * WorkerThread executes tasks submitted to the pool
42 * @brief Constructor of worker thread
43 * @param index Thread index assigned to the object during pool initialisation
45 explicit WorkerThread( uint32_t index );
48 * @brief Destructor of the worker thread
52 WorkerThread(const WorkerThread &other) = delete;
53 WorkerThread &operator=(const WorkerThread &other) = delete;
56 * @brief Adds task to the task queue
57 * @param task Task to be executed by the thread
59 void AddTask( Task task );
62 * @brief Wakes up thread.
67 * @brief Waits for the thread to complete all the tasks currently in the queue.
74 * @brief Internal thread loop function
76 void WaitAndExecute();
81 std::mutex mTaskQueueMutex;
82 std::condition_variable mConditionVariable;
84 bool mTerminating {false} ;
87 void WorkerThread::WaitAndExecute()
94 std::unique_lock< std::mutex > lock{ mTaskQueueMutex };
96 mConditionVariable.wait( lock, [ this ]() -> bool {
97 return !mTaskQueue.empty() || mTerminating;
105 task = mTaskQueue.front();
111 std::lock_guard< std::mutex > lock{ mTaskQueueMutex };
115 mConditionVariable.notify_one();
120 WorkerThread::WorkerThread(uint32_t index) : mIndex( index )
122 // Have to pass "this" as an argument because WaitAndExecute is a member function.
123 mWorker = std::thread{ &WorkerThread::WaitAndExecute, this };
126 WorkerThread::~WorkerThread()
128 if( mWorker.joinable() )
134 std::lock_guard< std::mutex > lock{ mTaskQueueMutex };
136 mConditionVariable.notify_one();
143 void WorkerThread::AddTask( Task task )
145 std::lock_guard< std::mutex > lock{ mTaskQueueMutex };
146 mTaskQueue.push( std::move( task ) );
147 mConditionVariable.notify_one();
150 void WorkerThread::Notify()
152 std::lock_guard< std::mutex > lock{ mTaskQueueMutex };
153 mConditionVariable.notify_one();
156 void WorkerThread::Wait()
158 std::unique_lock< std::mutex > lock{ mTaskQueueMutex };
159 mConditionVariable.wait( lock, [ this ]() -> bool {
160 return mTaskQueue.empty();
164 // ThreadPool -----------------------------------------------------------------------------------------------
166 struct ThreadPool::Impl
168 std::vector<std::unique_ptr<WorkerThread>> mWorkers;
169 uint32_t mWorkerIndex{ 0u };
172 ThreadPool::ThreadPool()
174 mImpl = make_unique<Impl>();
177 ThreadPool::~ThreadPool() = default;
179 bool ThreadPool::Initialize( uint32_t threadCount )
182 * Get the system's supported thread count.
184 auto thread_count = threadCount + 1;
187 thread_count = std::thread::hardware_concurrency();
195 * Spawn the worker threads.
197 for( auto i = 0u; i < thread_count - 1; i++ )
200 * The workers will execute an infinite loop function
201 * and will wait for a job to enter the job queue. Once a job is in the the queue
202 * the threads will wake up to acquire and execute it.
204 mImpl->mWorkers.push_back( make_unique< WorkerThread >( i ) );
211 void ThreadPool::Wait()
213 for( auto& worker : mImpl->mWorkers )
219 SharedFuture ThreadPool::SubmitTask( uint32_t workerIndex, const Task& task )
221 auto future = std::shared_ptr< Future< void > >( new Future< void > );
222 mImpl->mWorkers[workerIndex]->AddTask( [task, future]( uint32_t index )
226 future->mPromise.set_value();
232 SharedFuture ThreadPool::SubmitTasks( const std::vector< Task >& tasks )
234 auto future = std::shared_ptr< Future< void > >( new Future< void > );
236 mImpl->mWorkers[ mImpl->mWorkerIndex++ % static_cast< uint32_t >( mImpl->mWorkers.size() )]->AddTask(
237 [ future, tasks ]( uint32_t index ) {
238 for( auto& task : tasks )
243 future->mPromise.set_value();
250 UniqueFutureGroup ThreadPool::SubmitTasks( const std::vector< Task >& tasks, uint32_t threadMask )
252 auto retval = make_unique<FutureGroup<void>>();
255 * Use square root of number of sumbitted tasks to estimate optimal number of threads
256 * used to execute jobs
258 auto threads = uint32_t(std::log2(float(tasks.size())));
260 if( threadMask != 0 )
262 threads = threadMask;
265 if( threads > mImpl->mWorkers.size() )
267 threads = uint32_t(mImpl->mWorkers.size());
274 auto payloadPerThread = uint32_t(tasks.size() / threads);
275 auto remaining = uint32_t(tasks.size() % threads);
277 uint32_t taskIndex = 0;
278 uint32_t taskSize = uint32_t(remaining + payloadPerThread); // add 'remaining' tasks to the very first job list
280 for( auto wt = 0u; wt < threads; ++wt )
282 auto future = std::shared_ptr< Future< void > >( new Future< void > );
283 retval->mFutures.emplace_back( future );
284 mImpl->mWorkers[ mImpl->mWorkerIndex++ % static_cast< uint32_t >( mImpl->mWorkers.size() )]->AddTask(
285 [ future, tasks, taskIndex, taskSize ]( uint32_t index ) {
286 auto begin = tasks.begin() + int(taskIndex);
287 auto end = begin + int(taskSize);
288 for( auto it = begin; it < end; ++it )
292 future->mPromise.set_value();
295 taskIndex += taskSize;
296 taskSize = payloadPerThread;
302 size_t ThreadPool::GetWorkerCount() const
304 return mImpl->mWorkers.size();