2 * Copyright (c) 2016, 2017 ARM Limited.
4 * SPDX-License-Identifier: MIT
6 * Permission is hereby granted, free of charge, to any person obtaining a copy
7 * of this software and associated documentation files (the "Software"), to
8 * deal in the Software without restriction, including without limitation the
9 * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10 * sell copies of the Software, and to permit persons to whom the Software is
11 * furnished to do so, subject to the following conditions:
13 * The above copyright notice and this permission notice shall be included in all
14 * copies or substantial portions of the Software.
16 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
24 #include "arm_compute/core/NEON/kernels/NEArithmeticSubtractionKernel.h"
26 #include "arm_compute/core/Error.h"
27 #include "arm_compute/core/Helpers.h"
28 #include "arm_compute/core/ITensor.h"
29 #include "arm_compute/core/TensorInfo.h"
30 #include "arm_compute/core/Validate.h"
37 using namespace arm_compute;
42 } // namespace arm_compute
46 void sub_wrap_U8_U8_U8(const ITensor *in1, const ITensor *in2, ITensor *out, const Window &window)
48 Iterator input1(in1, window);
49 Iterator input2(in2, window);
50 Iterator output(out, window);
52 execute_window_loop(window, [&](const Coordinates & id)
54 const uint8x16_t ta1 = vld1q_u8(input1.ptr());
55 const uint8x16_t ta2 = vld1q_u8(input2.ptr());
57 vst1q_u8(output.ptr(), vsubq_u8(ta1, ta2));
59 input1, input2, output);
62 void sub_saturate_U8_U8_U8(const ITensor *in1, const ITensor *in2, ITensor *out, const Window &window)
64 Iterator input1(in1, window);
65 Iterator input2(in2, window);
66 Iterator output(out, window);
68 execute_window_loop(window, [&](const Coordinates & id)
70 const uint8x16_t ta1 = vld1q_u8(input1.ptr());
71 const uint8x16_t ta2 = vld1q_u8(input2.ptr());
73 vst1q_u8(output.ptr(), vqsubq_u8(ta1, ta2));
75 input1, input2, output);
78 void sub_wrap_S16_S16_S16(const ITensor *in1, const ITensor *in2, ITensor *out, const Window &window)
80 Iterator input1(in1, window);
81 Iterator input2(in2, window);
82 Iterator output(out, window);
84 execute_window_loop(window, [&](const Coordinates & id)
86 const int16x8x2_t ta1 = vld2q_s16(reinterpret_cast<const int16_t *>(input1.ptr()));
87 const int16x8x2_t ta2 = vld2q_s16(reinterpret_cast<const int16_t *>(input2.ptr()));
89 const int16x8x2_t ta3 =
92 vsubq_s16(ta1.val[0], ta2.val[0]),
93 vsubq_s16(ta1.val[1], ta2.val[1])
97 vst2q_s16(reinterpret_cast<int16_t *>(output.ptr()), ta3);
99 input1, input2, output);
102 void sub_saturate_S16_S16_S16(const ITensor *in1, const ITensor *in2, ITensor *out, const Window &window)
104 Iterator input1(in1, window);
105 Iterator input2(in2, window);
106 Iterator output(out, window);
108 execute_window_loop(window, [&](const Coordinates & id)
110 const int16x8x2_t ta1 = vld2q_s16(reinterpret_cast<const int16_t *>(input1.ptr()));
111 const int16x8x2_t ta2 = vld2q_s16(reinterpret_cast<const int16_t *>(input2.ptr()));
113 const int16x8x2_t ta3 =
116 vqsubq_s16(ta1.val[0], ta2.val[0]),
117 vqsubq_s16(ta1.val[1], ta2.val[1])
121 vst2q_s16(reinterpret_cast<int16_t *>(output.ptr()), ta3);
123 input1, input2, output);
126 void sub_F32_F32_F32(const ITensor *in1, const ITensor *in2, ITensor *out, const Window &window)
128 Iterator input1(in1, window);
129 Iterator input2(in2, window);
130 Iterator output(out, window);
132 execute_window_loop(window, [&](const Coordinates & id)
134 const float32x4x4_t ta1 = vld4q_f32(reinterpret_cast<const float *>(input1.ptr()));
135 const float32x4x4_t ta2 = vld4q_f32(reinterpret_cast<const float *>(input2.ptr()));
137 const float32x4x4_t ta3 =
140 vsubq_f32(ta1.val[0], ta2.val[0]),
141 vsubq_f32(ta1.val[1], ta2.val[1]),
142 vsubq_f32(ta1.val[2], ta2.val[2]),
143 vsubq_f32(ta1.val[3], ta2.val[3]),
147 vst4q_f32(reinterpret_cast<float *>(output.ptr()), ta3);
149 input1, input2, output);
151 void sub_wrap_S16_U8_S16(const ITensor *in1, const ITensor *in2, ITensor *out, const Window &window)
153 Iterator input1(in1, window);
154 Iterator input2(in2, window);
155 Iterator output(out, window);
157 execute_window_loop(window, [&](const Coordinates & id)
159 const uint8x16_t bv_0 = vld1q_u8(input2.ptr());
160 int16x8_t a1_0 = vld1q_s16(reinterpret_cast<const int16_t *>(input1.ptr()));
161 int16x8_t a2_0 = vld1q_s16(reinterpret_cast<const int16_t *>(input1.ptr()) + 8);
163 a1_0 = vsubq_s16(a1_0, vreinterpretq_s16_u16(vmovl_u8(vget_low_u8(bv_0))));
164 a2_0 = vsubq_s16(a2_0, vreinterpretq_s16_u16(vmovl_u8(vget_high_u8(bv_0))));
166 vst1q_s16(reinterpret_cast<int16_t *>(output.ptr()), a1_0);
167 vst1q_s16(reinterpret_cast<int16_t *>(output.ptr()) + 8, a2_0);
169 input1, input2, output);
172 void sub_saturate_S16_U8_S16(const ITensor *in1, const ITensor *in2, ITensor *out, const Window &window)
174 Iterator input1(in1, window);
175 Iterator input2(in2, window);
176 Iterator output(out, window);
178 execute_window_loop(window, [&](const Coordinates & id)
180 const uint8x16_t bv_0 = vld1q_u8(input2.ptr());
181 int16x8_t a1_0 = vld1q_s16(reinterpret_cast<const int16_t *>(input1.ptr()));
182 int16x8_t a2_0 = vld1q_s16(reinterpret_cast<const int16_t *>(input1.ptr()) + 8);
184 a1_0 = vqsubq_s16(a1_0, vreinterpretq_s16_u16(vmovl_u8(vget_low_u8(bv_0))));
185 a2_0 = vqsubq_s16(a2_0, vreinterpretq_s16_u16(vmovl_u8(vget_high_u8(bv_0))));
187 vst1q_s16(reinterpret_cast<int16_t *>(output.ptr()), a1_0);
188 vst1q_s16(reinterpret_cast<int16_t *>(output.ptr()) + 8, a2_0);
190 input1, input2, output);
193 void sub_wrap_U8_S16_S16(const ITensor *in1, const ITensor *in2, ITensor *out, const Window &window)
195 Iterator input1(in1, window);
196 Iterator input2(in2, window);
197 Iterator output(out, window);
199 execute_window_loop(window, [&](const Coordinates & id)
201 const uint8x16_t bv_0 = vld1q_u8(input1.ptr());
202 int16x8_t a1_0 = vld1q_s16(reinterpret_cast<const int16_t *>(input2.ptr()));
203 int16x8_t a2_0 = vld1q_s16(reinterpret_cast<const int16_t *>(input2.ptr()) + 8);
205 a1_0 = vsubq_s16(vreinterpretq_s16_u16(vmovl_u8(vget_low_u8(bv_0))), a1_0);
206 a2_0 = vsubq_s16(vreinterpretq_s16_u16(vmovl_u8(vget_high_u8(bv_0))), a2_0);
208 vst1q_s16(reinterpret_cast<int16_t *>(output.ptr()), a1_0);
209 vst1q_s16(reinterpret_cast<int16_t *>(output.ptr()) + 8, a2_0);
211 input1, input2, output);
214 void sub_saturate_U8_S16_S16(const ITensor *in1, const ITensor *in2, ITensor *out, const Window &window)
216 Iterator input1(in1, window);
217 Iterator input2(in2, window);
218 Iterator output(out, window);
220 execute_window_loop(window, [&](const Coordinates & id)
222 const uint8x16_t bv_0 = vld1q_u8(input1.ptr());
223 int16x8_t a1_0 = vld1q_s16(reinterpret_cast<const int16_t *>(input2.ptr()));
224 int16x8_t a2_0 = vld1q_s16(reinterpret_cast<const int16_t *>(input2.ptr()) + 8);
226 a1_0 = vqsubq_s16(vreinterpretq_s16_u16(vmovl_u8(vget_low_u8(bv_0))), a1_0);
227 a2_0 = vqsubq_s16(vreinterpretq_s16_u16(vmovl_u8(vget_high_u8(bv_0))), a2_0);
229 vst1q_s16(reinterpret_cast<int16_t *>(output.ptr()), a1_0);
230 vst1q_s16(reinterpret_cast<int16_t *>(output.ptr()) + 8, a2_0);
232 input1, input2, output);
235 void sub_wrap_U8_U8_S16(const ITensor *in1, const ITensor *in2, ITensor *out, const Window &window)
237 Iterator input1(in1, window);
238 Iterator input2(in2, window);
239 Iterator output(out, window);
241 execute_window_loop(window, [&](const Coordinates & id)
243 const uint8x16_t av_0 = vld1q_u8(input1.ptr());
244 const uint8x16_t bv_0 = vld1q_u8(input2.ptr());
246 const int16x8_t a1_0 = vsubq_s16(vreinterpretq_s16_u16(vmovl_u8(vget_low_u8(av_0))),
247 vreinterpretq_s16_u16(vmovl_u8(vget_low_u8(bv_0))));
248 const int16x8_t a2_0 = vsubq_s16(vreinterpretq_s16_u16(vmovl_u8(vget_high_u8(av_0))),
249 vreinterpretq_s16_u16(vmovl_u8(vget_high_u8(bv_0))));
251 vst1q_s16(reinterpret_cast<int16_t *>(output.ptr()), a1_0);
252 vst1q_s16(reinterpret_cast<int16_t *>(output.ptr()) + 8, a2_0);
254 input1, input2, output);
257 void sub_saturate_U8_U8_S16(const ITensor *in1, const ITensor *in2, ITensor *out, const Window &window)
259 Iterator input1(in1, window);
260 Iterator input2(in2, window);
261 Iterator output(out, window);
263 execute_window_loop(window, [&](const Coordinates & id)
265 const uint8x16_t av_0 = vld1q_u8(input1.ptr());
266 const uint8x16_t bv_0 = vld1q_u8(input2.ptr());
268 const int16x8_t a1_0 = vqsubq_s16(vreinterpretq_s16_u16(vmovl_u8(vget_low_u8(av_0))),
269 vreinterpretq_s16_u16(vmovl_u8(vget_low_u8(bv_0))));
270 const int16x8_t a2_0 = vqsubq_s16(vreinterpretq_s16_u16(vmovl_u8(vget_high_u8(av_0))),
271 vreinterpretq_s16_u16(vmovl_u8(vget_high_u8(bv_0))));
273 vst1q_s16(reinterpret_cast<int16_t *>(output.ptr()), a1_0);
274 vst1q_s16(reinterpret_cast<int16_t *>(output.ptr()) + 8, a2_0);
276 input1, input2, output);
280 NEArithmeticSubtractionKernel::NEArithmeticSubtractionKernel()
281 : _func(nullptr), _input1(nullptr), _input2(nullptr), _output(nullptr)
285 void NEArithmeticSubtractionKernel::configure(const ITensor *input1, const ITensor *input2, ITensor *output, ConvertPolicy policy)
287 ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input1, 1, DataType::U8, DataType::S16, DataType::F32);
288 ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input2, 1, DataType::U8, DataType::S16, DataType::F32);
289 ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(output, 1, DataType::U8, DataType::S16, DataType::F32);
291 /* If one of the inputs is 16bit then the output must be 16bit too: */
292 ARM_COMPUTE_ERROR_ON_MSG(output->info()->data_type() == DataType::U8 && (input1->info()->data_type() != DataType::U8 || input2->info()->data_type() != DataType::U8),
293 "Output can only be U8 if both inputs are U8");
295 static std::map<std::string, SubFunction *> map_function =
297 { "sub_wrap_U8_U8_U8", &sub_wrap_U8_U8_U8 },
298 { "sub_wrap_U8_U8_S16", &sub_wrap_U8_U8_S16 },
299 { "sub_saturate_U8_U8_U8", &sub_saturate_U8_U8_U8 },
300 { "sub_saturate_U8_U8_S16", &sub_saturate_U8_U8_S16 },
301 { "sub_wrap_U8_S16_S16", &sub_wrap_U8_S16_S16 },
302 { "sub_wrap_S16_U8_S16", &sub_wrap_S16_U8_S16 },
303 { "sub_saturate_U8_S16_S16", &sub_saturate_U8_S16_S16 },
304 { "sub_saturate_S16_U8_S16", &sub_saturate_S16_U8_S16 },
305 { "sub_wrap_S16_S16_S16", &sub_wrap_S16_S16_S16 },
306 { "sub_saturate_S16_S16_S16", &sub_saturate_S16_S16_S16 },
307 { "sub_wrap_F32_F32_F32", &sub_F32_F32_F32 },
308 { "sub_saturate_F32_F32_F32", &sub_F32_F32_F32 },
315 std::string function_to_call("sub_");
316 function_to_call += policy == ConvertPolicy::WRAP ? "wrap_" : "saturate_";
317 function_to_call += string_from_data_type(input1->info()->data_type()) + "_";
318 function_to_call += string_from_data_type(input2->info()->data_type()) + "_";
319 function_to_call += string_from_data_type(output->info()->data_type());
321 auto it = map_function.find(function_to_call);
323 if(it != map_function.end())
329 ARM_COMPUTE_ERROR("You called subtract with the wrong image formats");
332 constexpr unsigned int num_elems_processed_per_iteration = 16;
334 // Configure kernel window
335 Window win = calculate_max_window(*input1->info(), Steps(num_elems_processed_per_iteration));
336 AccessWindowHorizontal output_access(output->info(), 0, num_elems_processed_per_iteration);
338 update_window_and_padding(win,
339 AccessWindowHorizontal(input1->info(), 0, num_elems_processed_per_iteration),
340 AccessWindowHorizontal(input2->info(), 0, num_elems_processed_per_iteration),
343 ValidRegion valid_region = intersect_valid_regions(input1->info()->valid_region(),
344 input2->info()->valid_region());
346 output_access.set_valid_region(win, valid_region);
348 INEKernel::configure(win);
351 void NEArithmeticSubtractionKernel::run(const Window &window)
353 ARM_COMPUTE_ERROR_ON_UNCONFIGURED_KERNEL(this);
354 ARM_COMPUTE_ERROR_ON_INVALID_SUBWINDOW(INEKernel::window(), window);
355 ARM_COMPUTE_ERROR_ON(_func == nullptr);
357 (*_func)(_input1, _input2, _output, window);