Imported Upstream version 1.12.0
[platform/core/ml/nnfw.git] / compute / ARMComputeEx / src / core / NEON / kernels / NECastBoolKernel.cpp
1 /*
2  * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16
17 /*
18  * Copyright (c) 2016-2020 ARM Limited.
19  *
20  * SPDX-License-Identifier: MIT
21  *
22  * Permission is hereby granted, free of charge, to any person obtaining a copy
23  * of this software and associated documentation files (the "Software"), to
24  * deal in the Software without restriction, including without limitation the
25  * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
26  * sell copies of the Software, and to permit persons to whom the Software is
27  * furnished to do so, subject to the following conditions:
28  *
29  * The above copyright notice and this permission notice shall be included in all
30  * copies or substantial portions of the Software.
31  *
32  * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
33  * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
34  * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
35  * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
36  * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
37  * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
38  * SOFTWARE.
39  */
40 #include "arm_compute/core/NEON/kernels/NECastBoolKernel.h"
41
42 #include "arm_compute/core/CPP/Validate.h"
43 #include "arm_compute/core/Error.h"
44 #include "arm_compute/core/Helpers.h"
45 #include "arm_compute/core/ITensor.h"
46 #include "arm_compute/core/NEON/NEMath.h"
47 #include "arm_compute/core/TensorInfo.h"
48 #include "arm_compute/core/Validate.h"
49 #include "arm_compute/core/utils/misc/SaturateCast.h"
50
51 #include "arm_compute/core/NEON/wrapper/wrapper.h"
52
53 using namespace arm_compute;
54
55 namespace
56 {
57 Status validate_arguments(const ITensorInfo *input, const ITensorInfo *output)
58 {
59   ARM_COMPUTE_RETURN_ERROR_ON_CPU_F16_UNSUPPORTED(output);
60   ARM_COMPUTE_RETURN_ERROR_ON(input == output);
61   ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::U8);
62   ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(output, 1, DataType::U8, DataType::S8,
63                                                        DataType::S16, DataType::U16, DataType::F16,
64                                                        DataType::U32, DataType::S32, DataType::F32);
65
66   // Validate in case of configured output
67   if (output->total_size() > 0)
68   {
69     ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_SHAPES(input, output);
70   }
71
72   return Status{};
73 }
74 } // namespace
75
76 NECastBoolKernel::NECastBoolKernel() : _input(nullptr), _output(nullptr) {}
77
78 void NECastBoolKernel::configure(const ITensor *input, ITensor *output)
79 {
80   ARM_COMPUTE_ERROR_ON_NULLPTR(input, output);
81
82   // Auto initialize output shape if not initialized (We can only auto-configure the shape, datatype
83   // must be given)
84   set_shape_if_empty(*output->info(), input->info()->tensor_shape());
85
86   _input = input;
87   _output = output;
88
89   ARM_COMPUTE_ERROR_THROW_ON(validate_arguments(input->info(), output->info()));
90
91   // Configure kernel window
92   Window win = calculate_max_window(*input->info(), Steps());
93   Coordinates coord;
94   coord.set_num_dimensions(output->info()->num_dimensions());
95   output->info()->set_valid_region(ValidRegion(coord, output->info()->tensor_shape()));
96
97   ICPPKernel::configure(win);
98 }
99
100 Status NECastBoolKernel::validate(const ITensorInfo *input, const ITensorInfo *output)
101 {
102   ARM_COMPUTE_RETURN_ON_ERROR(validate_arguments(input, output));
103   return Status{};
104 }
105
106 void NECastBoolKernel::run(const Window &window, const ThreadInfo &info)
107 {
108   ARM_COMPUTE_UNUSED(info);
109   ARM_COMPUTE_ERROR_ON_UNCONFIGURED_KERNEL(this);
110   ARM_COMPUTE_ERROR_ON_INVALID_SUBWINDOW(IKernel::window(), window);
111   ARM_COMPUTE_ERROR_ON_NULLPTR(_input, _output);
112   ARM_COMPUTE_ERROR_ON(_input == _output);
113
114   const auto window_start_x = static_cast<int>(window.x().start());
115   const auto window_end_x = static_cast<int>(window.x().end());
116   const int window_step_x = 16;
117
118   Window win{window};
119   win.set(Window::DimX, Window::Dimension(0, 1, 1));
120
121   Iterator input(_input, win);
122   Iterator output(_output, win);
123
124   const uint8_t true_val = 1;
125   const uint8x8_t mask_bool = vdup_n_u8(true_val);
126
127   switch (_output->info()->data_type())
128   {
129     case DataType::S8:
130     {
131       /* Conversion U8 -> S8 */
132       execute_window_loop(
133         win,
134         [&](const Coordinates &) {
135           const auto input_ptr = reinterpret_cast<const uint8_t *>(input.ptr());
136           const auto output_ptr = reinterpret_cast<int8_t *>(output.ptr());
137
138           int x = window_start_x;
139           for (; x <= (window_end_x - window_step_x); x += window_step_x)
140           {
141             const uint8x16_t texels_u8 = vld1q_u8(input_ptr + x);
142
143             vst1q_s8(output_ptr + x,
144                      vreinterpretq_s8_u8(vandq_u8(texels_u8, vdupq_n_u8(true_val))));
145           }
146
147           // Compute left-over elements
148           for (; x < window_end_x; ++x)
149           {
150             *(output_ptr + x) = static_cast<int8_t>(*(input_ptr + x) & true_val);
151           }
152         },
153         input, output);
154       break;
155     }
156     case DataType::S16:
157     {
158       /* Up-conversion U8 -> S16 */
159       execute_window_loop(
160         win,
161         [&](const Coordinates &) {
162           const auto input_ptr = reinterpret_cast<const uint8_t *>(input.ptr());
163           const auto output_ptr = reinterpret_cast<int16_t *>(output.ptr());
164
165           int x = window_start_x;
166           for (; x <= (window_end_x - window_step_x); x += window_step_x)
167           {
168             const uint8x16_t texels_u8 = vld1q_u8(input_ptr + x);
169
170             const int16x8x2_t texels = {
171               {vreinterpretq_s16_u16(vmovl_u8(vand_u8(vget_low_u8(texels_u8), mask_bool))),
172                vreinterpretq_s16_u16(vmovl_u8(vand_u8(vget_high_u8(texels_u8), mask_bool)))}};
173
174             vst1q_s16(output_ptr + x, texels.val[0]);
175             vst1q_s16(output_ptr + x + 8, texels.val[1]);
176           }
177
178           // Compute left-over elements
179           for (; x < window_end_x; ++x)
180           {
181             *(output_ptr + x) = static_cast<int32_t>(*(input_ptr + x) & true_val);
182           }
183         },
184         input, output);
185       break;
186     }
187     case DataType::S32:
188     {
189       /* Up-conversion U8 -> S32 */
190       execute_window_loop(
191         win,
192         [&](const Coordinates &) {
193           const auto input_ptr = reinterpret_cast<const uint8_t *>(input.ptr());
194           const auto output_ptr = reinterpret_cast<int32_t *>(output.ptr());
195
196           int x = window_start_x;
197           for (; x <= (window_end_x - window_step_x); x += window_step_x)
198           {
199             const uint8x16_t texels_u8 = vld1q_u8(input_ptr + x);
200
201             const int16x8x2_t texels = {
202               {vreinterpretq_s16_u16(vmovl_u8(vand_u8(vget_low_u8(texels_u8), mask_bool))),
203                vreinterpretq_s16_u16(vmovl_u8(vand_u8(vget_high_u8(texels_u8), mask_bool)))}};
204
205             vst1q_s32(output_ptr + x, vmovl_s16(vget_low_s16(texels.val[0])));
206             vst1q_s32(output_ptr + x + 4, vmovl_s16(vget_high_s16(texels.val[0])));
207             vst1q_s32(output_ptr + x + 8, vmovl_s16(vget_low_s16(texels.val[1])));
208             vst1q_s32(output_ptr + x + 12, vmovl_s16(vget_high_s16(texels.val[1])));
209           }
210
211           // Compute left-over elements
212           for (; x < window_end_x; ++x)
213           {
214             *(output_ptr + x) = static_cast<uint32_t>(*(input_ptr + x) & true_val);
215           }
216         },
217         input, output);
218       break;
219     }
220     case DataType::F32:
221     {
222       /* Up-conversion U8 -> F32 */
223       execute_window_loop(
224         win,
225         [&](const Coordinates &) {
226           const auto input_ptr = reinterpret_cast<const uint8_t *>(input.ptr());
227           const auto output_ptr = reinterpret_cast<float *>(output.ptr());
228
229           int x = window_start_x;
230           for (; x <= (window_end_x - window_step_x); x += window_step_x)
231           {
232             const uint8x16_t texels_u8 = vld1q_u8(input_ptr + x);
233
234             const int16x8x2_t texels = {
235               {vreinterpretq_s16_u16(vmovl_u8(vand_u8(vget_low_u8(texels_u8), mask_bool))),
236                vreinterpretq_s16_u16(vmovl_u8(vand_u8(vget_high_u8(texels_u8), mask_bool)))}};
237             vst1q_f32(output_ptr + x, vcvtq_f32_s32(vmovl_s16(vget_low_s16(texels.val[0]))));
238             vst1q_f32(output_ptr + x + 4, vcvtq_f32_s32(vmovl_s16(vget_high_s16(texels.val[0]))));
239             vst1q_f32(output_ptr + x + 8, vcvtq_f32_s32(vmovl_s16(vget_low_s16(texels.val[1]))));
240             vst1q_f32(output_ptr + x + 12, vcvtq_f32_s32(vmovl_s16(vget_high_s16(texels.val[1]))));
241           }
242
243           // Compute left-over elements
244           for (; x < window_end_x; ++x)
245           {
246             auto in = static_cast<uint32_t>(*(input_ptr + x) & true_val);
247             *(output_ptr + x) = static_cast<float>(in);
248           }
249         },
250         input, output);
251       break;
252     }
253 #ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
254     case DataType::F16:
255     {
256       /* Up-conversion U8 -> F16 */
257       execute_window_loop(
258         win,
259         [&](const Coordinates &) {
260           const auto input_ptr = reinterpret_cast<const uint8_t *>(input.ptr());
261           const auto output_ptr = reinterpret_cast<float16_t *>(output.ptr());
262
263           int x = window_start_x;
264           for (; x <= (window_end_x - window_step_x); x += window_step_x)
265           {
266             const uint8x16_t texels_u8 = vld1q_u8(input_ptr + x);
267
268             const int16x8x2_t texels = {
269               {vreinterpretq_s16_u16(vmovl_u8(vand_u8(vget_low_u8(texels_u8), mask_bool))),
270                vreinterpretq_s16_u16(vmovl_u8(vand_u8(vget_high_u8(texels_u8), mask_bool)))}};
271             vst1q_f16(output_ptr + x, vcvtq_f16_s16(texels.val[0]));
272             vst1q_f16(output_ptr + x + 8, vcvtq_f16_s16(texels.val[1]));
273           }
274
275           // Compute left-over elements
276           for (; x < window_end_x; ++x)
277           {
278             *(output_ptr + x) = static_cast<float16_t>(*(input_ptr + x) & true_val);
279           }
280         },
281         input, output);
282       break;
283     }
284 #endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
285     case DataType::U8:
286     {
287       /* Conversion U8 -> S8 */
288       execute_window_loop(
289         win,
290         [&](const Coordinates &) {
291           const auto input_ptr = reinterpret_cast<const uint8_t *>(input.ptr());
292           const auto output_ptr = reinterpret_cast<uint8_t *>(output.ptr());
293
294           int x = window_start_x;
295           for (; x <= (window_end_x - window_step_x); x += window_step_x)
296           {
297             const uint8x16_t texels_u8 = vld1q_u8(input_ptr + x);
298
299             vst1q_u8(output_ptr + x, vandq_u8(texels_u8, vdupq_n_u8(true_val)));
300           }
301
302           // Compute left-over elements
303           for (; x < window_end_x; ++x)
304           {
305             *(output_ptr + x) = static_cast<uint8_t>(*(input_ptr + x) & true_val);
306           }
307         },
308         input, output);
309       break;
310     }
311     case DataType::U16:
312     {
313       /* Up-conversion U8 -> U16 */
314       execute_window_loop(
315         win,
316         [&](const Coordinates &) {
317           const auto input_ptr = reinterpret_cast<const uint8_t *>(input.ptr());
318           const auto output_ptr = reinterpret_cast<uint16_t *>(output.ptr());
319
320           int x = window_start_x;
321           for (; x <= (window_end_x - window_step_x); x += window_step_x)
322           {
323             const uint8x16_t texels_u8 = vld1q_u8(input_ptr + x);
324
325             const uint16x8x2_t texels = {{vmovl_u8(vand_u8(vget_low_u8(texels_u8), mask_bool)),
326                                           vmovl_u8(vand_u8(vget_high_u8(texels_u8), mask_bool))}};
327
328             vst1q_u16(output_ptr + x, texels.val[0]);
329             vst1q_u16(output_ptr + x + 8, texels.val[1]);
330           }
331
332           // Compute left-over elements
333           for (; x < window_end_x; ++x)
334           {
335             *(output_ptr + x) = static_cast<uint16_t>(*(input_ptr + x) & true_val);
336           }
337         },
338         input, output);
339       break;
340     }
341     default:
342       ARM_COMPUTE_ERROR("Output data type not supported");
343   }
344 }