2 * Copyright (c) 2018 Samsung Electronics Co., Ltd. All Rights Reserved
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
8 * http://www.apache.org/licenses/LICENSE-2.0
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
18 * Copyright (c) 2017 ARM Limited.
20 * SPDX-License-Identifier: MIT
22 * Permission is hereby granted, free of charge, to any person obtaining a copy
23 * of this software and associated documentation files (the "Software"), to
24 * deal in the Software without restriction, including without limitation the
25 * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
26 * sell copies of the Software, and to permit persons to whom the Software is
27 * furnished to do so, subject to the following conditions:
29 * The above copyright notice and this permission notice shall be included in all
30 * copies or substantial portions of the Software.
32 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
33 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
34 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
35 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
36 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
37 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
41 #include "arm_compute/runtime/CL/functions/CLTopKV2.h"
42 #include "arm_compute/runtime/CL/CLScheduler.h"
44 #include "arm_compute/core/CL/ICLTensor.h"
46 #include "../../topk_v2.h"
52 : _k(0), _total_bits(0), _bits(0), _radix(0), _hist_buf_size(0), _glob_sum_buf_size(0), _n(0),
53 _input(nullptr), _values(nullptr), _indices(nullptr), _qs_idx_buf(), _qs_temp_buf(),
54 _hist_buf(), _glob_sum_buf(), _temp_buf(), _first_negative_idx_buf(), _in_key_buf(),
55 _out_key_buf(), _in_ind_buf(), _out_ind_buf(), _p_in_key_buf(nullptr),
56 _p_out_key_buf(nullptr), _p_in_ind_buf(nullptr), _p_out_ind_buf(nullptr) /*, _qs_kernel(),
57 _init_kernel(), _hist_kernel(), _scan_hist_kernel(), _glob_scan_hist_kernel(),
58 _paste_hist_kernel(), _reorder_kernel(), _find_first_negative_kernel(),
59 _reorder_negatives_kernel(), _store_kernel()*/
63 void CLTopKV2::configure(ICLTensor *input, int k, ICLTensor *values, ICLTensor *indices,
64 int total_bits, int bits)
66 _total_bits = total_bits;
68 _n = input->info()->tensor_shape()[0];
70 // _total_bits should be divided by _bits.
71 ARM_COMPUTE_ERROR_ON((_total_bits % _bits) != 0);
82 // Disable GPU implementation
83 // TODO Enable GPU implementation with verification, or remove code
84 // Invalid result on GPU
86 char *env = getenv("ACL_TOPKV2");
90 if (topk_env == "GPU_SINGLE")
92 _qs_idx_buf = cl::Buffer(CLScheduler::get().context(),
93 CL_MEM_ALLOC_HOST_PTR | CL_MEM_READ_WRITE, sizeof(cl_int) * _n);
94 _qs_temp_buf = cl::Buffer(CLScheduler::get().context(),
95 CL_MEM_ALLOC_HOST_PTR | CL_MEM_READ_WRITE, sizeof(cl_int) * _n);
97 _qs_kernel.configure(input, values, indices, &_qs_idx_buf, &_qs_temp_buf, k, _n);
99 else if (topk_env == "GPU")
101 // n should be divided by (_GROUPS * _ITEMS)
102 ARM_COMPUTE_ERROR_ON((_n % (_GROUPS * _ITEMS)) != 0);
104 _hist_buf_size = _radix * _GROUPS * _ITEMS;
105 _glob_sum_buf_size = _HISTOSPLIT;
107 _hist_buf = cl::Buffer(CLScheduler::get().context(), CL_MEM_ALLOC_HOST_PTR | CL_MEM_READ_WRITE,
108 sizeof(cl_int) * _hist_buf_size);
110 cl::Buffer(CLScheduler::get().context(), CL_MEM_ALLOC_HOST_PTR | CL_MEM_READ_WRITE,
111 sizeof(cl_int) * _glob_sum_buf_size);
112 _temp_buf = cl::Buffer(CLScheduler::get().context(), CL_MEM_ALLOC_HOST_PTR | CL_MEM_READ_WRITE,
113 sizeof(cl_int) * _glob_sum_buf_size);
114 _first_negative_idx_buf = cl::Buffer(CLScheduler::get().context(),
115 CL_MEM_ALLOC_HOST_PTR | CL_MEM_READ_WRITE, sizeof(cl_int));
116 _in_key_buf = cl::Buffer(CLScheduler::get().context(),
117 CL_MEM_ALLOC_HOST_PTR | CL_MEM_READ_WRITE, sizeof(cl_float) * _n);
118 _out_key_buf = cl::Buffer(CLScheduler::get().context(),
119 CL_MEM_ALLOC_HOST_PTR | CL_MEM_READ_WRITE, sizeof(cl_float) * _n);
120 _in_ind_buf = cl::Buffer(CLScheduler::get().context(),
121 CL_MEM_ALLOC_HOST_PTR | CL_MEM_READ_WRITE, sizeof(cl_int) * _n);
122 _out_ind_buf = cl::Buffer(CLScheduler::get().context(),
123 CL_MEM_ALLOC_HOST_PTR | CL_MEM_READ_WRITE, sizeof(cl_int) * _n);
125 _p_in_key_buf = &_in_key_buf;
126 _p_out_key_buf = &_out_key_buf;
127 _p_in_ind_buf = &_in_ind_buf;
128 _p_out_ind_buf = &_out_ind_buf;
130 _init_kernel.configure(input, _p_in_key_buf, _p_in_ind_buf, _n);
131 _hist_kernel.configure(&_hist_buf, bits, _n);
132 _scan_hist_kernel.configure(&_hist_buf, &_glob_sum_buf, bits);
133 _glob_scan_hist_kernel.configure(&_glob_sum_buf, &_temp_buf, bits);
134 _paste_hist_kernel.configure(&_hist_buf, &_glob_sum_buf, bits);
135 _reorder_kernel.configure(&_hist_buf, bits, _n);
136 _find_first_negative_kernel.configure(&_first_negative_idx_buf, _n);
137 _reorder_negatives_kernel.configure(&_first_negative_idx_buf, _n);
138 _store_kernel.configure(values, indices, k, _n);
141 #endif // Disable GPU implementation
143 // DO NOTHING for CPU.
149 std::string topk_env;
151 char *env = getenv("ACL_TOPKV2");
155 if (topk_env == "GPU_SINGLE")
157 run_on_gpu_single_quicksort();
159 else if (topk_env == "GPU")
171 void CLTopKV2::run_on_gpu_single_quicksort()
173 // This is a single threaded quick sort implementation.
174 CLScheduler::get().enqueue(_qs_kernel, false);
176 arm_compute::CLScheduler::get().sync();
179 void CLTopKV2::run_on_gpu()
181 cl::CommandQueue q = CLScheduler::get().queue();
183 // 1. CLTopKV2Init set key buffer and index buffer.
184 // - Key buffer is set as the same value of the layer's input
185 // - Values in the index buffer are set as their indices.
186 CLScheduler::get().enqueue(_init_kernel, false);
188 int n_passes = _total_bits / _bits;
190 // 2. Repeat (total_bits/bits) times.
191 // - total_bits is the number of bits of the data type (e.g., 32 for float)
192 // - bits defines number of buckets (e.g. 16 buckets where bit is 4)
193 for (int pass = 0; pass < n_passes; ++pass)
195 arm_compute::CLScheduler::get().sync();
197 // 2.1. Calculate histogram with _GROUPS * _ITEMS threads
198 _hist_kernel.setPass(pass, _p_in_key_buf);
199 CLScheduler::get().enqueue(_hist_kernel, false);
201 // 2.2. Calculate prefix sum locally with multiple threads
202 CLScheduler::get().enqueue(_scan_hist_kernel, false);
203 // 2.3. Calculate prefix sum within a work group
204 CLScheduler::get().enqueue(_glob_scan_hist_kernel, false);
205 // 2.4. Calculate global prefix sum
206 CLScheduler::get().enqueue(_paste_hist_kernel, false);
208 // 2.5. Reorder keys and indices based on the global prefix sum
209 _reorder_kernel.setPass(pass, _p_in_key_buf, _p_out_key_buf, _p_in_ind_buf, _p_out_ind_buf);
210 CLScheduler::get().enqueue(_reorder_kernel, false);
215 _p_in_key_buf = _p_out_key_buf;
216 _p_out_key_buf = tmp;
218 // swap index buffers
220 _p_in_ind_buf = _p_out_ind_buf;
221 _p_out_ind_buf = tmp;
224 // 3. Get the first negative index
225 // Because we swap in_buf and out_buf at the end of the above for loop,
226 // the output buffers are in bufs.
227 _find_first_negative_kernel.setOutputBuffer(_p_in_key_buf);
228 CLScheduler::get().enqueue(_find_first_negative_kernel, false);
230 // 4. Correct odering of negatives
231 // - Since radix sort does not consider negatives, negatives are considered as bigger values
233 // reordered data will be stored in _p_out_key_buf and _p_out_ind_buf
234 _reorder_negatives_kernel.setBuffers(_p_in_key_buf, _p_out_key_buf, _p_in_ind_buf,
236 CLScheduler::get().enqueue(_reorder_negatives_kernel, false);
238 // 5. Extract top k values from sorted keys and indices.
239 _store_kernel.setOutputBuffers(_p_out_key_buf, _p_out_ind_buf);
240 CLScheduler::get().enqueue(_store_kernel, false);
242 arm_compute::CLScheduler::get().sync();
245 // below code is left for debugging.
247 q.enqueueReadBuffer(_first_negative_idx_buf, CL_TRUE, 0, sizeof(cl_int), &first_neg);
248 std::cout << "first neg = " << first_neg << std::endl;
251 q.enqueueReadBuffer(*_p_in_key_buf, CL_TRUE, 0, sizeof(cl_float)*_n, in_key);
252 for(uint32_t i = 0 ; i < _n; ++i) {
253 std::cout << "in_key[" << i << "] = " << in_key[i] << std::endl;
257 q.enqueueReadBuffer(*_p_out_key_buf, CL_TRUE, 0, sizeof(cl_float)*_n, out_key);
258 for(uint32_t i = 0 ; i < _n; ++i) {
259 std::cout << "out_key[" << i << "] = " << out_key[i] << std::endl;
263 q.enqueueReadBuffer(*_p_in_ind_buf, CL_TRUE, 0, sizeof(cl_int)*_n, in_ind);
264 for(uint32_t i = 0 ; i < _n; ++i) {
265 std::cout << "in_ind[" << i << "] = " << in_ind[i] << std::endl;
269 q.enqueueReadBuffer(*_p_out_ind_buf, CL_TRUE, 0, sizeof(cl_int)*_n, out_ind);
270 for(uint32_t i = 0 ; i < _n; ++i) {
271 std::cout << "out_ind[" << i << "] = " << out_ind[i] << std::endl;
274 int hist_buf[_hist_buf_size];
275 q.enqueueReadBuffer(_hist_buf, CL_TRUE, 0, sizeof(cl_int)*_hist_buf_size, hist_buf);
276 for(uint32_t i = 0 ; i < _hist_buf_size; ++i) {
277 std::cout << "hist_buf[" << i << "] = " << hist_buf[i] << std::endl;
280 int glob_sum_buf[_glob_sum_buf_size];
281 q.enqueueReadBuffer(_glob_sum_buf, CL_TRUE, 0, sizeof(cl_int)*_glob_sum_buf_size, glob_sum_buf);
282 for(uint32_t i = 0 ; i < _glob_sum_buf_size; ++i) {
283 std::cout << "glob_sum_buf[" << i << "] = " << glob_sum_buf[i] << std::endl;
288 #endif // Disable GPU implementation
290 void CLTopKV2::run_on_cpu()
292 cl::CommandQueue q = CLScheduler::get().queue();
293 // const Window& w = _topkv2_kernel.window();
299 // int row_size = (w[0].end() - w[0].start()) / w[0].step();
300 int row_size = _input->info()->tensor_shape()[0];
301 int rank = _input->info()->num_dimensions();
304 throw std::runtime_error("Not supported type.");
306 int row_num = (rank == 2 ? _input->info()->tensor_shape()[1] : 1);
308 if (_input->info()->data_type() == DataType::F32)
310 nnfw::rt::optimized_ops::TopK<float>(row_size, row_num, (float *)_input->buffer(), _k,
311 (int32 *)_indices->buffer(), (float *)_values->buffer());
313 else if (_input->info()->data_type() == DataType::S32)
315 nnfw::rt::optimized_ops::TopK<int32_t>(row_size, row_num, (int32_t *)_input->buffer(), _k,
316 (int32 *)_indices->buffer(),
317 (int32_t *)_values->buffer());
319 else if (_input->info()->data_type() == DataType::QASYMM8)
321 nnfw::rt::optimized_ops::TopK<uint8_t>(row_size, row_num, (uint8_t *)_input->buffer(), _k,
322 (int32 *)_indices->buffer(),
323 (uint8_t *)_values->buffer());
327 throw std::runtime_error("Not supported type.");
335 } // namespace arm_compute