2 * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
8 * http://www.apache.org/licenses/LICENSE-2.0
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
18 * Copyright (c) 2016-2020 ARM Limited.
20 * SPDX-License-Identifier: MIT
22 * Permission is hereby granted, free of charge, to any person obtaining a copy
23 * of this software and associated documentation files (the "Software"), to
24 * deal in the Software without restriction, including without limitation the
25 * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
26 * sell copies of the Software, and to permit persons to whom the Software is
27 * furnished to do so, subject to the following conditions:
29 * The above copyright notice and this permission notice shall be included in all
30 * copies or substantial portions of the Software.
32 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
33 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
34 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
35 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
36 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
37 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
40 #include "arm_compute/core/NEON/kernels/NECastBoolKernel.h"
42 #include "arm_compute/core/CPP/Validate.h"
43 #include "arm_compute/core/Error.h"
44 #include "arm_compute/core/Helpers.h"
45 #include "arm_compute/core/ITensor.h"
46 #include "arm_compute/core/NEON/NEMath.h"
47 #include "arm_compute/core/TensorInfo.h"
48 #include "arm_compute/core/Validate.h"
49 #include "arm_compute/core/utils/misc/SaturateCast.h"
51 #include "arm_compute/core/NEON/wrapper/wrapper.h"
53 using namespace arm_compute;
57 Status validate_arguments(const ITensorInfo *input, const ITensorInfo *output)
59 ARM_COMPUTE_RETURN_ERROR_ON_CPU_F16_UNSUPPORTED(output);
60 ARM_COMPUTE_RETURN_ERROR_ON(input == output);
61 ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::U8);
62 ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(output, 1, DataType::U8, DataType::S8,
63 DataType::S16, DataType::U16, DataType::F16,
64 DataType::U32, DataType::S32, DataType::F32);
66 // Validate in case of configured output
67 if (output->total_size() > 0)
69 ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_SHAPES(input, output);
76 NECastBoolKernel::NECastBoolKernel() : _input(nullptr), _output(nullptr) {}
78 void NECastBoolKernel::configure(const ITensor *input, ITensor *output)
80 ARM_COMPUTE_ERROR_ON_NULLPTR(input, output);
82 // Auto initialize output shape if not initialized (We can only auto-configure the shape, datatype
84 set_shape_if_empty(*output->info(), input->info()->tensor_shape());
89 ARM_COMPUTE_ERROR_THROW_ON(validate_arguments(input->info(), output->info()));
91 // Configure kernel window
92 Window win = calculate_max_window(*input->info(), Steps());
94 coord.set_num_dimensions(output->info()->num_dimensions());
95 output->info()->set_valid_region(ValidRegion(coord, output->info()->tensor_shape()));
97 ICPPKernel::configure(win);
100 Status NECastBoolKernel::validate(const ITensorInfo *input, const ITensorInfo *output)
102 ARM_COMPUTE_RETURN_ON_ERROR(validate_arguments(input, output));
106 void NECastBoolKernel::run(const Window &window, const ThreadInfo &info)
108 ARM_COMPUTE_UNUSED(info);
109 ARM_COMPUTE_ERROR_ON_UNCONFIGURED_KERNEL(this);
110 ARM_COMPUTE_ERROR_ON_INVALID_SUBWINDOW(IKernel::window(), window);
111 ARM_COMPUTE_ERROR_ON_NULLPTR(_input, _output);
112 ARM_COMPUTE_ERROR_ON(_input == _output);
114 const auto window_start_x = static_cast<int>(window.x().start());
115 const auto window_end_x = static_cast<int>(window.x().end());
116 const int window_step_x = 16;
119 win.set(Window::DimX, Window::Dimension(0, 1, 1));
121 Iterator input(_input, win);
122 Iterator output(_output, win);
124 const uint8_t true_val = 1;
125 const uint8x8_t mask_bool = vdup_n_u8(true_val);
127 switch (_output->info()->data_type())
131 /* Conversion U8 -> S8 */
134 [&](const Coordinates &) {
135 const auto input_ptr = reinterpret_cast<const uint8_t *>(input.ptr());
136 const auto output_ptr = reinterpret_cast<int8_t *>(output.ptr());
138 int x = window_start_x;
139 for (; x <= (window_end_x - window_step_x); x += window_step_x)
141 const uint8x16_t texels_u8 = vld1q_u8(input_ptr + x);
143 vst1q_s8(output_ptr + x,
144 vreinterpretq_s8_u8(vandq_u8(texels_u8, vdupq_n_u8(true_val))));
147 // Compute left-over elements
148 for (; x < window_end_x; ++x)
150 *(output_ptr + x) = static_cast<int8_t>(*(input_ptr + x) & true_val);
158 /* Up-conversion U8 -> S16 */
161 [&](const Coordinates &) {
162 const auto input_ptr = reinterpret_cast<const uint8_t *>(input.ptr());
163 const auto output_ptr = reinterpret_cast<int16_t *>(output.ptr());
165 int x = window_start_x;
166 for (; x <= (window_end_x - window_step_x); x += window_step_x)
168 const uint8x16_t texels_u8 = vld1q_u8(input_ptr + x);
170 const int16x8x2_t texels = {
171 {vreinterpretq_s16_u16(vmovl_u8(vand_u8(vget_low_u8(texels_u8), mask_bool))),
172 vreinterpretq_s16_u16(vmovl_u8(vand_u8(vget_high_u8(texels_u8), mask_bool)))}};
174 vst1q_s16(output_ptr + x, texels.val[0]);
175 vst1q_s16(output_ptr + x + 8, texels.val[1]);
178 // Compute left-over elements
179 for (; x < window_end_x; ++x)
181 *(output_ptr + x) = static_cast<int32_t>(*(input_ptr + x) & true_val);
189 /* Up-conversion U8 -> S32 */
192 [&](const Coordinates &) {
193 const auto input_ptr = reinterpret_cast<const uint8_t *>(input.ptr());
194 const auto output_ptr = reinterpret_cast<int32_t *>(output.ptr());
196 int x = window_start_x;
197 for (; x <= (window_end_x - window_step_x); x += window_step_x)
199 const uint8x16_t texels_u8 = vld1q_u8(input_ptr + x);
201 const int16x8x2_t texels = {
202 {vreinterpretq_s16_u16(vmovl_u8(vand_u8(vget_low_u8(texels_u8), mask_bool))),
203 vreinterpretq_s16_u16(vmovl_u8(vand_u8(vget_high_u8(texels_u8), mask_bool)))}};
205 vst1q_s32(output_ptr + x, vmovl_s16(vget_low_s16(texels.val[0])));
206 vst1q_s32(output_ptr + x + 4, vmovl_s16(vget_high_s16(texels.val[0])));
207 vst1q_s32(output_ptr + x + 8, vmovl_s16(vget_low_s16(texels.val[1])));
208 vst1q_s32(output_ptr + x + 12, vmovl_s16(vget_high_s16(texels.val[1])));
211 // Compute left-over elements
212 for (; x < window_end_x; ++x)
214 *(output_ptr + x) = static_cast<uint32_t>(*(input_ptr + x) & true_val);
222 /* Up-conversion U8 -> F32 */
225 [&](const Coordinates &) {
226 const auto input_ptr = reinterpret_cast<const uint8_t *>(input.ptr());
227 const auto output_ptr = reinterpret_cast<float *>(output.ptr());
229 int x = window_start_x;
230 for (; x <= (window_end_x - window_step_x); x += window_step_x)
232 const uint8x16_t texels_u8 = vld1q_u8(input_ptr + x);
234 const int16x8x2_t texels = {
235 {vreinterpretq_s16_u16(vmovl_u8(vand_u8(vget_low_u8(texels_u8), mask_bool))),
236 vreinterpretq_s16_u16(vmovl_u8(vand_u8(vget_high_u8(texels_u8), mask_bool)))}};
237 vst1q_f32(output_ptr + x, vcvtq_f32_s32(vmovl_s16(vget_low_s16(texels.val[0]))));
238 vst1q_f32(output_ptr + x + 4, vcvtq_f32_s32(vmovl_s16(vget_high_s16(texels.val[0]))));
239 vst1q_f32(output_ptr + x + 8, vcvtq_f32_s32(vmovl_s16(vget_low_s16(texels.val[1]))));
240 vst1q_f32(output_ptr + x + 12, vcvtq_f32_s32(vmovl_s16(vget_high_s16(texels.val[1]))));
243 // Compute left-over elements
244 for (; x < window_end_x; ++x)
246 auto in = static_cast<uint32_t>(*(input_ptr + x) & true_val);
247 *(output_ptr + x) = static_cast<float>(in);
253 #ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
256 /* Up-conversion U8 -> F16 */
259 [&](const Coordinates &) {
260 const auto input_ptr = reinterpret_cast<const uint8_t *>(input.ptr());
261 const auto output_ptr = reinterpret_cast<float16_t *>(output.ptr());
263 int x = window_start_x;
264 for (; x <= (window_end_x - window_step_x); x += window_step_x)
266 const uint8x16_t texels_u8 = vld1q_u8(input_ptr + x);
268 const int16x8x2_t texels = {
269 {vreinterpretq_s16_u16(vmovl_u8(vand_u8(vget_low_u8(texels_u8), mask_bool))),
270 vreinterpretq_s16_u16(vmovl_u8(vand_u8(vget_high_u8(texels_u8), mask_bool)))}};
271 vst1q_f16(output_ptr + x, vcvtq_f16_s16(texels.val[0]));
272 vst1q_f16(output_ptr + x + 8, vcvtq_f16_s16(texels.val[1]));
275 // Compute left-over elements
276 for (; x < window_end_x; ++x)
278 *(output_ptr + x) = static_cast<float16_t>(*(input_ptr + x) & true_val);
284 #endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
287 /* Conversion U8 -> S8 */
290 [&](const Coordinates &) {
291 const auto input_ptr = reinterpret_cast<const uint8_t *>(input.ptr());
292 const auto output_ptr = reinterpret_cast<uint8_t *>(output.ptr());
294 int x = window_start_x;
295 for (; x <= (window_end_x - window_step_x); x += window_step_x)
297 const uint8x16_t texels_u8 = vld1q_u8(input_ptr + x);
299 vst1q_u8(output_ptr + x, vandq_u8(texels_u8, vdupq_n_u8(true_val)));
302 // Compute left-over elements
303 for (; x < window_end_x; ++x)
305 *(output_ptr + x) = static_cast<uint8_t>(*(input_ptr + x) & true_val);
313 /* Up-conversion U8 -> U16 */
316 [&](const Coordinates &) {
317 const auto input_ptr = reinterpret_cast<const uint8_t *>(input.ptr());
318 const auto output_ptr = reinterpret_cast<uint16_t *>(output.ptr());
320 int x = window_start_x;
321 for (; x <= (window_end_x - window_step_x); x += window_step_x)
323 const uint8x16_t texels_u8 = vld1q_u8(input_ptr + x);
325 const uint16x8x2_t texels = {{vmovl_u8(vand_u8(vget_low_u8(texels_u8), mask_bool)),
326 vmovl_u8(vand_u8(vget_high_u8(texels_u8), mask_bool))}};
328 vst1q_u16(output_ptr + x, texels.val[0]);
329 vst1q_u16(output_ptr + x + 8, texels.val[1]);
332 // Compute left-over elements
333 for (; x < window_end_x; ++x)
335 *(output_ptr + x) = static_cast<uint16_t>(*(input_ptr + x) & true_val);
342 ARM_COMPUTE_ERROR("Output data type not supported");