arm_compute v17.03.1
[platform/upstream/armcl.git] / src / core / NEON / kernels / NENonMaximaSuppression3x3Kernel.cpp
1 /*
2  * Copyright (c) 2016, 2017 ARM Limited.
3  *
4  * SPDX-License-Identifier: MIT
5  *
6  * Permission is hereby granted, free of charge, to any person obtaining a copy
7  * of this software and associated documentation files (the "Software"), to
8  * deal in the Software without restriction, including without limitation the
9  * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10  * sell copies of the Software, and to permit persons to whom the Software is
11  * furnished to do so, subject to the following conditions:
12  *
13  * The above copyright notice and this permission notice shall be included in all
14  * copies or substantial portions of the Software.
15  *
16  * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17  * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18  * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19  * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20  * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21  * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22  * SOFTWARE.
23  */
24 #include "arm_compute/core/NEON/kernels/NENonMaximaSuppression3x3Kernel.h"
25
26 #include "arm_compute/core/AccessWindowAutoPadding.h"
27 #include "arm_compute/core/Error.h"
28 #include "arm_compute/core/Helpers.h"
29 #include "arm_compute/core/ITensor.h"
30 #include "arm_compute/core/TensorInfo.h"
31 #include "arm_compute/core/Types.h"
32 #include "arm_compute/core/Utils.h"
33 #include "arm_compute/core/Validate.h"
34
35 #include <arm_neon.h>
36 #include <cstddef>
37
38 using namespace arm_compute;
39
40 namespace arm_compute
41 {
42 class Coordinates;
43 } // namespace arm_compute
44
45 #ifdef ARM_COMPUTE_ENABLE_FP16
46 namespace fp16
47 {
48 inline void mask_top(const float16x8_t &vc, const float16x8_t &in0, const float16x8_t &in1, uint16x8_t &mask)
49 {
50     // vc > nc.val[0], vc > nc.val[1], vc > nc.val[2]
51     mask = vandq_u16(mask, vcgeq_f16(vc, in0));
52     mask = vandq_u16(mask, vcgeq_f16(vc, vextq_f16(in0, in1, 1)));
53     mask = vandq_u16(mask, vcgeq_f16(vc, vextq_f16(in0, in1, 2)));
54 }
55
56 inline void mask_middle(const float16x8_t &vc, const float16x8_t &in0, const float16x8_t &in1, uint16x8_t &mask)
57 {
58     // vc >= nc.val[0], vc > nc.val[2]
59     mask = vandq_u16(mask, vcgeq_f16(vc, in0));
60     mask = vandq_u16(mask, vcgtq_f16(vc, vextq_f16(in0, in1, 2)));
61 }
62
63 inline void mask_bottom(const float16x8_t &vc, const float16x8_t &in0, const float16x8_t &in1, uint16x8_t &mask)
64 {
65     // vc > nc.val[0], vc > nc.val[1], vc > nc.val[2]
66     mask = vandq_u16(mask, vcgtq_f16(vc, in0));
67     mask = vandq_u16(mask, vcgtq_f16(vc, vextq_f16(in0, in1, 1)));
68     mask = vandq_u16(mask, vcgtq_f16(vc, vextq_f16(in0, in1, 2)));
69 }
70
71 inline void non_maxima_suppression3x3_F32_F32(const void *__restrict in_ptr, void *__restrict out_ptr, const uint32_t in_stride)
72 {
73     auto       in  = static_cast<const float *__restrict>(in_ptr) - 1;
74     const auto out = static_cast<float *__restrict>(out_ptr);
75
76     // Get centre scores
77     const float16x8x2_t vc =
78     {
79         vcombine_f16(vcvt_f16_f32(vld1q_f32(in + 1)), vcvt_f16_f32(vld1q_f32(in + 5))),
80         vcombine_f16(vcvt_f16_f32(vld1q_f32(in + 9)), vcvt_f16_f32(vld1q_f32(in + 13)))
81     };
82
83     // Neighboring pixels
84     in -= in_stride;
85
86     static const float16x4_t  zero_f16x4 = vdup_n_f16(0);
87     static const uint16x8_t   zero_u16   = vdupq_n_u16(0);
88     static const uint16x8_t   true_mask  = vceqq_u16(zero_u16, zero_u16);
89     static const uint16x8x2_t true_mask_x2 =
90     {
91         true_mask,
92         true_mask
93     };
94
95     uint16x8x2_t mask = true_mask_x2;
96
97     // Top row
98     const float16x8_t tmp_top0 = vcombine_f16(vcvt_f16_f32(vld1q_f32(in)), vcvt_f16_f32(vld1q_f32(in + 4)));
99     const float16x8_t tmp_top1 = vcombine_f16(vcvt_f16_f32(vld1q_f32(in + 8)), vcvt_f16_f32(vld1q_f32(in + 12)));
100     const float16x8_t tmp_top2 = vcombine_f16(vcvt_f16_f32(vld1q_f32(in + 16)), zero_f16x4);
101
102     // vc >= nc.val[0], vc >= nc.val[1], vc >= nc.val[2]
103     mask_top(vc.val[0], tmp_top0, tmp_top1, mask.val[0]);
104     mask_top(vc.val[1], tmp_top1, tmp_top2, mask.val[1]);
105
106     in += in_stride;
107
108     // Middle row
109     const float16x8_t tmp_mid0 = vcombine_f16(vcvt_f16_f32(vld1q_f32(in)), vcvt_f16_f32(vld1q_f32(in + 4)));
110     const float16x8_t tmp_mid1 = vcombine_f16(vcvt_f16_f32(vld1q_f32(in + 8)), vcvt_f16_f32(vld1q_f32(in + 12)));
111     const float16x8_t tmp_mid2 = vcombine_f16(vcvt_f16_f32(vld1q_f32(in + 16)), zero_f16x4);
112
113     // vc >= nc.val[0], vc > nc.val[2]
114     mask_middle(vc.val[0], tmp_mid0, tmp_mid1, mask.val[0]);
115     mask_middle(vc.val[1], tmp_mid1, tmp_mid2, mask.val[1]);
116
117     in += in_stride;
118
119     // Bottom row
120     const float16x8_t tmp_bot0 = vcombine_f16(vcvt_f16_f32(vld1q_f32(in)), vcvt_f16_f32(vld1q_f32(in + 4)));
121     const float16x8_t tmp_bot1 = vcombine_f16(vcvt_f16_f32(vld1q_f32(in + 8)), vcvt_f16_f32(vld1q_f32(in + 12)));
122     const float16x8_t tmp_bot2 = vcombine_f16(vcvt_f16_f32(vld1q_f32(in + 16)), zero_f16x4);
123
124     // vc > nc.val[0], vc > nc.val[1], vc > nc.val[2]
125     mask_bottom(vc.val[0], tmp_bot0, tmp_bot1, mask.val[0]);
126     mask_bottom(vc.val[1], tmp_bot1, tmp_bot2, mask.val[1]);
127
128     // Store
129     static const float16x8_t zero_f16x8 = vdupq_n_f16(0);
130
131     const float16x8_t suppressed0 = vbslq_f16(mask.val[0], vc.val[0], zero_f16x8);
132     vst1q_f32(out + 0, vcvt_f32_f16(vget_low_f16(suppressed0)));
133     vst1q_f32(out + 4, vcvt_f32_f16(vget_high_f16(suppressed0)));
134
135     const float16x8_t suppressed1 = vbslq_f16(mask.val[1], vc.val[1], zero_f16x8);
136     vst1q_f32(out + 8, vcvt_f32_f16(vget_low_f16(suppressed1)));
137     vst1q_f32(out + 12, vcvt_f32_f16(vget_high_f16(suppressed1)));
138 }
139
140 inline void non_maxima_suppression3x3_U8_U8(const void *__restrict in_ptr, void *__restrict out_ptr, const uint32_t in_stride)
141 {
142     auto       in  = static_cast<const uint8_t *__restrict>(in_ptr) - 1;
143     const auto out = static_cast<uint8_t *__restrict>(out_ptr);
144
145     // Get centre scores
146     const uint8x16_t vc = vld1q_u8(in + 1);
147
148     // Neighboring pixels
149     in -= in_stride;
150
151     // Top row
152     const uint8x16_t l_nc_0 = vld1q_u8(in);
153     const uint8x16_t m_nc_0 = vld1q_u8(in + 1);
154     const uint8x16_t r_nc_0 = vld1q_u8(in + 2);
155
156     // Keep center scores if ...
157     // vc >= l_nc_0, vc >= m_nc_0, vc >= r_nc_0
158     uint8x16_t mask = vcgeq_u8(vc, l_nc_0);
159     mask            = vandq_u8(mask, vcgeq_u8(vc, m_nc_0));
160     mask            = vandq_u8(mask, vcgeq_u8(vc, r_nc_0));
161
162     in += in_stride;
163
164     // Middle row
165     const uint8x16_t l_nc_1 = vld1q_u8(in);
166     const uint8x16_t r_nc_1 = vld1q_u8(in + 2);
167
168     // ... and ...
169     // vc >= l_nc_1, vc > r_nc_1
170     mask = vandq_u8(mask, vcgeq_u8(vc, l_nc_1));
171     mask = vandq_u8(mask, vcgtq_u8(vc, r_nc_1));
172
173     in += in_stride;
174
175     // Bottom row
176     const uint8x16_t l_nc_2 = vld1q_u8(in);
177     const uint8x16_t m_nc_2 = vld1q_u8(in + 1);
178     const uint8x16_t r_nc_2 = vld1q_u8(in + 2);
179
180     // ... and ...
181     // vc > l_nc_2, vc > m_nc_2, vc > r_nc_2
182     mask = vandq_u8(mask, vcgtq_u8(vc, l_nc_2));
183     mask = vandq_u8(mask, vcgtq_u8(vc, m_nc_2));
184     mask = vandq_u8(mask, vcgtq_u8(vc, r_nc_2));
185
186     // Store
187     static const uint8x16_t zero = vdupq_n_u8(0);
188     vst1q_u8(out, vbslq_u8(mask, vc, zero));
189 }
190 } // namespace fp16
191
192 void NENonMaximaSuppression3x3FP16Kernel::configure(const ITensor *input, ITensor *output, bool border_undefined)
193 {
194     ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::U8, DataType::F32);
195     ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(output, 1, DataType::U8, DataType::F32);
196     ARM_COMPUTE_ERROR_ON_MISMATCHING_DATA_TYPES(input, output);
197
198     _input  = input;
199     _output = output;
200
201     switch(input->info()->data_type())
202     {
203         case DataType::U8:
204             _func = &fp16::non_maxima_suppression3x3_U8_U8;
205             break;
206         default:
207             _func = &fp16::non_maxima_suppression3x3_F32_F32;
208             break;
209     }
210
211     const unsigned int processed_elements = 16;
212
213     // Configure kernel window
214     Window                  win = calculate_max_window(*input->info(), Steps(processed_elements), border_undefined, border_size());
215     AccessWindowAutoPadding output_access(output->info());
216
217     update_window_and_padding(win,
218                               AccessWindowAutoPadding(input->info()),
219                               output_access);
220
221     output_access.set_valid_region();
222
223     INEKernel::configure(win);
224 }
225 #endif
226
227 namespace
228 {
229 inline void non_maxima_suppression3x3_FLOAT_FLOAT(const void *__restrict input_ptr, void *__restrict output_ptr, const uint32_t input_stride)
230 {
231     auto       input  = static_cast<const float *__restrict>(input_ptr) - 1;
232     const auto output = static_cast<float *__restrict>(output_ptr);
233
234     /* Get centre scores */
235     const float32x4x4_t vc =
236     {
237         {
238             vld1q_f32(input + 1),
239             vld1q_f32(input + 5),
240             vld1q_f32(input + 9),
241             vld1q_f32(input + 13)
242         }
243     };
244
245     /* Neighboring pixels */
246     float32x4x4_t l_nc{ {} };
247     float32x4x4_t m_nc{ {} };
248     float32x4x4_t r_nc{ {} };
249
250     input -= input_stride;
251
252     /* Row0 - Low part */
253     float32x4_t tmp_low   = vld1q_f32(input);
254     float32x4_t tmp_high  = vld1q_f32(input + 4);
255     float32x4_t tmp_high1 = vld1q_f32(input + 8);
256
257     l_nc.val[0] = tmp_low;
258     m_nc.val[0] = vextq_f32(tmp_low, tmp_high, 1);
259     r_nc.val[0] = vextq_f32(tmp_low, tmp_high, 2);
260
261     tmp_low  = tmp_high;
262     tmp_high = tmp_high1;
263
264     l_nc.val[1] = tmp_low;
265     m_nc.val[1] = vextq_f32(tmp_low, tmp_high, 1);
266     r_nc.val[1] = vextq_f32(tmp_low, tmp_high, 2);
267
268     /* Row0 - High part */
269     tmp_low   = tmp_high1;
270     tmp_high  = vld1q_f32(input + 12);
271     tmp_high1 = vld1q_f32(input + 16);
272
273     l_nc.val[2] = tmp_low;
274     m_nc.val[2] = vextq_f32(tmp_low, tmp_high, 1);
275     r_nc.val[2] = vextq_f32(tmp_low, tmp_high, 2);
276
277     tmp_low  = tmp_high;
278     tmp_high = tmp_high1;
279
280     l_nc.val[3] = tmp_low;
281     m_nc.val[3] = vextq_f32(tmp_low, tmp_high, 1);
282     r_nc.val[3] = vextq_f32(tmp_low, tmp_high, 2);
283
284     /* mc >= nc.val[0], mc >= nc.val[1], mc >= nc.val[2] */
285     uint32x4x4_t mask{ {} };
286     mask.val[0] = vcgeq_f32(vc.val[0], l_nc.val[0]);
287     mask.val[0] = vandq_u32(mask.val[0], vcgeq_f32(vc.val[0], m_nc.val[0]));
288     mask.val[0] = vandq_u32(mask.val[0], vcgeq_f32(vc.val[0], r_nc.val[0]));
289     mask.val[1] = vcgeq_f32(vc.val[1], l_nc.val[1]);
290     mask.val[1] = vandq_u32(mask.val[1], vcgeq_f32(vc.val[1], m_nc.val[1]));
291     mask.val[1] = vandq_u32(mask.val[1], vcgeq_f32(vc.val[1], r_nc.val[1]));
292     mask.val[2] = vcgeq_f32(vc.val[2], l_nc.val[2]);
293     mask.val[2] = vandq_u32(mask.val[2], vcgeq_f32(vc.val[2], m_nc.val[2]));
294     mask.val[2] = vandq_u32(mask.val[2], vcgeq_f32(vc.val[2], r_nc.val[2]));
295     mask.val[3] = vcgeq_f32(vc.val[3], l_nc.val[3]);
296     mask.val[3] = vandq_u32(mask.val[3], vcgeq_f32(vc.val[3], m_nc.val[3]));
297     mask.val[3] = vandq_u32(mask.val[3], vcgeq_f32(vc.val[3], r_nc.val[3]));
298
299     input += input_stride;
300
301     /* Row1 - Low part */
302     tmp_low   = vld1q_f32(input);
303     tmp_high  = vld1q_f32(input + 4);
304     tmp_high1 = vld1q_f32(input + 8);
305
306     l_nc.val[0] = tmp_low;
307     r_nc.val[0] = vextq_f32(tmp_low, tmp_high, 2);
308
309     tmp_low  = tmp_high;
310     tmp_high = tmp_high1;
311
312     l_nc.val[1] = tmp_low;
313     r_nc.val[1] = vextq_f32(tmp_low, tmp_high, 2);
314
315     /* Row1 - High part */
316     tmp_low   = tmp_high1;
317     tmp_high  = vld1q_f32(input + 12);
318     tmp_high1 = vld1q_f32(input + 16);
319
320     l_nc.val[2] = tmp_low;
321     r_nc.val[2] = vextq_f32(tmp_low, tmp_high, 2);
322
323     tmp_low  = tmp_high;
324     tmp_high = tmp_high1;
325
326     l_nc.val[3] = tmp_low;
327     r_nc.val[3] = vextq_f32(tmp_low, tmp_high, 2);
328
329     /* mc >= nc.val[0], mc > nc.val[2] */
330     mask.val[0] = vandq_u32(mask.val[0], vcgeq_f32(vc.val[0], l_nc.val[0]));
331     mask.val[0] = vandq_u32(mask.val[0], vcgtq_f32(vc.val[0], r_nc.val[0]));
332     mask.val[1] = vandq_u32(mask.val[1], vcgeq_f32(vc.val[1], l_nc.val[1]));
333     mask.val[1] = vandq_u32(mask.val[1], vcgtq_f32(vc.val[1], r_nc.val[1]));
334     mask.val[2] = vandq_u32(mask.val[2], vcgeq_f32(vc.val[2], l_nc.val[2]));
335     mask.val[2] = vandq_u32(mask.val[2], vcgtq_f32(vc.val[2], r_nc.val[2]));
336     mask.val[3] = vandq_u32(mask.val[3], vcgeq_f32(vc.val[3], l_nc.val[3]));
337     mask.val[3] = vandq_u32(mask.val[3], vcgtq_f32(vc.val[3], r_nc.val[3]));
338
339     input += input_stride;
340
341     /* Row2 - Low part */
342     tmp_low   = vld1q_f32(input);
343     tmp_high  = vld1q_f32(input + 4);
344     tmp_high1 = vld1q_f32(input + 8);
345
346     l_nc.val[0] = tmp_low;
347     m_nc.val[0] = vextq_f32(tmp_low, tmp_high, 1);
348     r_nc.val[0] = vextq_f32(tmp_low, tmp_high, 2);
349
350     tmp_low  = tmp_high;
351     tmp_high = tmp_high1;
352
353     l_nc.val[1] = tmp_low;
354     m_nc.val[1] = vextq_f32(tmp_low, tmp_high, 1);
355     r_nc.val[1] = vextq_f32(tmp_low, tmp_high, 2);
356
357     /* Row2 - High part */
358     tmp_low   = tmp_high1;
359     tmp_high  = vld1q_f32(input + 12);
360     tmp_high1 = vld1q_f32(input + 16);
361
362     l_nc.val[2] = tmp_low;
363     m_nc.val[2] = vextq_f32(tmp_low, tmp_high, 1);
364     r_nc.val[2] = vextq_f32(tmp_low, tmp_high, 2);
365
366     tmp_low  = tmp_high;
367     tmp_high = tmp_high1;
368
369     l_nc.val[3] = tmp_low;
370     m_nc.val[3] = vextq_f32(tmp_low, tmp_high, 1);
371     r_nc.val[3] = vextq_f32(tmp_low, tmp_high, 2);
372
373     /* mc > nc.val[0], mc > nc.val[1], mc > nc.val[2] */
374     mask.val[0] = vandq_u32(mask.val[0], vcgtq_f32(vc.val[0], l_nc.val[0]));
375     mask.val[0] = vandq_u32(mask.val[0], vcgtq_f32(vc.val[0], m_nc.val[0]));
376     mask.val[0] = vandq_u32(mask.val[0], vcgtq_f32(vc.val[0], r_nc.val[0]));
377     mask.val[1] = vandq_u32(mask.val[1], vcgtq_f32(vc.val[1], l_nc.val[1]));
378     mask.val[1] = vandq_u32(mask.val[1], vcgtq_f32(vc.val[1], m_nc.val[1]));
379     mask.val[1] = vandq_u32(mask.val[1], vcgtq_f32(vc.val[1], r_nc.val[1]));
380     mask.val[2] = vandq_u32(mask.val[2], vcgtq_f32(vc.val[2], l_nc.val[2]));
381     mask.val[2] = vandq_u32(mask.val[2], vcgtq_f32(vc.val[2], m_nc.val[2]));
382     mask.val[2] = vandq_u32(mask.val[2], vcgtq_f32(vc.val[2], r_nc.val[2]));
383     mask.val[3] = vandq_u32(mask.val[3], vcgtq_f32(vc.val[3], l_nc.val[3]));
384     mask.val[3] = vandq_u32(mask.val[3], vcgtq_f32(vc.val[3], m_nc.val[3]));
385     mask.val[3] = vandq_u32(mask.val[3], vcgtq_f32(vc.val[3], r_nc.val[3]));
386
387     static const float32x4_t zero = vdupq_n_f32(0.f);
388
389     /* Store */
390     vst1q_f32(output + 0, vbslq_f32(mask.val[0], vc.val[0], zero));
391     vst1q_f32(output + 4, vbslq_f32(mask.val[1], vc.val[1], zero));
392     vst1q_f32(output + 8, vbslq_f32(mask.val[2], vc.val[2], zero));
393     vst1q_f32(output + 12, vbslq_f32(mask.val[3], vc.val[3], zero));
394 }
395
396 inline void non_maxima_suppression3x3_U8_U8(const void *__restrict input_ptr, void *__restrict output_ptr, const uint32_t input_stride)
397 {
398     auto       input  = static_cast<const uint8_t *__restrict>(input_ptr) - 1;
399     const auto output = static_cast<uint8_t *__restrict>(output_ptr);
400
401     /* Get centre scores */
402     const uint8x16_t vc = vld1q_u8(input + 1);
403
404     /* Neighboring pixels */
405     uint8x16_t l_nc{};
406     uint8x16_t m_nc{};
407     uint8x16_t r_nc{};
408
409     input -= input_stride;
410
411     /* Row0 */
412     l_nc = vld1q_u8(input);
413     m_nc = vld1q_u8(input + 1);
414     r_nc = vld1q_u8(input + 2);
415
416     /* mc >= l_nc, mc >= m_nc, mc >= r_nc */
417     uint8x16_t mask = vcgeq_u8(vc, l_nc);
418     mask            = vandq_u8(mask, vcgeq_u8(vc, m_nc));
419     mask            = vandq_u8(mask, vcgeq_u8(vc, r_nc));
420
421     input += input_stride;
422
423     /* Row1 */
424     l_nc = vld1q_u8(input);
425     r_nc = vld1q_u8(input + 2);
426
427     /* mc >= l_nc, mc > r_nc */
428     mask = vandq_u8(mask, vcgeq_u8(vc, l_nc));
429     mask = vandq_u8(mask, vcgtq_u8(vc, r_nc));
430
431     input += input_stride;
432
433     /* Row2 */
434     l_nc = vld1q_u8(input);
435     m_nc = vld1q_u8(input + 1);
436     r_nc = vld1q_u8(input + 2);
437
438     /* mc > l_nc, mc > m_nc, mc > r_nc */
439     mask = vandq_u8(mask, vcgtq_u8(vc, l_nc));
440     mask = vandq_u8(mask, vcgtq_u8(vc, m_nc));
441     mask = vandq_u8(mask, vcgtq_u8(vc, r_nc));
442
443     static const uint8x16_t zero = vdupq_n_u8(0);
444
445     /* Store */
446     vst1q_u8(output, vbslq_u8(mask, vc, zero));
447 }
448 } // namespace
449
450 NENonMaximaSuppression3x3Kernel::NENonMaximaSuppression3x3Kernel()
451     : _func(nullptr), _input(nullptr), _output(nullptr)
452 {
453 }
454
455 BorderSize NENonMaximaSuppression3x3Kernel::border_size() const
456 {
457     return BorderSize(1);
458 }
459
460 void NENonMaximaSuppression3x3Kernel::configure(const ITensor *input, ITensor *output, bool border_undefined)
461 {
462     ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::U8, DataType::F32);
463     ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(output, 1, DataType::U8, DataType::F32);
464     ARM_COMPUTE_ERROR_ON_MISMATCHING_DATA_TYPES(input, output);
465
466     _input  = input;
467     _output = output;
468
469     if(input->info()->data_type() == DataType::U8)
470     {
471         _func = &non_maxima_suppression3x3_U8_U8;
472     }
473     else
474     {
475         _func = &non_maxima_suppression3x3_FLOAT_FLOAT;
476     }
477
478     const unsigned int processed_elements = 16;
479
480     // Configure kernel window
481     Window                  win = calculate_max_window(*input->info(), Steps(processed_elements), border_undefined, border_size());
482     AccessWindowAutoPadding output_access(output->info());
483
484     update_window_and_padding(win,
485                               AccessWindowAutoPadding(input->info()),
486                               output_access);
487
488     output_access.set_valid_region();
489
490     INEKernel::configure(win);
491 }
492
493 void NENonMaximaSuppression3x3Kernel::run(const Window &window)
494 {
495     ARM_COMPUTE_ERROR_ON_UNCONFIGURED_KERNEL(this);
496     ARM_COMPUTE_ERROR_ON_INVALID_SUBWINDOW(INEKernel::window(), window);
497     ARM_COMPUTE_ERROR_ON(_func == nullptr);
498     Iterator input(_input, window);
499     Iterator output(_output, window);
500
501     const size_t input_stride = _input->info()->strides_in_bytes()[1] / element_size_from_data_type(_input->info()->data_type());
502
503     execute_window_loop(window, [&](const Coordinates & id)
504     {
505         _func(input.ptr(), output.ptr(), input_stride);
506     },
507     input, output);
508 }