Merge pull request #244 from forresti/fix-comment
[platform/upstream/armcl.git] / src / runtime / CPP / CPPScheduler.cpp
1 /*
2  * Copyright (c) 2016, 2017 ARM Limited.
3  *
4  * SPDX-License-Identifier: MIT
5  *
6  * Permission is hereby granted, free of charge, to any person obtaining a copy
7  * of this software and associated documentation files (the "Software"), to
8  * deal in the Software without restriction, including without limitation the
9  * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10  * sell copies of the Software, and to permit persons to whom the Software is
11  * furnished to do so, subject to the following conditions:
12  *
13  * The above copyright notice and this permission notice shall be included in all
14  * copies or substantial portions of the Software.
15  *
16  * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17  * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18  * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19  * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20  * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21  * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22  * SOFTWARE.
23  */
24 #include "arm_compute/runtime/CPP/CPPScheduler.h"
25
26 #include "arm_compute/core/CPP/ICPPKernel.h"
27 #include "arm_compute/core/Error.h"
28 #include "arm_compute/core/Helpers.h"
29 #include "arm_compute/core/Utils.h"
30
31 #include <condition_variable>
32 #include <iostream>
33 #include <mutex>
34 #include <system_error>
35 #include <thread>
36
37 namespace arm_compute
38 {
39 class Thread
40 {
41 public:
42     /** Start a new thread. */
43     Thread();
44
45     Thread(const Thread &) = delete;
46     Thread &operator=(const Thread &) = delete;
47     Thread(Thread &&)                 = delete;
48     Thread &operator=(Thread &&) = delete;
49
50     /** Destructor. Make the thread join. */
51     ~Thread();
52
53     /** Request the worker thread to start executing the given kernel
54      * This function will return as soon as the kernel has been sent to the worker thread.
55      * wait() needs to be called to ensure the execution is complete.
56      */
57     void start(ICPPKernel *kernel, const Window &window, const ThreadInfo &info);
58
59     /** Wait for the current kernel execution to complete. */
60     void wait();
61
62     /** Function ran by the worker thread. */
63     void worker_thread();
64
65 private:
66     std::thread             _thread;
67     ICPPKernel             *_kernel{ nullptr };
68     Window                  _window;
69     ThreadInfo              _info;
70     std::mutex              _m;
71     std::condition_variable _cv;
72     bool                    _wait_for_work{ false };
73     bool                    _job_complete{ true };
74     std::exception_ptr      _current_exception;
75 };
76
77 Thread::Thread()
78     : _thread(), _window(), _info(), _m(), _cv(), _current_exception(nullptr)
79 {
80     _thread = std::thread(&Thread::worker_thread, this);
81 }
82
83 Thread::~Thread()
84 {
85     // Make sure worker thread has ended
86     if(_thread.joinable())
87     {
88         start(nullptr, Window(), ThreadInfo());
89         _thread.join();
90     }
91 }
92
93 void Thread::start(ICPPKernel *kernel, const Window &window, const ThreadInfo &info)
94 {
95     _kernel = kernel;
96     _window = window;
97     _info   = info;
98
99     {
100         std::lock_guard<std::mutex> lock(_m);
101         _wait_for_work = true;
102         _job_complete  = false;
103     }
104     _cv.notify_one();
105 }
106
107 void Thread::wait()
108 {
109     {
110         std::unique_lock<std::mutex> lock(_m);
111         _cv.wait(lock, [&] { return _job_complete; });
112     }
113
114     if(_current_exception)
115     {
116         std::rethrow_exception(_current_exception);
117     }
118 }
119
120 void Thread::worker_thread()
121 {
122     while(true)
123     {
124         std::unique_lock<std::mutex> lock(_m);
125         _cv.wait(lock, [&] { return _wait_for_work; });
126         _wait_for_work = false;
127
128         _current_exception = nullptr;
129
130         // Time to exit
131         if(_kernel == nullptr)
132         {
133             return;
134         }
135
136         try
137         {
138             _window.validate();
139             _kernel->run(_window, _info);
140         }
141         catch(...)
142         {
143             _current_exception = std::current_exception();
144         }
145
146         _job_complete = true;
147         lock.unlock();
148         _cv.notify_one();
149     }
150 }
151
152 CPPScheduler &CPPScheduler::get()
153 {
154     static CPPScheduler scheduler;
155     return scheduler;
156 }
157
158 CPPScheduler::CPPScheduler()
159     : _num_threads(std::thread::hardware_concurrency()),
160       _threads(_num_threads - 1)
161 {
162 }
163
164 void CPPScheduler::set_num_threads(unsigned int num_threads)
165 {
166     _num_threads = num_threads == 0 ? std::thread::hardware_concurrency() : num_threads;
167     _threads.resize(_num_threads - 1);
168 }
169
170 unsigned int CPPScheduler::num_threads() const
171 {
172     return _num_threads;
173 }
174
175 void CPPScheduler::schedule(ICPPKernel *kernel, unsigned int split_dimension)
176 {
177     ARM_COMPUTE_ERROR_ON_MSG(!kernel, "The child class didn't set the kernel");
178
179     /** [Scheduler example] */
180     ThreadInfo info;
181     info.cpu_info = _info;
182
183     const Window      &max_window     = kernel->window();
184     const unsigned int num_iterations = max_window.num_iterations(split_dimension);
185     info.num_threads                  = std::min(num_iterations, _num_threads);
186
187     if(num_iterations == 0)
188     {
189         return;
190     }
191
192     if(!kernel->is_parallelisable() || info.num_threads == 1)
193     {
194         kernel->run(max_window, info);
195     }
196     else
197     {
198         int  t         = 0;
199         auto thread_it = _threads.begin();
200
201         for(; t < info.num_threads - 1; ++t, ++thread_it)
202         {
203             Window win     = max_window.split_window(split_dimension, t, info.num_threads);
204             info.thread_id = t;
205             thread_it->start(kernel, win, info);
206         }
207
208         // Run last part on main thread
209         Window win     = max_window.split_window(split_dimension, t, info.num_threads);
210         info.thread_id = t;
211         kernel->run(win, info);
212
213         try
214         {
215             for(auto &thread : _threads)
216             {
217                 thread.wait();
218             }
219         }
220         catch(const std::system_error &e)
221         {
222             std::cerr << "Caught system_error with code " << e.code() << " meaning " << e.what() << '\n';
223         }
224     }
225     /** [Scheduler example] */
226 }
227 } // namespace arm_compute