arm_compute v18.05
[platform/upstream/armcl.git] / src / core / NEON / kernels / NECannyEdgeKernel.cpp
1 /*
2  * Copyright (c) 2016, 2017 ARM Limited.
3  *
4  * SPDX-License-Identifier: MIT
5  *
6  * Permission is hereby granted, free of charge, to any person obtaining a copy
7  * of this software and associated documentation files (the "Software"), to
8  * deal in the Software without restriction, including without limitation the
9  * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10  * sell copies of the Software, and to permit persons to whom the Software is
11  * furnished to do so, subject to the following conditions:
12  *
13  * The above copyright notice and this permission notice shall be included in all
14  * copies or substantial portions of the Software.
15  *
16  * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17  * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18  * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19  * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20  * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21  * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22  * SOFTWARE.
23  */
24 #include "arm_compute/core/NEON/kernels/NECannyEdgeKernel.h"
25
26 #include "arm_compute/core/AccessWindowStatic.h"
27 #include "arm_compute/core/Error.h"
28 #include "arm_compute/core/Helpers.h"
29 #include "arm_compute/core/ITensor.h"
30 #include "arm_compute/core/TensorInfo.h"
31 #include "arm_compute/core/Types.h"
32 #include "arm_compute/core/Utils.h"
33 #include "arm_compute/core/Validate.h"
34
35 #include <arm_neon.h>
36 #include <cstddef>
37 #include <cstdint>
38 #include <tuple>
39
40 using namespace arm_compute;
41
42 namespace arm_compute
43 {
44 class Coordinates;
45 } // namespace arm_compute
46
47 namespace
48 {
49 constexpr int NO_EDGE = 0;
50 constexpr int EDGE    = 255;
51 constexpr int MAYBE   = 127;
52 } // namespace
53
54 #ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
55 namespace fp16
56 {
57 inline uint8x8_t phase_quantization(const float32x4x2_t &gx, const float32x4x2_t &gy)
58 {
59     // Constant use for evaluating score1 and score3
60     static const float32x4_t const45 = vdupq_n_f32(0.70710678118655f);
61     static const float32x4_t zero    = vdupq_n_f32(0.0f);
62     static const float32x4_t one     = vdupq_n_f32(1.0f);
63     static const float32x4_t two     = vdupq_n_f32(2.0f);
64     static const float32x4_t three   = vdupq_n_f32(3.0f);
65
66     // Score0: (1, 0)
67     const float32x4x2_t score0 =
68     {
69         vabsq_f32(gx.val[0]),
70         vabsq_f32(gx.val[1])
71     };
72
73     // Score2: ( 0, 1 )
74     const float32x4x2_t score2 =
75     {
76         vabsq_f32(gy.val[0]),
77         vabsq_f32(gy.val[1])
78     };
79
80     // Score1 and Score3: ( sqrt(2) / 2, sqrt(2) / 2 ) - ( -sqrt(2) / 2, sqrt(2) / 2 )
81     float32x4x2_t score1 =
82     {
83         vmulq_f32(gy.val[0], const45),
84         vmulq_f32(gy.val[1], const45)
85     };
86
87     float32x4x2_t score3 = score1;
88
89     score1.val[0] = vmlaq_f32(score1.val[0], gx.val[0], const45);
90     score1.val[1] = vmlaq_f32(score1.val[1], gx.val[1], const45);
91     score3.val[0] = vmlsq_f32(score3.val[0], gx.val[0], const45);
92     score3.val[1] = vmlsq_f32(score3.val[1], gx.val[1], const45);
93
94     score1.val[0] = vabsq_f32(score1.val[0]);
95     score1.val[1] = vabsq_f32(score1.val[1]);
96     score3.val[0] = vabsq_f32(score3.val[0]);
97     score3.val[1] = vabsq_f32(score3.val[1]);
98
99     float32x4x2_t phase =
100     {
101         zero,
102         zero
103     };
104
105     float32x4x2_t old_score = score0;
106
107     // score1 > old_score?
108     uint32x4x2_t mask =
109     {
110         vcgtq_f32(score1.val[0], old_score.val[0]),
111         vcgtq_f32(score1.val[1], old_score.val[1])
112     };
113
114     phase.val[0]     = vbslq_f32(mask.val[0], one, phase.val[0]);
115     phase.val[1]     = vbslq_f32(mask.val[1], one, phase.val[1]);
116     old_score.val[0] = vbslq_f32(mask.val[0], score1.val[0], old_score.val[0]);
117     old_score.val[1] = vbslq_f32(mask.val[1], score1.val[1], old_score.val[1]);
118
119     // score2 > old_score?
120     mask.val[0] = vcgtq_f32(score2.val[0], old_score.val[0]);
121     mask.val[1] = vcgtq_f32(score2.val[1], old_score.val[1]);
122
123     phase.val[0]     = vbslq_f32(mask.val[0], two, phase.val[0]);
124     phase.val[1]     = vbslq_f32(mask.val[1], two, phase.val[1]);
125     old_score.val[0] = vbslq_f32(mask.val[0], score2.val[0], old_score.val[0]);
126     old_score.val[1] = vbslq_f32(mask.val[1], score2.val[1], old_score.val[1]);
127
128     // score3 > old_score?
129     mask.val[0] = vcgtq_f32(score3.val[0], old_score.val[0]);
130     mask.val[1] = vcgtq_f32(score3.val[1], old_score.val[1]);
131
132     phase.val[0]     = vbslq_f32(mask.val[0], three, phase.val[0]);
133     phase.val[1]     = vbslq_f32(mask.val[1], three, phase.val[1]);
134     old_score.val[0] = vbslq_f32(mask.val[0], score3.val[0], old_score.val[0]);
135     old_score.val[1] = vbslq_f32(mask.val[1], score3.val[1], old_score.val[1]);
136
137     // Convert from float32x4_t to uint8x8_t
138     return vmovn_u16(vcombine_u16(vmovn_u32(vcvtq_u32_f32(phase.val[0])),
139                                   vmovn_u32(vcvtq_u32_f32(phase.val[1]))));
140 }
141
142 inline uint8x8_t phase_quantization(float16x8_t gx, float16x8_t gy)
143 {
144     // Constant use for evaluating score1 and score3
145     static const float16x8_t const45 = vdupq_n_f16(0.70710678118655f);
146     static const float16x8_t zero    = vdupq_n_f16(0.0f);
147     static const float16x8_t one     = vdupq_n_f16(1.0f);
148     static const float16x8_t two     = vdupq_n_f16(2.0f);
149     static const float16x8_t three   = vdupq_n_f16(3.0f);
150
151     // Score0: (1, 0)
152     const float16x8_t score0 = vabsq_f16(gx);
153
154     // Score2: ( 0, 1 )
155     const float16x8_t score2 = vabsq_f16(gy);
156
157     // Score1 and Score3: ( sqrt(2) / 2, sqrt(2) / 2 ) - ( -sqrt(2) / 2, sqrt(2) / 2 )
158     float16x8_t score1 = vmulq_f16(gy, const45);
159     float16x8_t score3 = score1;
160
161     score1 = vfmaq_f16(score1, gx, const45);
162     score3 = vfmsq_f16(score3, gx, const45);
163
164     score1 = vabsq_f16(score1);
165     score3 = vabsq_f16(score3);
166
167     float16x8_t phase     = zero;
168     float16x8_t old_score = score0;
169
170     // score1 > old_score?
171     uint16x8_t mask = vcgtq_f16(score1, old_score);
172
173     phase     = vbslq_f16(mask, one, phase);
174     old_score = vbslq_f16(mask, score1, old_score);
175
176     // score2 > old_score?
177     mask = vcgtq_f16(score2, old_score);
178
179     phase     = vbslq_f16(mask, two, phase);
180     old_score = vbslq_f16(mask, score2, old_score);
181
182     // score3 > old_score?
183     mask = vcgtq_f16(score3, old_score);
184
185     phase = vbslq_f16(mask, three, phase);
186
187     // Convert from float16x8_t to uint8x8_t
188     return vmovn_u16(vcvtq_u16_f16(phase));
189 }
190
191 /** Computes the gradient phase if gradient_size = 3 or 5. The output is quantized.
192  *         0 = 0°, 1 = 45°, 2 = 90°, 3 = 135°
193  *
194  * @param[in] gx Gx component
195  * @param[in] gy Gy component
196  *
197  * @return quantized phase for 8 pixels
198  */
199 inline uint8x8_t phase_quantization_S16_S16(int16x8_t gx, int16x8_t gy)
200 {
201     return phase_quantization(vcvtq_f16_s16(gx), vcvtq_f16_s16(gy));
202 }
203
204 /** Computes the gradient phase if gradient_size = 7. The output is quantized.
205  *         0 = 0°, 1 = 45°, 2 = 90°, 3 = 135°
206  *
207  * @param[in] gx Gx component
208  * @param[in] gy Gy component
209  *
210  * @return quantized phase for 8 pixels
211  */
212 inline uint8x8_t phase_quantization_S32_S32(const int32x4x2_t &gx, const int32x4x2_t &gy)
213 {
214     // Convert to float
215     const float32x4x2_t gx_f32 =
216     {
217         vcvtq_f32_s32(gx.val[0]),
218         vcvtq_f32_s32(gx.val[1])
219     };
220
221     const float32x4x2_t gy_f32 =
222     {
223         vcvtq_f32_s32(gy.val[0]),
224         vcvtq_f32_s32(gy.val[1])
225     };
226
227     return phase_quantization(gx_f32, gy_f32);
228 }
229
230 /** Computes the magnitude using the L1-norm type if gradient_size = 3 or 5
231  *
232  * @param[in] gx Gx component
233  * @param[in] gy Gy component
234  *
235  * @return magnitude for 8 pixels
236  */
237 inline uint16x8_t mag_l1_S16_S16(int16x8_t gx, int16x8_t gy)
238 {
239     return vaddq_u16(vreinterpretq_u16_s16(vabsq_s16(gx)),
240                      vreinterpretq_u16_s16(vabsq_s16(gy)));
241 }
242
243 /** Computes the magnitude using the L1-norm type if gradient_size = 7
244  *
245  * @param[in] gx Gx component
246  * @param[in] gy Gy component
247  *
248  * @return magnitude for 8 pixels
249  */
250 inline uint32x4x2_t mag_l1_S32_S32(const int32x4x2_t &gx, const int32x4x2_t &gy)
251 {
252     const uint32x4x2_t gx_abs =
253     {
254         vreinterpretq_u32_s32(vabsq_s32(gx.val[0])),
255         vreinterpretq_u32_s32(vabsq_s32(gx.val[1]))
256     };
257
258     const uint32x4x2_t gy_abs =
259     {
260         vreinterpretq_u32_s32(vabsq_s32(gy.val[0])),
261         vreinterpretq_u32_s32(vabsq_s32(gy.val[1]))
262     };
263
264     const uint32x4x2_t out =
265     {
266         vaddq_u32(gx_abs.val[0], gy_abs.val[0]),
267         vaddq_u32(gx_abs.val[1], gy_abs.val[1])
268     };
269
270     return out;
271 }
272
273 inline float32x4x2_t mag_l2(const float32x4x2_t &gx, const float32x4x2_t &gy)
274 {
275     // x^2 ...
276     float32x4x2_t mag =
277     {
278         vmulq_f32(gx.val[0], gx.val[0]),
279         vmulq_f32(gx.val[1], gx.val[1])
280     };
281
282     // ... + y^2
283     mag.val[0] = vmlaq_f32(mag.val[0], gy.val[0], gy.val[0]);
284     mag.val[1] = vmlaq_f32(mag.val[1], gy.val[1], gy.val[1]);
285
286     // sqrt(...)
287     mag.val[0] = vmulq_f32(vrsqrteq_f32(mag.val[0]), mag.val[0]);
288     mag.val[1] = vmulq_f32(vrsqrteq_f32(mag.val[1]), mag.val[1]);
289
290     return mag;
291 }
292
293 inline float16x8_t mag_l2(float16x8_t gx, float16x8_t gy)
294 {
295     // x^2 ...
296     float16x8_t mag = vmulq_f16(gx, gx);
297
298     // ... + y^2
299     mag = vfmaq_f16(mag, gy, gy);
300
301     // sqrt(...)
302     mag = vmulq_f16(vrsqrteq_f16(mag), mag);
303
304     return mag;
305 }
306
307 /** Computes the magnitude using L2-norm if gradient_size = 3 or 5
308  *
309  * @param[in] gx Gx component
310  * @param[in] gy Gy component
311  *
312  * @return magnitude for 8 pixels
313  */
314 inline uint16x8_t mag_l2_S16_S16(int16x8_t gx, int16x8_t gy)
315 {
316     /* Compute magnitude using L2 normalization */
317     const float16x8_t gx2 = vcvtq_f16_s16(gx);
318     const float16x8_t gy2 = vcvtq_f16_s16(gy);
319     const float16x8_t mag = mag_l2(gx2, gy2);
320
321     /* Store magnitude - Convert to uint16x8 */
322     return vcvtq_u16_f16(mag);
323 }
324
325 /** Computes the magnitude using L2-norm if gradient_size = 7
326  *
327  * @param[in] gx Gx component
328  * @param[in] gy Gy component
329  *
330  * @return magnitude for 8 pixels
331  */
332 inline uint32x4x2_t mag_l2_S32_S32(const int32x4x2_t &gx, const int32x4x2_t &gy)
333 {
334     // Compute magnitude using L2 normalization
335     float32x4x2_t gx2 =
336     {
337         vcvtq_f32_s32(gx.val[0]),
338         vcvtq_f32_s32(gx.val[1])
339     };
340
341     float32x4x2_t gy2 =
342     {
343         vcvtq_f32_s32(gy.val[0]),
344         vcvtq_f32_s32(gy.val[1])
345     };
346
347     const float32x4x2_t mag = mag_l2(gx2, gy2);
348     const uint32x4x2_t  mag32 =
349     {
350         vcvtq_u32_f32(mag.val[0]),
351         vcvtq_u32_f32(mag.val[1])
352     };
353
354     return mag32;
355 }
356
357 /** Gradient function used when the gradient size = 3 or 5 and when the norm_type = L1-norm
358  *
359  * @param[in]  in1_ptr  Pointer to source image. Gx image. Data type supported S16
360  * @param[in]  in2_ptr  Pointer to source image. Gy image. Data type supported S16
361  * @param[out] out1_ptr Pointer to destination image. Magnitude. Data type supported U16
362  * @param[out] out2_ptr Pointer to destination image. Quantized phase. Data type supported U8
363  */
364 void mag_phase_l1norm_S16_S16_U16_U8(const void *__restrict in1_ptr, const void *__restrict in2_ptr, void *__restrict out1_ptr, void *__restrict out2_ptr)
365 {
366     const auto in1  = static_cast<const int16_t *__restrict>(in1_ptr);
367     const auto in2  = static_cast<const int16_t *__restrict>(in2_ptr);
368     const auto out1 = static_cast<uint16_t *__restrict>(out1_ptr);
369     const auto out2 = static_cast<uint8_t *__restrict>(out2_ptr);
370
371     const int16x8x4_t gx =
372     {
373         vld1q_s16(in1),
374         vld1q_s16(in1 + 8),
375         vld1q_s16(in1 + 16),
376         vld1q_s16(in1 + 24)
377     };
378
379     const int16x8x4_t gy =
380     {
381         vld1q_s16(in2),
382         vld1q_s16(in2 + 8),
383         vld1q_s16(in2 + 16),
384         vld1q_s16(in2 + 24)
385     };
386
387     // Compute and store phase
388     vst1_u8(out2 + 0, phase_quantization_S16_S16(gx.val[0], gy.val[0]));
389     vst1_u8(out2 + 8, phase_quantization_S16_S16(gx.val[1], gy.val[1]));
390     vst1_u8(out2 + 16, phase_quantization_S16_S16(gx.val[2], gy.val[2]));
391     vst1_u8(out2 + 24, phase_quantization_S16_S16(gx.val[3], gy.val[3]));
392
393     // Compute ans store magnitude using L1 normalization
394     vst1q_u16(out1 + 0, mag_l1_S16_S16(gx.val[0], gy.val[0]));
395     vst1q_u16(out1 + 8, mag_l1_S16_S16(gx.val[1], gy.val[1]));
396     vst1q_u16(out1 + 16, mag_l1_S16_S16(gx.val[2], gy.val[2]));
397     vst1q_u16(out1 + 24, mag_l1_S16_S16(gx.val[3], gy.val[3]));
398 }
399
400 /** Gradient function used when the gradient size = 3 or 5 and when the norm_type = L2-norm
401  *
402  * @param[in]  in1_ptr  Pointer to source image. Gx image. Data type supported S16
403  * @param[in]  in2_ptr  Pointer to source image. Gy image. Data type supported S16
404  * @param[out] out1_ptr Pointer to destination image. Magnitude. Data type supported U16
405  * @param[out] out2_ptr Pointer to destination image. Quantized phase. Data type supported U8
406  */
407 void mag_phase_l2norm_S16_S16_U16_U8(const void *__restrict in1_ptr, const void *__restrict in2_ptr, void *__restrict out1_ptr, void *__restrict out2_ptr)
408 {
409     const auto in1  = static_cast<const int16_t *__restrict>(in1_ptr);
410     const auto in2  = static_cast<const int16_t *__restrict>(in2_ptr);
411     const auto out1 = static_cast<uint16_t *__restrict>(out1_ptr);
412     const auto out2 = static_cast<uint8_t *__restrict>(out2_ptr);
413
414     const int16x8x4_t gx =
415     {
416         vld1q_s16(in1),
417         vld1q_s16(in1 + 8),
418         vld1q_s16(in1 + 16),
419         vld1q_s16(in1 + 24)
420     };
421
422     const int16x8x4_t gy =
423     {
424         vld1q_s16(in2),
425         vld1q_s16(in2 + 8),
426         vld1q_s16(in2 + 16),
427         vld1q_s16(in2 + 24)
428     };
429
430     // Compute and store phase
431     vst1_u8(out2 + 0, phase_quantization_S16_S16(gx.val[0], gy.val[0]));
432     vst1_u8(out2 + 8, phase_quantization_S16_S16(gx.val[1], gy.val[1]));
433     vst1_u8(out2 + 16, phase_quantization_S16_S16(gx.val[2], gy.val[2]));
434     vst1_u8(out2 + 24, phase_quantization_S16_S16(gx.val[3], gy.val[3]));
435
436     // Compute and store magnitude using L2 normalization
437     vst1q_u16(out1 + 0, mag_l2_S16_S16(gx.val[0], gy.val[0]));
438     vst1q_u16(out1 + 8, mag_l2_S16_S16(gx.val[1], gy.val[1]));
439     vst1q_u16(out1 + 16, mag_l2_S16_S16(gx.val[2], gy.val[2]));
440     vst1q_u16(out1 + 24, mag_l2_S16_S16(gx.val[3], gy.val[3]));
441 }
442
443 /** Gradient function used when the gradient size = 7 and when the norm_type = L1-norm
444  *
445  * @param[in]  in1_ptr  Pointer to source image. Gx image. Data type supported S32
446  * @param[in]  in2_ptr  Pointer to source image. Gy image. Data type supported S32
447  * @param[out] out1_ptr Pointer to destination image. Magnitude. Data type supported U32
448  * @param[out] out2_ptr Pointer to destination image. Quantized phase. Data type supported U8
449  */
450 void mag_phase_l1norm_S32_S32_U32_U8(const void *__restrict in1_ptr, const void *__restrict in2_ptr, void *__restrict out1_ptr, void *__restrict out2_ptr)
451 {
452     auto in1  = static_cast<const int32_t *__restrict>(in1_ptr);
453     auto in2  = static_cast<const int32_t *__restrict>(in2_ptr);
454     auto out1 = static_cast<uint32_t *__restrict>(out1_ptr);
455     auto out2 = static_cast<uint8_t *__restrict>(out2_ptr);
456
457     // Process low and high part
458     for(size_t i = 0; i < 2; ++i, in1 += 16, in2 += 16, out1 += 16, out2 += 16)
459     {
460         const int32x4x2_t gx0 =
461         {
462             vld1q_s32(in1 + 0),
463             vld1q_s32(in1 + 4)
464         };
465
466         const int32x4x2_t gx1 =
467         {
468             vld1q_s32(in1 + 8),
469             vld1q_s32(in1 + 12)
470         };
471
472         const int32x4x2_t gy0 =
473         {
474             vld1q_s32(in2 + 0),
475             vld1q_s32(in2 + 4)
476         };
477
478         const int32x4x2_t gy1 =
479         {
480             vld1q_s32(in2 + 8),
481             vld1q_s32(in2 + 12)
482         };
483
484         // Compute and store phase
485         vst1_u8(out2 + 0, phase_quantization_S32_S32(gx0, gy0));
486         vst1_u8(out2 + 8, phase_quantization_S32_S32(gx1, gy1));
487
488         // Compute magnitude using L1 normalization
489         const uint32x4x2_t mag0 = mag_l1_S32_S32(gx0, gy0);
490         const uint32x4x2_t mag1 = mag_l1_S32_S32(gx1, gy1);
491
492         // Store magnitude
493         vst1q_u32(out1 + 0, mag0.val[0]);
494         vst1q_u32(out1 + 4, mag0.val[1]);
495         vst1q_u32(out1 + 8, mag1.val[0]);
496         vst1q_u32(out1 + 12, mag1.val[1]);
497     }
498 }
499
500 /** Gradient function used when the gradient size = 7 and when the norm_type = L2-norm
501  *
502  * @param[in]  in1_ptr  Pointer to source image. Gx image. Data type supported S32
503  * @param[in]  in2_ptr  Pointer to source image. Gy image. Data type supported S32
504  * @param[out] out1_ptr Pointer to destination image. Magnitude. Data type supported U32
505  * @param[out] out2_ptr Pointer to destination image. Quantized phase. Data type supported U8
506  */
507 void mag_phase_l2norm_S32_S32_U32_U8(const void *__restrict in1_ptr, const void *__restrict in2_ptr, void *__restrict out1_ptr, void *__restrict out2_ptr)
508 {
509     auto in1  = static_cast<const int32_t *__restrict>(in1_ptr);
510     auto in2  = static_cast<const int32_t *__restrict>(in2_ptr);
511     auto out1 = static_cast<uint32_t *__restrict>(out1_ptr);
512     auto out2 = static_cast<uint8_t *__restrict>(out2_ptr);
513
514     // Process low and high part
515     for(size_t i = 0; i < 2; ++i, in1 += 16, in2 += 16, out1 += 16, out2 += 16)
516     {
517         const int32x4x2_t gx0 =
518         {
519             vld1q_s32(in1 + 0),
520             vld1q_s32(in1 + 4)
521         };
522
523         const int32x4x2_t gx1 =
524         {
525             vld1q_s32(in1 + 8),
526             vld1q_s32(in1 + 12)
527         };
528
529         const int32x4x2_t gy0 =
530         {
531             vld1q_s32(in2 + 0),
532             vld1q_s32(in2 + 4)
533         };
534
535         const int32x4x2_t gy1 =
536         {
537             vld1q_s32(in2 + 8),
538             vld1q_s32(in2 + 12)
539         };
540
541         // Compute and store phase
542         vst1_u8(out2 + 0, phase_quantization_S32_S32(gx0, gy0));
543         vst1_u8(out2 + 8, phase_quantization_S32_S32(gx1, gy1));
544
545         // Compute magnitude using L2 normalization
546         const uint32x4x2_t mag0 = mag_l2_S32_S32(gx0, gy0);
547         const uint32x4x2_t mag1 = mag_l2_S32_S32(gx1, gy1);
548
549         // Store magnitude
550         vst1q_u32(out1 + 0, mag0.val[0]);
551         vst1q_u32(out1 + 4, mag0.val[1]);
552         vst1q_u32(out1 + 8, mag1.val[0]);
553         vst1q_u32(out1 + 12, mag1.val[1]);
554     }
555 }
556
557 inline uint16x4_t non_max_U32_helper(const uint32_t *in, const uint16x4_t pc, const uint32_t stride_mag, const int32_t lower_thr, const int32_t upper_thr)
558 {
559     // Phase for 4 pixel
560     const uint32x4_t pc32 = vmovl_u16(pc);
561
562     // Get magnitude for 4 pixel
563     uint32x4_t mc = vld1q_u32(in);
564
565     // Angle_quantized: 0 = 0°, 1 = 45°, 2 = 90°, 3 = 135°
566     // 0 degree
567     const uint32x4_t mk0_0 = vld1q_u32(in - 1);
568     const uint32x4_t mk0_1 = vld1q_u32(in + 1);
569     uint32x4_t       mask0 = vceqq_u32(pc32, vdupq_n_u32(0));
570     mask0                  = vandq_u32(mask0, vcgeq_u32(mc, mk0_0));
571     mask0                  = vandq_u32(mask0, vcgeq_u32(mc, mk0_1));
572
573     // 45 degree
574     const uint32x4_t mk45_0 = vld1q_u32(in - stride_mag - 1);
575     const uint32x4_t mk45_1 = vld1q_u32(in + stride_mag + 1);
576     uint32x4_t       mask1  = vceqq_u32(pc32, vdupq_n_u32(1));
577     mask1                   = vandq_u32(mask1, vcgeq_u32(mc, mk45_0));
578     mask1                   = vandq_u32(mask1, vcgeq_u32(mc, mk45_1));
579
580     // 90 degree
581     const uint32x4_t mk90_0 = vld1q_u32(in - stride_mag);
582     const uint32x4_t mk90_1 = vld1q_u32(in + stride_mag);
583     uint32x4_t       mask2  = vceqq_u32(pc32, vdupq_n_u32(2));
584     mask2                   = vandq_u32(mask2, vcgeq_u32(mc, mk90_0));
585     mask2                   = vandq_u32(mask2, vcgeq_u32(mc, mk90_1));
586
587     // 135 degree
588     const uint32x4_t mk135_0 = vld1q_u32(in - stride_mag + 1);
589     const uint32x4_t mk135_1 = vld1q_u32(in + stride_mag - 1);
590     uint32x4_t       mask3   = vceqq_u32(pc32, vdupq_n_u32(3));
591     mask3                    = vandq_u32(mask3, vcgeq_u32(mc, mk135_0));
592     mask3                    = vandq_u32(mask3, vcgeq_u32(mc, mk135_1));
593
594     // Merge masks
595     mask0 = vorrq_u32(mask0, mask1);
596     mask2 = vorrq_u32(mask2, mask3);
597     mask0 = vorrq_u32(mask0, mask2);
598
599     mc = vbslq_u32(mask0, mc, vdupq_n_u32(0));
600
601     // mc > upper_thr
602     mask0 = vcgtq_u32(mc, vdupq_n_u32(upper_thr));
603
604     // mc <= lower_thr
605     mask1 = vcleq_u32(mc, vdupq_n_u32(lower_thr));
606
607     // mc <= upper_thr && mc > lower_thr
608     mask2 = vcleq_u32(mc, vdupq_n_u32(upper_thr));
609     mask2 = vandq_u32(mask2, vcgtq_u32(mc, vdupq_n_u32(lower_thr)));
610
611     mc = vbslq_u32(mask0, vdupq_n_u32(EDGE), mc);
612     mc = vbslq_u32(mask1, vdupq_n_u32(NO_EDGE), mc);
613     mc = vbslq_u32(mask2, vdupq_n_u32(MAYBE), mc);
614
615     return vmovn_u32(mc);
616 }
617
618 /** Computes edge tracing when is called by edge_trace_U8_U8 recursively
619  *
620  * @param[in]  in         Pointer to source image. Data type supported U8
621  * @param[out] out        Pointer to destination image. Data type supported U8
622  * @param[in]  in_stride  Stride of the input image
623  * @param[in]  out_stride Stride of the output image
624  */
625 void edge_trace_recursive_U8_U8(uint8_t *__restrict in, uint8_t *__restrict out, const int32_t in_stride, const int32_t out_stride)
626 {
627     // Look for MAYBE pixels in 8 directions
628     *out = EDGE;
629
630     // (-1, 0)
631     uint8_t pixel = *(in - 1);
632
633     if(pixel == MAYBE)
634     {
635         // Touched a MAYBE point. MAYBE becomes EDGE
636         *(in - 1) = EDGE;
637
638         edge_trace_recursive_U8_U8(in - 1, out - 1, in_stride, out_stride);
639     }
640
641     // (+1, 0)
642     pixel = *(in + 1);
643
644     if(pixel == MAYBE)
645     {
646         // Touched a MAYBE point. MAYBE becomes EDGE
647         *(in + 1) = EDGE;
648
649         edge_trace_recursive_U8_U8(in + 1, out + 1, in_stride, out_stride);
650     }
651
652     in -= in_stride;
653     out -= out_stride;
654
655     // (-1, -1)
656     pixel = *(in - 1);
657
658     if(pixel == MAYBE)
659     {
660         // Touched a MAYBE point. MAYBE becomes EDGE
661         *(in - 1) = EDGE;
662
663         edge_trace_recursive_U8_U8(in - 1, out - 1, in_stride, out_stride);
664     }
665
666     // (0, -1)
667     pixel = *in;
668
669     if(pixel == MAYBE)
670     {
671         // Touched a MAYBE point. MAYBE becomes EDGE
672         *in = EDGE;
673
674         edge_trace_recursive_U8_U8(in, out, in_stride, out_stride);
675     }
676
677     // (+1, -1)
678     pixel = *(in + 1);
679
680     if(pixel == MAYBE)
681     {
682         // Touched a MAYBE point. MAYBE becomes EDGE
683         *(in + 1) = EDGE;
684
685         edge_trace_recursive_U8_U8(in + 1, out + 1, in_stride, out_stride);
686     }
687
688     in += in_stride * 2;
689     out += out_stride * 2;
690
691     // (-1, +1)
692     pixel = *(in - 1);
693
694     if(pixel == MAYBE)
695     {
696         // Touched a MAYBE point. MAYBE becomes EDGE
697         *(in - 1) = EDGE;
698
699         edge_trace_recursive_U8_U8(in - 1, out - 1, in_stride, out_stride);
700     }
701
702     // (0, +1)
703     pixel = *in;
704
705     if(pixel == MAYBE)
706     {
707         // Touched a MAYBE point. MAYBE becomes EDGE
708         *in = EDGE;
709
710         edge_trace_recursive_U8_U8(in, out, in_stride, out_stride);
711     }
712
713     // (+1, +1)
714     pixel = *(in + 1);
715
716     if(pixel == MAYBE)
717     {
718         // Touched a MAYBE point. MAYBE becomes EDGE
719         *(in + 1) = EDGE;
720
721         edge_trace_recursive_U8_U8(in + 1, out + 1, in_stride, out_stride);
722     }
723 }
724 } // namespace fp16
725
726 void NEGradientFP16Kernel::configure(const ITensor *gx, const ITensor *gy, ITensor *magnitude, ITensor *phase, int32_t norm_type)
727 {
728     ARM_COMPUTE_ERROR_ON_NULLPTR(gx, gy, magnitude, phase);
729
730     set_shape_if_empty(*magnitude->info(), gx->info()->tensor_shape());
731     set_shape_if_empty(*phase->info(), gx->info()->tensor_shape());
732
733     Format magnitude_format = gx->info()->data_type() == DataType::S16 ? Format::U16 : Format::U32;
734     set_format_if_unknown(*magnitude->info(), magnitude_format);
735     set_format_if_unknown(*phase->info(), Format::U8);
736
737     ARM_COMPUTE_ERROR_ON_MISMATCHING_SHAPES(gx, gy, magnitude, phase);
738     ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(gx, 1, DataType::S16, DataType::S32);
739     ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(gy, 1, DataType::S16, DataType::S32);
740     ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(magnitude, 1, DataType::U16, DataType::U32);
741     ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(phase, 1, DataType::U8);
742     ARM_COMPUTE_ERROR_ON_MISMATCHING_DATA_TYPES(gx, gy);
743     ARM_COMPUTE_ERROR_ON_MSG(element_size_from_data_type(gx->info()->data_type()) != element_size_from_data_type(magnitude->info()->data_type()), "Magnitude must have the same element size as Gx and Gy");
744
745     _gx        = gx;
746     _gy        = gy;
747     _magnitude = magnitude;
748     _phase     = phase;
749
750     if(_gx->info()->data_type() == DataType::S16)
751     {
752         if(norm_type == 1)
753         {
754             _func = &fp16::mag_phase_l1norm_S16_S16_U16_U8;
755         }
756         else
757         {
758             _func = &fp16::mag_phase_l2norm_S16_S16_U16_U8;
759         }
760     }
761     else
762     {
763         if(norm_type == 1)
764         {
765             _func = &fp16::mag_phase_l1norm_S32_S32_U32_U8;
766         }
767         else
768         {
769             _func = &fp16::mag_phase_l2norm_S32_S32_U32_U8;
770         }
771     }
772
773     constexpr unsigned int num_elems_processed_per_iteration = 32;
774
775     // Configure kernel window
776     Window win = calculate_max_window(*_gx->info(), Steps(num_elems_processed_per_iteration));
777
778     AccessWindowHorizontal gx_access(_gx->info(), 0, num_elems_processed_per_iteration);
779     AccessWindowHorizontal gy_access(_gy->info(), 0, num_elems_processed_per_iteration);
780     AccessWindowHorizontal mag_access(_magnitude->info(), 0, num_elems_processed_per_iteration);
781     AccessWindowHorizontal phase_access(_phase->info(), 0, num_elems_processed_per_iteration);
782
783     update_window_and_padding(win, gx_access, gy_access, mag_access, phase_access);
784
785     mag_access.set_valid_region(win, _gx->info()->valid_region());
786     phase_access.set_valid_region(win, _gx->info()->valid_region());
787
788     INEKernel::configure(win);
789 }
790 #endif /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */
791
792 namespace
793 {
794 inline uint8x8_t phase_quantization(const float32x4x2_t &gx, const float32x4x2_t &gy)
795 {
796     // Constant use for evaluating score1 and score3
797     static const float32x4_t const45 = vdupq_n_f32(0.70710678118655f);
798     static const float32x4_t zero    = vdupq_n_f32(0.0f);
799     static const float32x4_t one     = vdupq_n_f32(1.0f);
800     static const float32x4_t two     = vdupq_n_f32(2.0f);
801     static const float32x4_t three   = vdupq_n_f32(3.0f);
802
803     // Score0: (1, 0)
804     const float32x4x2_t score0 =
805     {
806         {
807             vabsq_f32(gx.val[0]),
808             vabsq_f32(gx.val[1])
809         }
810     };
811
812     // Score2: ( 0, 1 )
813     const float32x4x2_t score2 =
814     {
815         {
816             vabsq_f32(gy.val[0]),
817             vabsq_f32(gy.val[1])
818         }
819     };
820
821     // Score1 and Score3: ( sqrt(2) / 2, sqrt(2) / 2 ) - ( -sqrt(2) / 2, sqrt(2) / 2 )
822     float32x4x2_t score1 =
823     {
824         {
825             vmulq_f32(gy.val[0], const45),
826             vmulq_f32(gy.val[1], const45)
827         }
828     };
829
830     float32x4x2_t score3 = score1;
831
832     score1.val[0] = vmlaq_f32(score1.val[0], gx.val[0], const45);
833     score1.val[1] = vmlaq_f32(score1.val[1], gx.val[1], const45);
834     score3.val[0] = vmlsq_f32(score3.val[0], gx.val[0], const45);
835     score3.val[1] = vmlsq_f32(score3.val[1], gx.val[1], const45);
836
837     score1.val[0] = vabsq_f32(score1.val[0]);
838     score1.val[1] = vabsq_f32(score1.val[1]);
839     score3.val[0] = vabsq_f32(score3.val[0]);
840     score3.val[1] = vabsq_f32(score3.val[1]);
841
842     float32x4x2_t phase =
843     {
844         {
845             zero,
846             zero
847         }
848     };
849
850     float32x4x2_t old_score = score0;
851
852     // score1 > old_score?
853     uint32x4x2_t mask =
854     {
855         {
856             vcgtq_f32(score1.val[0], old_score.val[0]),
857             vcgtq_f32(score1.val[1], old_score.val[1])
858         }
859     };
860
861     phase.val[0]     = vbslq_f32(mask.val[0], one, phase.val[0]);
862     phase.val[1]     = vbslq_f32(mask.val[1], one, phase.val[1]);
863     old_score.val[0] = vbslq_f32(mask.val[0], score1.val[0], old_score.val[0]);
864     old_score.val[1] = vbslq_f32(mask.val[1], score1.val[1], old_score.val[1]);
865
866     // score2 > old_score?
867     mask.val[0] = vcgtq_f32(score2.val[0], old_score.val[0]);
868     mask.val[1] = vcgtq_f32(score2.val[1], old_score.val[1]);
869
870     phase.val[0]     = vbslq_f32(mask.val[0], two, phase.val[0]);
871     phase.val[1]     = vbslq_f32(mask.val[1], two, phase.val[1]);
872     old_score.val[0] = vbslq_f32(mask.val[0], score2.val[0], old_score.val[0]);
873     old_score.val[1] = vbslq_f32(mask.val[1], score2.val[1], old_score.val[1]);
874
875     // score3 > old_score?
876     mask.val[0] = vcgtq_f32(score3.val[0], old_score.val[0]);
877     mask.val[1] = vcgtq_f32(score3.val[1], old_score.val[1]);
878
879     phase.val[0]     = vbslq_f32(mask.val[0], three, phase.val[0]);
880     phase.val[1]     = vbslq_f32(mask.val[1], three, phase.val[1]);
881     old_score.val[0] = vbslq_f32(mask.val[0], score3.val[0], old_score.val[0]);
882     old_score.val[1] = vbslq_f32(mask.val[1], score3.val[1], old_score.val[1]);
883
884     // Convert from float32x4_t to uint8x8_t
885     return vmovn_u16(vcombine_u16(vmovn_u32(vcvtq_u32_f32(phase.val[0])),
886                                   vmovn_u32(vcvtq_u32_f32(phase.val[1]))));
887 }
888
889 /* Computes the gradient phase if gradient_size = 3 or 5. The output is quantized.
890  * 0 = 0°, 1 = 45°, 2 = 90°, 3 = 135°
891  *
892  * @param[in] gx Gx component
893  * @param[in] gy Gy component
894  *
895  * @return quantized phase for 8 pixels
896  */
897 inline uint8x8_t phase_quantization_S16_S16(int16x8_t gx, int16x8_t gy)
898 {
899     // Convert to float
900     const float32x4x2_t gx_f32 =
901     {
902         {
903             vcvtq_f32_s32(vmovl_s16(vget_low_s16(gx))),
904             vcvtq_f32_s32(vmovl_s16(vget_high_s16(gx)))
905         }
906     };
907
908     const float32x4x2_t gy_f32 =
909     {
910         {
911             vcvtq_f32_s32(vmovl_s16(vget_low_s16(gy))),
912             vcvtq_f32_s32(vmovl_s16(vget_high_s16(gy)))
913         }
914     };
915
916     return phase_quantization(gx_f32, gy_f32);
917 }
918
919 /* Computes the gradient phase if gradient_size = 7. The output is quantized.
920  * 0 = 0°, 1 = 45°, 2 = 90°, 3 = 135°
921  *
922  * @param[in] gx Gx component
923  * @param[in] gy Gy component
924  *
925  * @return quantized phase for 8 pixels
926  */
927 inline uint8x8_t phase_quantization_S32_S32(const int32x4x2_t &gx, const int32x4x2_t &gy)
928 {
929     // Convert to float
930     const float32x4x2_t gx_f32 =
931     {
932         {
933             vcvtq_f32_s32(gx.val[0]),
934             vcvtq_f32_s32(gx.val[1])
935         }
936     };
937
938     const float32x4x2_t gy_f32 =
939     {
940         {
941             vcvtq_f32_s32(gy.val[0]),
942             vcvtq_f32_s32(gy.val[1])
943         }
944     };
945
946     return phase_quantization(gx_f32, gy_f32);
947 }
948
949 /* Computes the magnitude using the L1-norm type if gradient_size = 3 or 5
950  *
951  * @param[in] gx Gx component
952  * @param[in] gy Gy component
953  *
954  * @return magnitude for 8 pixels
955  */
956 inline uint16x8_t mag_l1_S16_S16(int16x8_t gx, int16x8_t gy)
957 {
958     return vaddq_u16(vreinterpretq_u16_s16(vabsq_s16(gx)),
959                      vreinterpretq_u16_s16(vabsq_s16(gy)));
960 }
961
962 /* Computes the magnitude using the L1-norm type if gradient_size = 7
963  *
964  * @param[in] gx Gx component
965  * @param[in] gy Gy component
966  *
967  * @return magnitude for 8 pixels
968  */
969 inline uint32x4x2_t mag_l1_S32_S32(const int32x4x2_t &gx, const int32x4x2_t &gy)
970 {
971     const uint32x4x2_t gx_abs =
972     {
973         {
974             vreinterpretq_u32_s32(vabsq_s32(gx.val[0])),
975             vreinterpretq_u32_s32(vabsq_s32(gx.val[1]))
976         }
977     };
978
979     const uint32x4x2_t gy_abs =
980     {
981         {
982             vreinterpretq_u32_s32(vabsq_s32(gy.val[0])),
983             vreinterpretq_u32_s32(vabsq_s32(gy.val[1]))
984         }
985     };
986
987     const uint32x4x2_t output =
988     {
989         {
990             vaddq_u32(gx_abs.val[0], gy_abs.val[0]),
991             vaddq_u32(gx_abs.val[1], gy_abs.val[1])
992         }
993     };
994
995     return output;
996 }
997
998 inline float32x4x2_t mag_l2(const float32x4x2_t &gx, const float32x4x2_t &gy)
999 {
1000     // x^2 ...
1001     float32x4x2_t magnitude =
1002     {
1003         {
1004             vmulq_f32(gx.val[0], gx.val[0]),
1005             vmulq_f32(gx.val[1], gx.val[1])
1006         }
1007     };
1008
1009     // ... + y^2
1010     magnitude.val[0] = vmlaq_f32(magnitude.val[0], gy.val[0], gy.val[0]);
1011     magnitude.val[1] = vmlaq_f32(magnitude.val[1], gy.val[1], gy.val[1]);
1012
1013     // sqrt(...)
1014     magnitude.val[0] = vmulq_f32(vrsqrteq_f32(magnitude.val[0]), magnitude.val[0]);
1015     magnitude.val[1] = vmulq_f32(vrsqrteq_f32(magnitude.val[1]), magnitude.val[1]);
1016
1017     return magnitude;
1018 }
1019
1020 /* Computes the magnitude using L2-norm if gradient_size = 3 or 5
1021  *
1022  * @param[in] gx Gx component
1023  * @param[in] gy Gy component
1024  *
1025  * @return magnitude for 8 pixels
1026  */
1027 inline uint16x8_t mag_l2_S16_S16(int16x8_t gx, int16x8_t gy)
1028 {
1029     // Compute magnitude using L2 normalization
1030     const float32x4x2_t gx2 =
1031     {
1032         {
1033             vcvtq_f32_s32(vmovl_s16(vget_low_s16(gx))),
1034             vcvtq_f32_s32(vmovl_s16(vget_high_s16(gx)))
1035         }
1036     };
1037
1038     const float32x4x2_t gy2 =
1039     {
1040         {
1041             vcvtq_f32_s32(vmovl_s16(vget_low_s16(gy))),
1042             vcvtq_f32_s32(vmovl_s16(vget_high_s16(gy)))
1043         }
1044     };
1045
1046     const float32x4x2_t magnitude = mag_l2(gx2, gy2);
1047
1048     // Store magnitude - Convert to uint16x8
1049     return vcombine_u16(vmovn_u32(vcvtq_u32_f32(magnitude.val[0])),
1050                         vmovn_u32(vcvtq_u32_f32(magnitude.val[1])));
1051 }
1052
1053 /* Computes the magnitude using L2-norm if gradient_size = 7
1054  *
1055  * @param[in] gx Gx component
1056  * @param[in] gy Gy component
1057  *
1058  * @return magnitude for 8 pixels
1059  */
1060 inline uint32x4x2_t mag_l2_S32_S32(const int32x4x2_t &gx, const int32x4x2_t &gy)
1061 {
1062     // Compute magnitude using L2 normalization
1063     float32x4x2_t gx2 =
1064     {
1065         {
1066             vcvtq_f32_s32(gx.val[0]),
1067             vcvtq_f32_s32(gx.val[1])
1068         }
1069     };
1070
1071     float32x4x2_t gy2 =
1072     {
1073         {
1074             vcvtq_f32_s32(gy.val[0]),
1075             vcvtq_f32_s32(gy.val[1])
1076         }
1077     };
1078
1079     const float32x4x2_t magnitude = mag_l2(gx2, gy2);
1080     const uint32x4x2_t  mag32 =
1081     {
1082         {
1083             vcvtq_u32_f32(magnitude.val[0]),
1084             vcvtq_u32_f32(magnitude.val[1])
1085         }
1086     };
1087
1088     return mag32;
1089 }
1090
1091 /* Gradient function used when the gradient size = 3 or 5 and when the norm_type = L1-norm
1092  *
1093  * @param[in]  gx_ptr        Pointer to source image. Gx image. Data type supported S16
1094  * @param[in]  gy_ptr        Pointer to source image. Gy image. Data type supported S16
1095  * @param[out] magnitude_ptr Pointer to destination image. Magnitude. Data type supported U16
1096  * @param[out] phase_ptr     Pointer to destination image. Quantized phase. Data type supported U8
1097  */
1098 void mag_phase_l1norm_S16_S16_U16_U8(const void *__restrict gx_ptr, const void *__restrict gy_ptr, void *__restrict magnitude_ptr, void *__restrict phase_ptr)
1099 {
1100     const auto gx        = static_cast<const int16_t *__restrict>(gx_ptr);
1101     const auto gy        = static_cast<const int16_t *__restrict>(gy_ptr);
1102     const auto magnitude = static_cast<uint16_t *__restrict>(magnitude_ptr);
1103     const auto phase     = static_cast<uint8_t *__restrict>(phase_ptr);
1104
1105     const int16x8x4_t gx_val =
1106     {
1107         {
1108             vld1q_s16(gx),
1109             vld1q_s16(gx + 8),
1110             vld1q_s16(gx + 16),
1111             vld1q_s16(gx + 24)
1112         }
1113     };
1114
1115     const int16x8x4_t gy_val =
1116     {
1117         {
1118             vld1q_s16(gy),
1119             vld1q_s16(gy + 8),
1120             vld1q_s16(gy + 16),
1121             vld1q_s16(gy + 24)
1122         }
1123     };
1124
1125     // Compute and store phase
1126     vst1_u8(phase + 0, phase_quantization_S16_S16(gx_val.val[0], gy_val.val[0]));
1127     vst1_u8(phase + 8, phase_quantization_S16_S16(gx_val.val[1], gy_val.val[1]));
1128     vst1_u8(phase + 16, phase_quantization_S16_S16(gx_val.val[2], gy_val.val[2]));
1129     vst1_u8(phase + 24, phase_quantization_S16_S16(gx_val.val[3], gy_val.val[3]));
1130
1131     // Compute ans store magnitude using L1 normalization
1132     vst1q_u16(magnitude + 0, mag_l1_S16_S16(gx_val.val[0], gy_val.val[0]));
1133     vst1q_u16(magnitude + 8, mag_l1_S16_S16(gx_val.val[1], gy_val.val[1]));
1134     vst1q_u16(magnitude + 16, mag_l1_S16_S16(gx_val.val[2], gy_val.val[2]));
1135     vst1q_u16(magnitude + 24, mag_l1_S16_S16(gx_val.val[3], gy_val.val[3]));
1136 }
1137
1138 /* Gradient function used when the gradient size = 3 or 5 and when the norm_type = L2-norm
1139  *
1140  * @param[in]  gx_ptr        Pointer to source image. Gx image. Data type supported S16
1141  * @param[in]  gy_ptr        Pointer to source image. Gy image. Data type supported S16
1142  * @param[out] magnitude_ptr Pointer to destination image. Magnitude. Data type supported U16
1143  * @param[out] phase_ptr     Pointer to destination image. Quantized phase. Data type supported U8
1144  */
1145 void mag_phase_l2norm_S16_S16_U16_U8(const void *__restrict gx_ptr, const void *__restrict gy_ptr, void *__restrict magnitude_ptr, void *__restrict phase_ptr)
1146 {
1147     const auto gx        = static_cast<const int16_t *__restrict>(gx_ptr);
1148     const auto gy        = static_cast<const int16_t *__restrict>(gy_ptr);
1149     const auto magnitude = static_cast<uint16_t *__restrict>(magnitude_ptr);
1150     const auto phase     = static_cast<uint8_t *__restrict>(phase_ptr);
1151
1152     const int16x8x4_t gx_val =
1153     {
1154         {
1155             vld1q_s16(gx),
1156             vld1q_s16(gx + 8),
1157             vld1q_s16(gx + 16),
1158             vld1q_s16(gx + 24)
1159         }
1160     };
1161
1162     const int16x8x4_t gy_val =
1163     {
1164         {
1165             vld1q_s16(gy),
1166             vld1q_s16(gy + 8),
1167             vld1q_s16(gy + 16),
1168             vld1q_s16(gy + 24)
1169         }
1170     };
1171
1172     // Compute and store phase
1173     vst1_u8(phase + 0, phase_quantization_S16_S16(gx_val.val[0], gy_val.val[0]));
1174     vst1_u8(phase + 8, phase_quantization_S16_S16(gx_val.val[1], gy_val.val[1]));
1175     vst1_u8(phase + 16, phase_quantization_S16_S16(gx_val.val[2], gy_val.val[2]));
1176     vst1_u8(phase + 24, phase_quantization_S16_S16(gx_val.val[3], gy_val.val[3]));
1177
1178     // Compute and store magnitude using L2 normalization
1179     vst1q_u16(magnitude + 0, mag_l2_S16_S16(gx_val.val[0], gy_val.val[0]));
1180     vst1q_u16(magnitude + 8, mag_l2_S16_S16(gx_val.val[1], gy_val.val[1]));
1181     vst1q_u16(magnitude + 16, mag_l2_S16_S16(gx_val.val[2], gy_val.val[2]));
1182     vst1q_u16(magnitude + 24, mag_l2_S16_S16(gx_val.val[3], gy_val.val[3]));
1183 }
1184
1185 /* Gradient function used when the gradient size = 7 and when the norm_type = L1-norm
1186  *
1187  * @param[in]  gx_ptr        Pointer to source image. Gx image. Data type supported S32
1188  * @param[in]  gy_ptr        Pointer to source image. Gy image. Data type supported S32
1189  * @param[out] magnitude_ptr Pointer to destination image. Magnitude. Data type supported U32
1190  * @param[out] phase_ptr     Pointer to destination image. Quantized phase. Data type support U8
1191  */
1192 void mag_phase_l1norm_S32_S32_U32_U8(const void *__restrict gx_ptr, const void *__restrict gy_ptr, void *__restrict magnitude_ptr, void *__restrict phase_ptr)
1193 {
1194     auto gx        = static_cast<const int32_t *__restrict>(gx_ptr);
1195     auto gy        = static_cast<const int32_t *__restrict>(gy_ptr);
1196     auto magnitude = static_cast<uint32_t *__restrict>(magnitude_ptr);
1197     auto phase     = static_cast<uint8_t *__restrict>(phase_ptr);
1198
1199     // Process low and high part
1200     for(size_t i = 0; i < 2; ++i, gx += 16, gy += 16, magnitude += 16, phase += 16)
1201     {
1202         const int32x4x2_t gx0 =
1203         {
1204             {
1205                 vld1q_s32(gx + 0),
1206                 vld1q_s32(gx + 4)
1207             }
1208         };
1209
1210         const int32x4x2_t gx1 =
1211         {
1212             {
1213                 vld1q_s32(gx + 8),
1214                 vld1q_s32(gx + 12)
1215             }
1216         };
1217
1218         const int32x4x2_t gy0 =
1219         {
1220             {
1221                 vld1q_s32(gy + 0),
1222                 vld1q_s32(gy + 4)
1223             }
1224         };
1225
1226         const int32x4x2_t gy1 =
1227         {
1228             {
1229                 vld1q_s32(gy + 8),
1230                 vld1q_s32(gy + 12)
1231             }
1232         };
1233
1234         // Compute and store phase
1235         vst1_u8(phase + 0, phase_quantization_S32_S32(gx0, gy0));
1236         vst1_u8(phase + 8, phase_quantization_S32_S32(gx1, gy1));
1237
1238         // Compute magnitude using L1 normalization
1239         const uint32x4x2_t mag0 = mag_l1_S32_S32(gx0, gy0);
1240         const uint32x4x2_t mag1 = mag_l1_S32_S32(gx1, gy1);
1241
1242         // Store magnitude
1243         vst1q_u32(magnitude + 0, mag0.val[0]);
1244         vst1q_u32(magnitude + 4, mag0.val[1]);
1245         vst1q_u32(magnitude + 8, mag1.val[0]);
1246         vst1q_u32(magnitude + 12, mag1.val[1]);
1247     }
1248 }
1249
1250 /* Gradient function used when the gradient size = 7 and when the norm_type = L2-norm
1251  *
1252  * @param[in]  gx_ptr        Pointer to source image. Gx image. Data type supported S32
1253  * @param[in]  gy_ptr        Pointer to source image. Gy image. Data type supported S32
1254  * @param[out] magnitude_ptr Pointer to destination image. Magnitude. Data type supported U32
1255  * @param[out] phase_ptr     Pointer to destination image. Quantized phase. Data type supported U8
1256  */
1257 void mag_phase_l2norm_S32_S32_U32_U8(const void *__restrict gx_ptr, const void *__restrict gy_ptr, void *__restrict magnitude_ptr, void *__restrict phase_ptr)
1258 {
1259     auto gx        = static_cast<const int32_t *__restrict>(gx_ptr);
1260     auto gy        = static_cast<const int32_t *__restrict>(gy_ptr);
1261     auto magnitude = static_cast<uint32_t *__restrict>(magnitude_ptr);
1262     auto phase     = static_cast<uint8_t *__restrict>(phase_ptr);
1263
1264     // Process low and high part
1265     for(size_t i = 0; i < 2; ++i, gx += 16, gy += 16, magnitude += 16, phase += 16)
1266     {
1267         const int32x4x2_t gx0 =
1268         {
1269             {
1270                 vld1q_s32(gx + 0),
1271                 vld1q_s32(gx + 4)
1272             }
1273         };
1274
1275         const int32x4x2_t gx1 =
1276         {
1277             {
1278                 vld1q_s32(gx + 8),
1279                 vld1q_s32(gx + 12)
1280             }
1281         };
1282
1283         const int32x4x2_t gy0 =
1284         {
1285             {
1286                 vld1q_s32(gy + 0),
1287                 vld1q_s32(gy + 4)
1288             }
1289         };
1290
1291         const int32x4x2_t gy1 =
1292         {
1293             {
1294                 vld1q_s32(gy + 8),
1295                 vld1q_s32(gy + 12)
1296             }
1297         };
1298
1299         // Compute and store phase
1300         vst1_u8(phase + 0, phase_quantization_S32_S32(gx0, gy0));
1301         vst1_u8(phase + 8, phase_quantization_S32_S32(gx1, gy1));
1302
1303         // Compute magnitude using L2 normalization
1304         const uint32x4x2_t mag0 = mag_l2_S32_S32(gx0, gy0);
1305         const uint32x4x2_t mag1 = mag_l2_S32_S32(gx1, gy1);
1306
1307         // Store magnitude
1308         vst1q_u32(magnitude + 0, mag0.val[0]);
1309         vst1q_u32(magnitude + 4, mag0.val[1]);
1310         vst1q_u32(magnitude + 8, mag1.val[0]);
1311         vst1q_u32(magnitude + 12, mag1.val[1]);
1312     }
1313 }
1314
1315 /* Computes non-maxima suppression and hysteresis when the gradient size = 3 or 5
1316  *
1317  * @param[in]  magnitude_ptr Pointer to source image. Magnitude. Data type supported U16
1318  * @param[in]  phase_ptr     Pointer to source image. Quantized phase. Data type supported U8
1319  * @param[out] output_ptr    Pointer to output image. Data type supported U8
1320  * @param[in]  stride_mag    Stride of magnitude image
1321  * @param[in]  lower_thr     Lower threshold used for the hysteresis
1322  * @param[in]  upper_thr     Upper threshold used for the hysteresis
1323  */
1324 void non_max_suppression_U16_U8_U8(const void *__restrict magnitude_ptr, const void *__restrict phase_ptr, void *__restrict output_ptr, const uint32_t stride_mag, const int32_t lower_thr,
1325                                    const int32_t upper_thr)
1326 {
1327     const auto magnitude = static_cast<const uint16_t *__restrict>(magnitude_ptr);
1328     const auto phase     = static_cast<const uint8_t *__restrict>(phase_ptr);
1329     const auto output    = static_cast<uint8_t *__restrict>(output_ptr);
1330
1331     // Get magnitude and phase of the centre pixels
1332     uint16x8_t mc = vld1q_u16(magnitude);
1333
1334     // Angle_quantized: 0 = 0°, 1 = 45°, 2 = 90°, 3 = 135°
1335     const uint16x8_t pc16 = vmovl_u8(vld1_u8(phase));
1336
1337     // 0 degree
1338     const uint16x8_t mk0_0 = vld1q_u16(magnitude - 1);
1339     const uint16x8_t mk0_1 = vld1q_u16(magnitude + 1);
1340     uint16x8_t       mask0 = vceqq_u16(pc16, vdupq_n_u16(0));
1341     mask0                  = vandq_u16(mask0, vcgeq_u16(mc, mk0_0));
1342     mask0                  = vandq_u16(mask0, vcgeq_u16(mc, mk0_1));
1343
1344     // 45 degree
1345     const uint16x8_t mk45_0 = vld1q_u16(magnitude - stride_mag - 1);
1346     const uint16x8_t mk45_1 = vld1q_u16(magnitude + stride_mag + 1);
1347     uint16x8_t       mask1  = vceqq_u16(pc16, vdupq_n_u16(1));
1348     mask1                   = vandq_u16(mask1, vcgeq_u16(mc, mk45_0));
1349     mask1                   = vandq_u16(mask1, vcgeq_u16(mc, mk45_1));
1350
1351     // 90 degree
1352     const uint16x8_t mk90_0 = vld1q_u16(magnitude - stride_mag);
1353     const uint16x8_t mk90_1 = vld1q_u16(magnitude + stride_mag);
1354     uint16x8_t       mask2  = vceqq_u16(pc16, vdupq_n_u16(2));
1355     mask2                   = vandq_u16(mask2, vcgeq_u16(mc, mk90_0));
1356     mask2                   = vandq_u16(mask2, vcgeq_u16(mc, mk90_1));
1357
1358     // 135 degree
1359     const uint16x8_t mk135_0 = vld1q_u16(magnitude - stride_mag + 1);
1360     const uint16x8_t mk135_1 = vld1q_u16(magnitude + stride_mag - 1);
1361     uint16x8_t       mask3   = vceqq_u16(pc16, vdupq_n_u16(3));
1362     mask3                    = vandq_u16(mask3, vcgeq_u16(mc, mk135_0));
1363     mask3                    = vandq_u16(mask3, vcgeq_u16(mc, mk135_1));
1364
1365     // Merge masks
1366     mask0 = vorrq_u16(mask0, mask1);
1367     mask2 = vorrq_u16(mask2, mask3);
1368     mask0 = vorrq_u16(mask0, mask2);
1369
1370     mc = vbslq_u16(mask0, mc, vdupq_n_u16(0));
1371
1372     // mc > upper_thr
1373     mask0 = vcgtq_u16(mc, vdupq_n_u16(upper_thr));
1374
1375     // mc <= lower_thr
1376     mask1 = vcleq_u16(mc, vdupq_n_u16(lower_thr));
1377
1378     // mc <= upper_thr && mc > lower_thr
1379     mask2 = vcleq_u16(mc, vdupq_n_u16(upper_thr));
1380     mask2 = vandq_u16(mask2, vcgtq_u16(mc, vdupq_n_u16(lower_thr)));
1381
1382     mc = vbslq_u16(mask0, vdupq_n_u16(EDGE), mc);
1383     mc = vbslq_u16(mask1, vdupq_n_u16(NO_EDGE), mc);
1384     mc = vbslq_u16(mask2, vdupq_n_u16(MAYBE), mc);
1385
1386     vst1_u8(output, vmovn_u16(mc));
1387 }
1388
1389 inline uint16x4_t non_max_U32_helper(const uint32_t *input, const uint16x4_t pc, const uint32_t stride_mag, const int32_t lower_thr, const int32_t upper_thr)
1390 {
1391     // Phase for 4 pixel
1392     const uint32x4_t pc32 = vmovl_u16(pc);
1393
1394     // Get magnitude for 4 pixel
1395     uint32x4_t mc = vld1q_u32(input);
1396
1397     // Angle_quantized: 0 = 0°, 1 = 45°, 2 = 90°, 3 = 135°
1398     // 0 degree
1399     const uint32x4_t mk0_0 = vld1q_u32(input - 1);
1400     const uint32x4_t mk0_1 = vld1q_u32(input + 1);
1401     uint32x4_t       mask0 = vceqq_u32(pc32, vdupq_n_u32(0));
1402     mask0                  = vandq_u32(mask0, vcgeq_u32(mc, mk0_0));
1403     mask0                  = vandq_u32(mask0, vcgeq_u32(mc, mk0_1));
1404
1405     // 45 degree
1406     const uint32x4_t mk45_0 = vld1q_u32(input - stride_mag - 1);
1407     const uint32x4_t mk45_1 = vld1q_u32(input + stride_mag + 1);
1408     uint32x4_t       mask1  = vceqq_u32(pc32, vdupq_n_u32(1));
1409     mask1                   = vandq_u32(mask1, vcgeq_u32(mc, mk45_0));
1410     mask1                   = vandq_u32(mask1, vcgeq_u32(mc, mk45_1));
1411
1412     // 90 degree
1413     const uint32x4_t mk90_0 = vld1q_u32(input - stride_mag);
1414     const uint32x4_t mk90_1 = vld1q_u32(input + stride_mag);
1415     uint32x4_t       mask2  = vceqq_u32(pc32, vdupq_n_u32(2));
1416     mask2                   = vandq_u32(mask2, vcgeq_u32(mc, mk90_0));
1417     mask2                   = vandq_u32(mask2, vcgeq_u32(mc, mk90_1));
1418
1419     // 135 degree
1420     const uint32x4_t mk135_0 = vld1q_u32(input - stride_mag + 1);
1421     const uint32x4_t mk135_1 = vld1q_u32(input + stride_mag - 1);
1422     uint32x4_t       mask3   = vceqq_u32(pc32, vdupq_n_u32(3));
1423     mask3                    = vandq_u32(mask3, vcgeq_u32(mc, mk135_0));
1424     mask3                    = vandq_u32(mask3, vcgeq_u32(mc, mk135_1));
1425
1426     // Merge masks
1427     mask0 = vorrq_u32(mask0, mask1);
1428     mask2 = vorrq_u32(mask2, mask3);
1429     mask0 = vorrq_u32(mask0, mask2);
1430
1431     mc = vbslq_u32(mask0, mc, vdupq_n_u32(0));
1432
1433     // mc > upper_thr
1434     mask0 = vcgtq_u32(mc, vdupq_n_u32(upper_thr));
1435
1436     // mc <= lower_thr
1437     mask1 = vcleq_u32(mc, vdupq_n_u32(lower_thr));
1438
1439     // mc <= upper_thr && mc > lower_thr
1440     mask2 = vcleq_u32(mc, vdupq_n_u32(upper_thr));
1441     mask2 = vandq_u32(mask2, vcgtq_u32(mc, vdupq_n_u32(lower_thr)));
1442
1443     mc = vbslq_u32(mask0, vdupq_n_u32(EDGE), mc);
1444     mc = vbslq_u32(mask1, vdupq_n_u32(NO_EDGE), mc);
1445     mc = vbslq_u32(mask2, vdupq_n_u32(MAYBE), mc);
1446
1447     return vmovn_u32(mc);
1448 }
1449
1450 /* Computes non-maxima suppression and hysteresis when the gradient_size = 7
1451  *
1452  * @param[in]  magnitude_ptr Pointer to source image. Magnitude. Data type supported U32
1453  * @param[in]  phase_ptr     Pointer to source image. Quantized phase. Data type supported U8
1454  * @param[out] output_ptr    Pointer to destination image. Data type supported U8
1455  * @param[in]  stride_mag    Stride of magnitude image
1456  * @param[in]  lower_thr     Lower threshold used for the hysteresis
1457  * @param[in]  upper_thr     Upper threshold used for the hysteresis
1458  */
1459 void non_max_suppression_U32_U8_U8(const void *__restrict magnitude_ptr, const void *__restrict phase_ptr, void *__restrict output_ptr, const uint32_t stride_mag, const int32_t lower_thr,
1460                                    const int32_t upper_thr)
1461 {
1462     const auto magnitude = static_cast<const uint32_t *__restrict>(magnitude_ptr);
1463     const auto phase     = static_cast<const uint8_t *__restrict>(phase_ptr);
1464     const auto output    = static_cast<uint8_t *__restrict>(output_ptr);
1465
1466     // Get phase for 8 pixel
1467     const uint16x8_t pc16 = vmovl_u8(vld1_u8(phase));
1468
1469     // Compute non maxima suppression
1470     const uint16x4x2_t res =
1471     {
1472         {
1473             non_max_U32_helper(magnitude, vget_low_u16(pc16), stride_mag, lower_thr, upper_thr),
1474             non_max_U32_helper(magnitude + 4, vget_high_u16(pc16), stride_mag, lower_thr, upper_thr)
1475         }
1476     };
1477
1478     // Store result
1479     vst1_u8(output, vmovn_u16(vcombine_u16(res.val[0], res.val[1])));
1480 }
1481
1482 /* Computes edge tracing when is called by edge_trace_U8_U8 recursively
1483  *
1484  * @param[in]  input         Pointer to source image. Data type supported U8
1485  * @param[out] output        Pointer to destination image. Data type supported U8
1486  * @param[in]  input_stride  Stride of the input image
1487  * @param[in]  output_stride Stride of the output image
1488  */
1489 void edge_trace_recursive_U8_U8(uint8_t *__restrict input, uint8_t *__restrict output, const int32_t input_stride, const int32_t output_stride)
1490 {
1491     // Look for MAYBE pixels in 8 directions
1492     *output = EDGE;
1493
1494     // (-1, 0)
1495     uint8_t pixel = *(input - 1);
1496
1497     if(pixel == MAYBE)
1498     {
1499         // Touched a MAYBE point. MAYBE becomes EDGE
1500         *(input - 1) = EDGE;
1501
1502         edge_trace_recursive_U8_U8(input - 1, output - 1, input_stride, output_stride);
1503     }
1504
1505     // (+1, 0)
1506     pixel = *(input + 1);
1507
1508     if(pixel == MAYBE)
1509     {
1510         // Touched a MAYBE point. MAYBE becomes EDGE
1511         *(input + 1) = EDGE;
1512
1513         edge_trace_recursive_U8_U8(input + 1, output + 1, input_stride, output_stride);
1514     }
1515
1516     input -= input_stride;
1517     output -= output_stride;
1518
1519     // (-1, -1)
1520     pixel = *(input - 1);
1521
1522     if(pixel == MAYBE)
1523     {
1524         // Touched a MAYBE point. MAYBE becomes EDGE
1525         *(input - 1) = EDGE;
1526
1527         edge_trace_recursive_U8_U8(input - 1, output - 1, input_stride, output_stride);
1528     }
1529
1530     // (0, -1)
1531     pixel = *input;
1532
1533     if(pixel == MAYBE)
1534     {
1535         // Touched a MAYBE point. MAYBE becomes EDGE
1536         *input = EDGE;
1537
1538         edge_trace_recursive_U8_U8(input, output, input_stride, output_stride);
1539     }
1540
1541     // (+1, -1)
1542     pixel = *(input + 1);
1543
1544     if(pixel == MAYBE)
1545     {
1546         // Touched a MAYBE point. MAYBE becomes EDGE
1547         *(input + 1) = EDGE;
1548
1549         edge_trace_recursive_U8_U8(input + 1, output + 1, input_stride, output_stride);
1550     }
1551
1552     input += input_stride * 2;
1553     output += output_stride * 2;
1554
1555     // (-1, +1)
1556     pixel = *(input - 1);
1557
1558     if(pixel == MAYBE)
1559     {
1560         // Touched a MAYBE point. MAYBE becomes EDGE
1561         *(input - 1) = EDGE;
1562
1563         edge_trace_recursive_U8_U8(input - 1, output - 1, input_stride, output_stride);
1564     }
1565
1566     // (0, +1)
1567     pixel = *input;
1568
1569     if(pixel == MAYBE)
1570     {
1571         // Touched a MAYBE point. MAYBE becomes EDGE
1572         *input = EDGE;
1573
1574         edge_trace_recursive_U8_U8(input, output, input_stride, output_stride);
1575     }
1576
1577     // (+1, +1)
1578     pixel = *(input + 1);
1579
1580     if(pixel == MAYBE)
1581     {
1582         // Touched a MAYBE point. MAYBE becomes EDGE
1583         *(input + 1) = EDGE;
1584
1585         edge_trace_recursive_U8_U8(input + 1, output + 1, input_stride, output_stride);
1586     }
1587 }
1588
1589 /* Computes edge tracing
1590  *
1591  * @param[in]  input         Pointer to source image. Data type supported U8
1592  * @param[out] output        Pointer to destination image. Data type supported U8
1593  * @param[in]  input_stride  Stride of the input image
1594  * @param[in]  output_stride Stride of the output image
1595  */
1596 void edge_trace_U8_U8(uint8_t *__restrict input, uint8_t *__restrict output, const int32_t input_stride, const int32_t output_stride)
1597 {
1598     if(*input == NO_EDGE)
1599     {
1600         *output = NO_EDGE;
1601     }
1602     // Check if EDGE and not yet touched
1603     else if((*input == EDGE) && (*output == NO_EDGE))
1604     {
1605         edge_trace_recursive_U8_U8(input, output, input_stride, output_stride);
1606     }
1607 }
1608 } // namespace
1609
1610 NEGradientKernel::NEGradientKernel()
1611     : _func(nullptr), _gx(nullptr), _gy(nullptr), _magnitude(nullptr), _phase(nullptr)
1612 {
1613 }
1614
1615 void NEGradientKernel::configure(const ITensor *gx, const ITensor *gy, ITensor *magnitude, ITensor *phase, int32_t norm_type)
1616 {
1617     ARM_COMPUTE_ERROR_ON_NULLPTR(gx, gy, magnitude, phase);
1618
1619     set_shape_if_empty(*magnitude->info(), gx->info()->tensor_shape());
1620     set_shape_if_empty(*phase->info(), gx->info()->tensor_shape());
1621
1622     Format magnitude_format = gx->info()->data_type() == DataType::S16 ? Format::U16 : Format::U32;
1623     set_format_if_unknown(*magnitude->info(), magnitude_format);
1624     set_format_if_unknown(*phase->info(), Format::U8);
1625
1626     ARM_COMPUTE_ERROR_ON_MISMATCHING_SHAPES(gx, gy, magnitude, phase);
1627     ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(gx, 1, DataType::S16, DataType::S32);
1628     ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(gy, 1, DataType::S16, DataType::S32);
1629     ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(magnitude, 1, DataType::U16, DataType::U32);
1630     ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(phase, 1, DataType::U8);
1631     ARM_COMPUTE_ERROR_ON_MISMATCHING_DATA_TYPES(gx, gy);
1632     ARM_COMPUTE_ERROR_ON_MSG(element_size_from_data_type(gx->info()->data_type()) != element_size_from_data_type(magnitude->info()->data_type()), "Magnitude must have the same element size as Gx and Gy");
1633
1634     _gx        = gx;
1635     _gy        = gy;
1636     _magnitude = magnitude;
1637     _phase     = phase;
1638
1639     if(_gx->info()->data_type() == DataType::S16)
1640     {
1641         if(norm_type == 1)
1642         {
1643             _func = &mag_phase_l1norm_S16_S16_U16_U8;
1644         }
1645         else
1646         {
1647             _func = &mag_phase_l2norm_S16_S16_U16_U8;
1648         }
1649     }
1650     else
1651     {
1652         if(norm_type == 1)
1653         {
1654             _func = &mag_phase_l1norm_S32_S32_U32_U8;
1655         }
1656         else
1657         {
1658             _func = &mag_phase_l2norm_S32_S32_U32_U8;
1659         }
1660     }
1661
1662     constexpr unsigned int num_elems_processed_per_iteration = 32;
1663
1664     // Configure kernel window
1665     Window win = calculate_max_window(*_gx->info(), Steps(num_elems_processed_per_iteration));
1666
1667     AccessWindowHorizontal gx_access(_gx->info(), 0, num_elems_processed_per_iteration);
1668     AccessWindowHorizontal gy_access(_gy->info(), 0, num_elems_processed_per_iteration);
1669     AccessWindowHorizontal mag_access(_magnitude->info(), 0, num_elems_processed_per_iteration);
1670     AccessWindowHorizontal phase_access(_phase->info(), 0, num_elems_processed_per_iteration);
1671
1672     update_window_and_padding(win, gx_access, gy_access, mag_access, phase_access);
1673
1674     mag_access.set_valid_region(win, _gx->info()->valid_region());
1675     phase_access.set_valid_region(win, _gx->info()->valid_region());
1676
1677     INEKernel::configure(win);
1678 }
1679
1680 void NEGradientKernel::run(const Window &window, const ThreadInfo &info)
1681 {
1682     ARM_COMPUTE_UNUSED(info);
1683     ARM_COMPUTE_ERROR_ON_UNCONFIGURED_KERNEL(this);
1684     ARM_COMPUTE_ERROR_ON_INVALID_SUBWINDOW(INEKernel::window(), window);
1685     ARM_COMPUTE_ERROR_ON(_func == nullptr);
1686     Iterator gx(_gx, window);
1687     Iterator gy(_gy, window);
1688     Iterator magnitude(_magnitude, window);
1689     Iterator phase(_phase, window);
1690
1691     execute_window_loop(window, [&](const Coordinates & id)
1692     {
1693         (*_func)(gx.ptr(), gy.ptr(), magnitude.ptr(), phase.ptr());
1694     },
1695     gx, gy, magnitude, phase);
1696 }
1697
1698 NEEdgeNonMaxSuppressionKernel::NEEdgeNonMaxSuppressionKernel()
1699     : _func(nullptr), _magnitude(nullptr), _phase(nullptr), _output(nullptr), _lower_thr(0), _upper_thr(0)
1700 {
1701 }
1702
1703 BorderSize NEEdgeNonMaxSuppressionKernel::border_size() const
1704 {
1705     return BorderSize(1);
1706 }
1707
1708 void NEEdgeNonMaxSuppressionKernel::configure(const ITensor *magnitude, const ITensor *phase, ITensor *output,
1709                                               int32_t upper_thr, int32_t lower_thr, bool border_undefined)
1710 {
1711     ARM_COMPUTE_ERROR_ON_NULLPTR(magnitude, phase, output);
1712
1713     set_shape_if_empty(*output->info(), magnitude->info()->tensor_shape());
1714
1715     set_format_if_unknown(*phase->info(), Format::U8);
1716     set_format_if_unknown(*output->info(), Format::U8);
1717
1718     ARM_COMPUTE_ERROR_ON_MISMATCHING_SHAPES(magnitude, phase, output);
1719     ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(magnitude, 1, DataType::U16, DataType::U32);
1720     ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(phase, 1, DataType::U8);
1721     ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(output, 1, DataType::U8);
1722     ARM_COMPUTE_ERROR_ON_MISMATCHING_DATA_TYPES(phase, output);
1723
1724     _magnitude = magnitude;
1725     _phase     = phase;
1726     _output    = output;
1727
1728     switch(_magnitude->info()->data_type())
1729     {
1730         case DataType::U16:
1731             _func = &non_max_suppression_U16_U8_U8;
1732             break;
1733         case DataType::U32:
1734             _func = &non_max_suppression_U32_U8_U8;
1735             break;
1736         default:
1737             ARM_COMPUTE_ERROR("Unsupported data type!");
1738     }
1739
1740     // Set thresholds
1741     _lower_thr = lower_thr;
1742     _upper_thr = upper_thr;
1743
1744     constexpr unsigned int num_elems_processed_per_iteration = 8;
1745     constexpr unsigned int num_elems_read_per_iteration      = 10;
1746     constexpr unsigned int num_rows_read_per_iteration       = 3;
1747
1748     // Configure kernel window
1749     Window win = calculate_max_window(*_magnitude->info(), Steps(num_elems_processed_per_iteration), border_undefined, border_size());
1750
1751     AccessWindowRectangle  mag_access(_magnitude->info(), -border_size().left, -border_size().top, num_elems_read_per_iteration, num_rows_read_per_iteration);
1752     AccessWindowHorizontal phase_access(_phase->info(), 0, num_elems_processed_per_iteration);
1753     AccessWindowHorizontal output_access(_output->info(), 0, num_elems_processed_per_iteration);
1754
1755     update_window_and_padding(win, mag_access, phase_access, output_access);
1756
1757     output_access.set_valid_region(win, _magnitude->info()->valid_region(), border_undefined, border_size());
1758
1759     INEKernel::configure(win);
1760 }
1761
1762 void NEEdgeNonMaxSuppressionKernel::run(const Window &window, const ThreadInfo &info)
1763 {
1764     ARM_COMPUTE_UNUSED(info);
1765     ARM_COMPUTE_ERROR_ON_UNCONFIGURED_KERNEL(this);
1766     ARM_COMPUTE_ERROR_ON_INVALID_SUBWINDOW(INEKernel::window(), window);
1767     ARM_COMPUTE_ERROR_ON(_func == nullptr);
1768     Iterator magnitude(_magnitude, window);
1769     Iterator phase(_phase, window);
1770     Iterator output(_output, window);
1771
1772     const size_t input1_stride        = _magnitude->info()->strides_in_bytes()[1];
1773     const size_t input1_stride_ushort = input1_stride / data_size_from_type(_magnitude->info()->data_type());
1774
1775     execute_window_loop(window, [&](const Coordinates & id)
1776     {
1777         (*_func)(magnitude.ptr(), phase.ptr(), output.ptr(), input1_stride_ushort, _lower_thr, _upper_thr);
1778     },
1779     magnitude, phase, output);
1780 }
1781
1782 NEEdgeTraceKernel::NEEdgeTraceKernel()
1783     : _input(nullptr), _output(nullptr)
1784 {
1785 }
1786
1787 BorderSize NEEdgeTraceKernel::border_size() const
1788 {
1789     return BorderSize(1);
1790 }
1791
1792 bool NEEdgeTraceKernel::is_parallelisable() const
1793 {
1794     return false;
1795 }
1796
1797 void NEEdgeTraceKernel::configure(ITensor *input, ITensor *output)
1798 {
1799     ARM_COMPUTE_ERROR_ON_NULLPTR(input, output);
1800
1801     set_shape_if_empty(*output->info(), input->info()->tensor_shape());
1802
1803     set_format_if_unknown(*input->info(), Format::U8);
1804     set_format_if_unknown(*output->info(), Format::U8);
1805
1806     ARM_COMPUTE_ERROR_ON_MISMATCHING_SHAPES(input, output);
1807     ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::U8);
1808     ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(output, 1, DataType::U8);
1809     ARM_COMPUTE_ERROR_ON_MISMATCHING_DATA_TYPES(input, output);
1810
1811     _input  = input;
1812     _output = output;
1813
1814     constexpr unsigned int num_elems_processed_per_iteration = 1;
1815
1816     // Configure kernel window
1817     Window win = calculate_max_window(*_input->info(), Steps(num_elems_processed_per_iteration));
1818
1819     const ValidRegion &input_valid_region  = input->info()->valid_region();
1820     const ValidRegion &output_valid_region = output->info()->valid_region();
1821
1822     // Reads can occur within the valid region of the input + border
1823     AccessWindowStatic input_access(input->info(),
1824                                     input_valid_region.anchor[0] - border_size().left,
1825                                     input_valid_region.anchor[1] - border_size().top,
1826                                     input_valid_region.anchor[0] + input_valid_region.shape[0] + border_size().right,
1827                                     input_valid_region.anchor[1] + input_valid_region.shape[1] + border_size().bottom);
1828
1829     // Writes can occur within the valid region of the output + border
1830     AccessWindowStatic output_access(output->info(),
1831                                      output_valid_region.anchor[0] - border_size().left,
1832                                      output_valid_region.anchor[1] - border_size().top,
1833                                      output_valid_region.anchor[0] + output_valid_region.shape[0] + border_size().right,
1834                                      output_valid_region.anchor[1] + output_valid_region.shape[1] + border_size().bottom);
1835
1836     update_window_and_padding(win, input_access, output_access);
1837
1838     output_access.set_valid_region(win, _input->info()->valid_region());
1839
1840     INEKernel::configure(win);
1841 }
1842
1843 void NEEdgeTraceKernel::run(const Window &window, const ThreadInfo &info)
1844 {
1845     ARM_COMPUTE_UNUSED(info);
1846     ARM_COMPUTE_ERROR_ON_UNCONFIGURED_KERNEL(this);
1847     ARM_COMPUTE_ERROR_ON_INVALID_SUBWINDOW(INEKernel::window(), window);
1848     Iterator input(_input, window);
1849     Iterator output(_output, window);
1850
1851     const size_t input_stride  = _input->info()->strides_in_bytes()[1];
1852     const size_t output_stride = _output->info()->strides_in_bytes()[1];
1853
1854     execute_window_loop(window, [&](const Coordinates & id)
1855     {
1856         edge_trace_U8_U8(input.ptr(), output.ptr(), input_stride, output_stride);
1857     },
1858     input, output);
1859 }