arm_compute v17.04
[platform/upstream/armcl.git] / src / core / NEON / kernels / NECannyEdgeKernel.cpp
1 /*
2  * Copyright (c) 2016, 2017 ARM Limited.
3  *
4  * SPDX-License-Identifier: MIT
5  *
6  * Permission is hereby granted, free of charge, to any person obtaining a copy
7  * of this software and associated documentation files (the "Software"), to
8  * deal in the Software without restriction, including without limitation the
9  * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10  * sell copies of the Software, and to permit persons to whom the Software is
11  * furnished to do so, subject to the following conditions:
12  *
13  * The above copyright notice and this permission notice shall be included in all
14  * copies or substantial portions of the Software.
15  *
16  * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17  * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18  * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19  * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20  * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21  * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22  * SOFTWARE.
23  */
24 #include "arm_compute/core/NEON/kernels/NECannyEdgeKernel.h"
25
26 #include "arm_compute/core/AccessWindowStatic.h"
27 #include "arm_compute/core/Error.h"
28 #include "arm_compute/core/Helpers.h"
29 #include "arm_compute/core/ITensor.h"
30 #include "arm_compute/core/TensorInfo.h"
31 #include "arm_compute/core/Types.h"
32 #include "arm_compute/core/Utils.h"
33 #include "arm_compute/core/Validate.h"
34
35 #include <arm_neon.h>
36 #include <cstddef>
37 #include <cstdint>
38 #include <tuple>
39
40 using namespace arm_compute;
41
42 namespace arm_compute
43 {
44 class Coordinates;
45 } // namespace arm_compute
46
47 namespace
48 {
49 constexpr int NO_EDGE = 0;
50 constexpr int EDGE    = 255;
51 constexpr int MAYBE   = 127;
52 } // namespace
53
54 #ifdef ARM_COMPUTE_ENABLE_FP16
55 namespace fp16
56 {
57 inline uint8x8_t phase_quantization(const float32x4x2_t &gx, const float32x4x2_t &gy)
58 {
59     // Constant use for evaluating score1 and score3
60     static const float32x4_t const45 = vdupq_n_f32(0.70710678118655f);
61     static const float32x4_t zero    = vdupq_n_f32(0.0f);
62     static const float32x4_t one     = vdupq_n_f32(1.0f);
63     static const float32x4_t two     = vdupq_n_f32(2.0f);
64     static const float32x4_t three   = vdupq_n_f32(3.0f);
65
66     // Score0: (1, 0)
67     const float32x4x2_t score0 =
68     {
69         vabsq_f32(gx.val[0]),
70         vabsq_f32(gx.val[1])
71     };
72
73     // Score2: ( 0, 1 )
74     const float32x4x2_t score2 =
75     {
76         vabsq_f32(gy.val[0]),
77         vabsq_f32(gy.val[1])
78     };
79
80     // Score1 and Score3: ( sqrt(2) / 2, sqrt(2) / 2 ) - ( -sqrt(2) / 2, sqrt(2) / 2 )
81     float32x4x2_t score1 =
82     {
83         vmulq_f32(gy.val[0], const45),
84         vmulq_f32(gy.val[1], const45)
85     };
86
87     float32x4x2_t score3 = score1;
88
89     score1.val[0] = vmlaq_f32(score1.val[0], gx.val[0], const45);
90     score1.val[1] = vmlaq_f32(score1.val[1], gx.val[1], const45);
91     score3.val[0] = vmlsq_f32(score3.val[0], gx.val[0], const45);
92     score3.val[1] = vmlsq_f32(score3.val[1], gx.val[1], const45);
93
94     score1.val[0] = vabsq_f32(score1.val[0]);
95     score1.val[1] = vabsq_f32(score1.val[1]);
96     score3.val[0] = vabsq_f32(score3.val[0]);
97     score3.val[1] = vabsq_f32(score3.val[1]);
98
99     float32x4x2_t phase =
100     {
101         zero,
102         zero
103     };
104
105     float32x4x2_t old_score = score0;
106
107     // score1 > old_score?
108     uint32x4x2_t mask =
109     {
110         vcgtq_f32(score1.val[0], old_score.val[0]),
111         vcgtq_f32(score1.val[1], old_score.val[1])
112     };
113
114     phase.val[0]     = vbslq_f32(mask.val[0], one, phase.val[0]);
115     phase.val[1]     = vbslq_f32(mask.val[1], one, phase.val[1]);
116     old_score.val[0] = vbslq_f32(mask.val[0], score1.val[0], old_score.val[0]);
117     old_score.val[1] = vbslq_f32(mask.val[1], score1.val[1], old_score.val[1]);
118
119     // score2 > old_score?
120     mask.val[0] = vcgtq_f32(score2.val[0], old_score.val[0]);
121     mask.val[1] = vcgtq_f32(score2.val[1], old_score.val[1]);
122
123     phase.val[0]     = vbslq_f32(mask.val[0], two, phase.val[0]);
124     phase.val[1]     = vbslq_f32(mask.val[1], two, phase.val[1]);
125     old_score.val[0] = vbslq_f32(mask.val[0], score2.val[0], old_score.val[0]);
126     old_score.val[1] = vbslq_f32(mask.val[1], score2.val[1], old_score.val[1]);
127
128     // score3 > old_score?
129     mask.val[0] = vcgtq_f32(score3.val[0], old_score.val[0]);
130     mask.val[1] = vcgtq_f32(score3.val[1], old_score.val[1]);
131
132     phase.val[0]     = vbslq_f32(mask.val[0], three, phase.val[0]);
133     phase.val[1]     = vbslq_f32(mask.val[1], three, phase.val[1]);
134     old_score.val[0] = vbslq_f32(mask.val[0], score3.val[0], old_score.val[0]);
135     old_score.val[1] = vbslq_f32(mask.val[1], score3.val[1], old_score.val[1]);
136
137     // Convert from float32x4_t to uint8x8_t
138     return vmovn_u16(vcombine_u16(vmovn_u32(vcvtq_u32_f32(phase.val[0])),
139                                   vmovn_u32(vcvtq_u32_f32(phase.val[1]))));
140 }
141
142 inline uint8x8_t phase_quantization(float16x8_t gx, float16x8_t gy)
143 {
144     // Constant use for evaluating score1 and score3
145     static const float16x8_t const45 = vdupq_n_f16(0.70710678118655f);
146     static const float16x8_t zero    = vdupq_n_f16(0.0f);
147     static const float16x8_t one     = vdupq_n_f16(1.0f);
148     static const float16x8_t two     = vdupq_n_f16(2.0f);
149     static const float16x8_t three   = vdupq_n_f16(3.0f);
150
151     // Score0: (1, 0)
152     const float16x8_t score0 = vabsq_f16(gx);
153
154     // Score2: ( 0, 1 )
155     const float16x8_t score2 = vabsq_f16(gy);
156
157     // Score1 and Score3: ( sqrt(2) / 2, sqrt(2) / 2 ) - ( -sqrt(2) / 2, sqrt(2) / 2 )
158     float16x8_t score1 = vmulq_f16(gy, const45);
159     float16x8_t score3 = score1;
160
161     score1 = vfmaq_f16(score1, gx, const45);
162     score3 = vfmsq_f16(score3, gx, const45);
163
164     score1 = vabsq_f16(score1);
165     score3 = vabsq_f16(score3);
166
167     float16x8_t phase     = zero;
168     float16x8_t old_score = score0;
169
170     // score1 > old_score?
171     uint16x8_t mask = vcgtq_f16(score1, old_score);
172
173     phase     = vbslq_f16(mask, one, phase);
174     old_score = vbslq_f16(mask, score1, old_score);
175
176     // score2 > old_score?
177     mask = vcgtq_f16(score2, old_score);
178
179     phase     = vbslq_f16(mask, two, phase);
180     old_score = vbslq_f16(mask, score2, old_score);
181
182     // score3 > old_score?
183     mask = vcgtq_f16(score3, old_score);
184
185     phase = vbslq_f16(mask, three, phase);
186
187     // Convert from float16x8_t to uint8x8_t
188     return vmovn_u16(vcvtq_u16_f16(phase));
189 }
190
191 /** Computes the gradient phase if gradient_size = 3 or 5. The output is quantized.
192  *         0 = 0°, 1 = 45°, 2 = 90°, 3 = 135°
193  *
194  * @param[in] gx Gx component
195  * @param[in] gy Gy component
196  *
197  * @return quantized phase for 8 pixels
198  */
199 inline uint8x8_t phase_quantization_S16_S16(int16x8_t gx, int16x8_t gy)
200 {
201     return phase_quantization(vcvtq_f16_s16(gx), vcvtq_f16_s16(gy));
202 }
203
204 /** Computes the gradient phase if gradient_size = 7. The output is quantized.
205  *         0 = 0°, 1 = 45°, 2 = 90°, 3 = 135°
206  *
207  * @param[in] gx Gx component
208  * @param[in] gy Gy component
209  *
210  * @return quantized phase for 8 pixels
211  */
212 inline uint8x8_t phase_quantization_S32_S32(const int32x4x2_t &gx, const int32x4x2_t &gy)
213 {
214     // Convert to float
215     const float32x4x2_t gx_f32 =
216     {
217         vcvtq_f32_s32(gx.val[0]),
218         vcvtq_f32_s32(gx.val[1])
219     };
220
221     const float32x4x2_t gy_f32 =
222     {
223         vcvtq_f32_s32(gy.val[0]),
224         vcvtq_f32_s32(gy.val[1])
225     };
226
227     return phase_quantization(gx_f32, gy_f32);
228 }
229
230 /** Computes the magnitude using the L1-norm type if gradient_size = 3 or 5
231  *
232  * @param[in] gx Gx component
233  * @param[in] gy Gy component
234  *
235  * @return magnitude for 8 pixels
236  */
237 inline uint16x8_t mag_l1_S16_S16(int16x8_t gx, int16x8_t gy)
238 {
239     return vaddq_u16(vreinterpretq_u16_s16(vabsq_s16(gx)),
240                      vreinterpretq_u16_s16(vabsq_s16(gy)));
241 }
242
243 /** Computes the magnitude using the L1-norm type if gradient_size = 7
244  *
245  * @param[in] gx Gx component
246  * @param[in] gy Gy component
247  *
248  * @return magnitude for 8 pixels
249  */
250 inline uint32x4x2_t mag_l1_S32_S32(const int32x4x2_t &gx, const int32x4x2_t &gy)
251 {
252     const uint32x4x2_t gx_abs =
253     {
254         vreinterpretq_u32_s32(vabsq_s32(gx.val[0])),
255         vreinterpretq_u32_s32(vabsq_s32(gx.val[1]))
256     };
257
258     const uint32x4x2_t gy_abs =
259     {
260         vreinterpretq_u32_s32(vabsq_s32(gy.val[0])),
261         vreinterpretq_u32_s32(vabsq_s32(gy.val[1]))
262     };
263
264     const uint32x4x2_t out =
265     {
266         vaddq_u32(gx_abs.val[0], gy_abs.val[0]),
267         vaddq_u32(gx_abs.val[1], gy_abs.val[1])
268     };
269
270     return out;
271 }
272
273 inline float32x4x2_t mag_l2(const float32x4x2_t &gx, const float32x4x2_t &gy)
274 {
275     // x^2 ...
276     float32x4x2_t mag =
277     {
278         vmulq_f32(gx.val[0], gx.val[0]),
279         vmulq_f32(gx.val[1], gx.val[1])
280     };
281
282     // ... + y^2
283     mag.val[0] = vmlaq_f32(mag.val[0], gy.val[0], gy.val[0]);
284     mag.val[1] = vmlaq_f32(mag.val[1], gy.val[1], gy.val[1]);
285
286     // sqrt(...)
287     mag.val[0] = vmulq_f32(vrsqrteq_f32(mag.val[0]), mag.val[0]);
288     mag.val[1] = vmulq_f32(vrsqrteq_f32(mag.val[1]), mag.val[1]);
289
290     return mag;
291 }
292
293 inline float16x8_t mag_l2(float16x8_t gx, float16x8_t gy)
294 {
295     // x^2 ...
296     float16x8_t mag = vmulq_f16(gx, gx);
297
298     // ... + y^2
299     mag = vfmaq_f16(mag, gy, gy);
300
301     // sqrt(...)
302     mag = vmulq_f16(vrsqrteq_f16(mag), mag);
303
304     return mag;
305 }
306
307 /** Computes the magnitude using L2-norm if gradient_size = 3 or 5
308  *
309  * @param[in] gx Gx component
310  * @param[in] gy Gy component
311  *
312  * @return magnitude for 8 pixels
313  */
314 inline uint16x8_t mag_l2_S16_S16(int16x8_t gx, int16x8_t gy)
315 {
316     /* Compute magnitude using L2 normalization */
317     const float16x8_t gx2 = vcvtq_f16_s16(gx);
318     const float16x8_t gy2 = vcvtq_f16_s16(gy);
319     const float16x8_t mag = mag_l2(gx2, gy2);
320
321     /* Store magnitude - Convert to uint16x8 */
322     return vcvtq_u16_f16(mag);
323 }
324
325 /** Computes the magnitude using L2-norm if gradient_size = 7
326  *
327  * @param[in] gx Gx component
328  * @param[in] gy Gy component
329  *
330  * @return magnitude for 8 pixels
331  */
332 inline uint32x4x2_t mag_l2_S32_S32(const int32x4x2_t &gx, const int32x4x2_t &gy)
333 {
334     // Compute magnitude using L2 normalization
335     float32x4x2_t gx2 =
336     {
337         vcvtq_f32_s32(gx.val[0]),
338         vcvtq_f32_s32(gx.val[1])
339     };
340
341     float32x4x2_t gy2 =
342     {
343         vcvtq_f32_s32(gy.val[0]),
344         vcvtq_f32_s32(gy.val[1])
345     };
346
347     const float32x4x2_t mag = mag_l2(gx2, gy2);
348     const uint32x4x2_t  mag32 =
349     {
350         vcvtq_u32_f32(mag.val[0]),
351         vcvtq_u32_f32(mag.val[1])
352     };
353
354     return mag32;
355 }
356
357 /** Gradient function used when the gradient size = 3 or 5 and when the norm_type = L1-norm
358  *
359  * @param[in]  in1_ptr  Pointer to source image. Gx image. Data type supported S16
360  * @param[in]  in2_ptr  Pointer to source image. Gy image. Data type supported S16
361  * @param[out] out1_ptr Pointer to destination image. Magnitude. Data type supported U16
362  * @param[out] out2_ptr Pointer to destination image. Quantized phase. Data type supported U8
363  */
364 void mag_phase_l1norm_S16_S16_U16_U8(const void *__restrict in1_ptr, const void *__restrict in2_ptr, void *__restrict out1_ptr, void *__restrict out2_ptr)
365 {
366     const auto in1  = static_cast<const int16_t *__restrict>(in1_ptr);
367     const auto in2  = static_cast<const int16_t *__restrict>(in2_ptr);
368     const auto out1 = static_cast<uint16_t *__restrict>(out1_ptr);
369     const auto out2 = static_cast<uint8_t *__restrict>(out2_ptr);
370
371     const int16x8x4_t gx =
372     {
373         vld1q_s16(in1),
374         vld1q_s16(in1 + 8),
375         vld1q_s16(in1 + 16),
376         vld1q_s16(in1 + 24)
377     };
378
379     const int16x8x4_t gy =
380     {
381         vld1q_s16(in2),
382         vld1q_s16(in2 + 8),
383         vld1q_s16(in2 + 16),
384         vld1q_s16(in2 + 24)
385     };
386
387     // Compute and store phase
388     vst1_u8(out2 + 0, phase_quantization_S16_S16(gx.val[0], gy.val[0]));
389     vst1_u8(out2 + 8, phase_quantization_S16_S16(gx.val[1], gy.val[1]));
390     vst1_u8(out2 + 16, phase_quantization_S16_S16(gx.val[2], gy.val[2]));
391     vst1_u8(out2 + 24, phase_quantization_S16_S16(gx.val[3], gy.val[3]));
392
393     // Compute ans store magnitude using L1 normalization
394     vst1q_u16(out1 + 0, mag_l1_S16_S16(gx.val[0], gy.val[0]));
395     vst1q_u16(out1 + 8, mag_l1_S16_S16(gx.val[1], gy.val[1]));
396     vst1q_u16(out1 + 16, mag_l1_S16_S16(gx.val[2], gy.val[2]));
397     vst1q_u16(out1 + 24, mag_l1_S16_S16(gx.val[3], gy.val[3]));
398 }
399
400 /** Gradient function used when the gradient size = 3 or 5 and when the norm_type = L2-norm
401  *
402  * @param[in]  in1_ptr  Pointer to source image. Gx image. Data type supported S16
403  * @param[in]  in2_ptr  Pointer to source image. Gy image. Data type supported S16
404  * @param[out] out1_ptr Pointer to destination image. Magnitude. Data type supported U16
405  * @param[out] out2_ptr Pointer to destination image. Quantized phase. Data type supported U8
406  */
407 void mag_phase_l2norm_S16_S16_U16_U8(const void *__restrict in1_ptr, const void *__restrict in2_ptr, void *__restrict out1_ptr, void *__restrict out2_ptr)
408 {
409     const auto in1  = static_cast<const int16_t *__restrict>(in1_ptr);
410     const auto in2  = static_cast<const int16_t *__restrict>(in2_ptr);
411     const auto out1 = static_cast<uint16_t *__restrict>(out1_ptr);
412     const auto out2 = static_cast<uint8_t *__restrict>(out2_ptr);
413
414     const int16x8x4_t gx =
415     {
416         vld1q_s16(in1),
417         vld1q_s16(in1 + 8),
418         vld1q_s16(in1 + 16),
419         vld1q_s16(in1 + 24)
420     };
421
422     const int16x8x4_t gy =
423     {
424         vld1q_s16(in2),
425         vld1q_s16(in2 + 8),
426         vld1q_s16(in2 + 16),
427         vld1q_s16(in2 + 24)
428     };
429
430     // Compute and store phase
431     vst1_u8(out2 + 0, phase_quantization_S16_S16(gx.val[0], gy.val[0]));
432     vst1_u8(out2 + 8, phase_quantization_S16_S16(gx.val[1], gy.val[1]));
433     vst1_u8(out2 + 16, phase_quantization_S16_S16(gx.val[2], gy.val[2]));
434     vst1_u8(out2 + 24, phase_quantization_S16_S16(gx.val[3], gy.val[3]));
435
436     // Compute and store magnitude using L2 normalization
437     vst1q_u16(out1 + 0, mag_l2_S16_S16(gx.val[0], gy.val[0]));
438     vst1q_u16(out1 + 8, mag_l2_S16_S16(gx.val[1], gy.val[1]));
439     vst1q_u16(out1 + 16, mag_l2_S16_S16(gx.val[2], gy.val[2]));
440     vst1q_u16(out1 + 24, mag_l2_S16_S16(gx.val[3], gy.val[3]));
441 }
442
443 /** Gradient function used when the gradient size = 7 and when the norm_type = L1-norm
444  *
445  * @param[in]  in1_ptr  Pointer to source image. Gx image. Data type supported S32
446  * @param[in]  in2_ptr  Pointer to source image. Gy image. Data type supported S32
447  * @param[out] out1_ptr Pointer to destination image. Magnitude. Data type supported U32
448  * @param[out] out2_ptr Pointer to destination image. Quantized phase. Data type supported U8
449  */
450 void mag_phase_l1norm_S32_S32_U32_U8(const void *__restrict in1_ptr, const void *__restrict in2_ptr, void *__restrict out1_ptr, void *__restrict out2_ptr)
451 {
452     auto in1  = static_cast<const int32_t *__restrict>(in1_ptr);
453     auto in2  = static_cast<const int32_t *__restrict>(in2_ptr);
454     auto out1 = static_cast<uint32_t *__restrict>(out1_ptr);
455     auto out2 = static_cast<uint8_t *__restrict>(out2_ptr);
456
457     // Process low and high part
458     for(size_t i = 0; i < 2; ++i, in1 += 16, in2 += 16, out1 += 16, out2 += 16)
459     {
460         const int32x4x2_t gx0 =
461         {
462             vld1q_s32(in1 + 0),
463             vld1q_s32(in1 + 4)
464         };
465
466         const int32x4x2_t gx1 =
467         {
468             vld1q_s32(in1 + 8),
469             vld1q_s32(in1 + 12)
470         };
471
472         const int32x4x2_t gy0 =
473         {
474             vld1q_s32(in2 + 0),
475             vld1q_s32(in2 + 4)
476         };
477
478         const int32x4x2_t gy1 =
479         {
480             vld1q_s32(in2 + 8),
481             vld1q_s32(in2 + 12)
482         };
483
484         // Compute and store phase
485         vst1_u8(out2 + 0, phase_quantization_S32_S32(gx0, gy0));
486         vst1_u8(out2 + 8, phase_quantization_S32_S32(gx1, gy1));
487
488         // Compute magnitude using L1 normalization
489         const uint32x4x2_t mag0 = mag_l1_S32_S32(gx0, gy0);
490         const uint32x4x2_t mag1 = mag_l1_S32_S32(gx1, gy1);
491
492         // Store magnitude
493         vst1q_u32(out1 + 0, mag0.val[0]);
494         vst1q_u32(out1 + 4, mag0.val[1]);
495         vst1q_u32(out1 + 8, mag1.val[0]);
496         vst1q_u32(out1 + 12, mag1.val[1]);
497     }
498 }
499
500 /** Gradient function used when the gradient size = 7 and when the norm_type = L2-norm
501  *
502  * @param[in]  in1_ptr  Pointer to source image. Gx image. Data type supported S32
503  * @param[in]  in2_ptr  Pointer to source image. Gy image. Data type supported S32
504  * @param[out] out1_ptr Pointer to destination image. Magnitude. Data type supported U32
505  * @param[out] out2_ptr Pointer to destination image. Quantized phase. Data type supported U8
506  */
507 void mag_phase_l2norm_S32_S32_U32_U8(const void *__restrict in1_ptr, const void *__restrict in2_ptr, void *__restrict out1_ptr, void *__restrict out2_ptr)
508 {
509     auto in1  = static_cast<const int32_t *__restrict>(in1_ptr);
510     auto in2  = static_cast<const int32_t *__restrict>(in2_ptr);
511     auto out1 = static_cast<uint32_t *__restrict>(out1_ptr);
512     auto out2 = static_cast<uint8_t *__restrict>(out2_ptr);
513
514     // Process low and high part
515     for(size_t i = 0; i < 2; ++i, in1 += 16, in2 += 16, out1 += 16, out2 += 16)
516     {
517         const int32x4x2_t gx0 =
518         {
519             vld1q_s32(in1 + 0),
520             vld1q_s32(in1 + 4)
521         };
522
523         const int32x4x2_t gx1 =
524         {
525             vld1q_s32(in1 + 8),
526             vld1q_s32(in1 + 12)
527         };
528
529         const int32x4x2_t gy0 =
530         {
531             vld1q_s32(in2 + 0),
532             vld1q_s32(in2 + 4)
533         };
534
535         const int32x4x2_t gy1 =
536         {
537             vld1q_s32(in2 + 8),
538             vld1q_s32(in2 + 12)
539         };
540
541         // Compute and store phase
542         vst1_u8(out2 + 0, phase_quantization_S32_S32(gx0, gy0));
543         vst1_u8(out2 + 8, phase_quantization_S32_S32(gx1, gy1));
544
545         // Compute magnitude using L2 normalization
546         const uint32x4x2_t mag0 = mag_l2_S32_S32(gx0, gy0);
547         const uint32x4x2_t mag1 = mag_l2_S32_S32(gx1, gy1);
548
549         // Store magnitude
550         vst1q_u32(out1 + 0, mag0.val[0]);
551         vst1q_u32(out1 + 4, mag0.val[1]);
552         vst1q_u32(out1 + 8, mag1.val[0]);
553         vst1q_u32(out1 + 12, mag1.val[1]);
554     }
555 }
556
557 inline uint16x4_t non_max_U32_helper(const uint32_t *in, const uint16x4_t pc, const uint32_t stride_mag, const int32_t lower_thr, const int32_t upper_thr)
558 {
559     // Phase for 4 pixel
560     const uint32x4_t pc32 = vmovl_u16(pc);
561
562     // Get magnitude for 4 pixel
563     uint32x4_t mc = vld1q_u32(in);
564
565     // Angle_quantized: 0 = 0°, 1 = 45°, 2 = 90°, 3 = 135°
566     // 0 degree
567     const uint32x4_t mk0_0 = vld1q_u32(in - 1);
568     const uint32x4_t mk0_1 = vld1q_u32(in + 1);
569     uint32x4_t       mask0 = vceqq_u32(pc32, vdupq_n_u32(0));
570     mask0                  = vandq_u32(mask0, vcgeq_u32(mc, mk0_0));
571     mask0                  = vandq_u32(mask0, vcgeq_u32(mc, mk0_1));
572
573     // 45 degree
574     const uint32x4_t mk45_0 = vld1q_u32(in - stride_mag - 1);
575     const uint32x4_t mk45_1 = vld1q_u32(in + stride_mag + 1);
576     uint32x4_t       mask1  = vceqq_u32(pc32, vdupq_n_u32(1));
577     mask1                   = vandq_u32(mask1, vcgeq_u32(mc, mk45_0));
578     mask1                   = vandq_u32(mask1, vcgeq_u32(mc, mk45_1));
579
580     // 90 degree
581     const uint32x4_t mk90_0 = vld1q_u32(in - stride_mag);
582     const uint32x4_t mk90_1 = vld1q_u32(in + stride_mag);
583     uint32x4_t       mask2  = vceqq_u32(pc32, vdupq_n_u32(2));
584     mask2                   = vandq_u32(mask2, vcgeq_u32(mc, mk90_0));
585     mask2                   = vandq_u32(mask2, vcgeq_u32(mc, mk90_1));
586
587     // 135 degree
588     const uint32x4_t mk135_0 = vld1q_u32(in - stride_mag + 1);
589     const uint32x4_t mk135_1 = vld1q_u32(in + stride_mag - 1);
590     uint32x4_t       mask3   = vceqq_u32(pc32, vdupq_n_u32(3));
591     mask3                    = vandq_u32(mask3, vcgeq_u32(mc, mk135_0));
592     mask3                    = vandq_u32(mask3, vcgeq_u32(mc, mk135_1));
593
594     // Merge masks
595     mask0 = vorrq_u32(mask0, mask1);
596     mask2 = vorrq_u32(mask2, mask3);
597     mask0 = vorrq_u32(mask0, mask2);
598
599     mc = vbslq_u32(mask0, mc, vdupq_n_u32(0));
600
601     // mc > upper_thr
602     mask0 = vcgtq_u32(mc, vdupq_n_u32(upper_thr));
603
604     // mc <= lower_thr
605     mask1 = vcleq_u32(mc, vdupq_n_u32(lower_thr));
606
607     // mc <= upper_thr && mc > lower_thr
608     mask2 = vcleq_u32(mc, vdupq_n_u32(upper_thr));
609     mask2 = vandq_u32(mask2, vcgtq_u32(mc, vdupq_n_u32(lower_thr)));
610
611     mc = vbslq_u32(mask0, vdupq_n_u32(EDGE), mc);
612     mc = vbslq_u32(mask1, vdupq_n_u32(NO_EDGE), mc);
613     mc = vbslq_u32(mask2, vdupq_n_u32(MAYBE), mc);
614
615     return vmovn_u32(mc);
616 }
617
618 /** Computes edge tracing when is called by edge_trace_U8_U8 recursively
619  *
620  * @param[in]  in         Pointer to source image. Data type supported U8
621  * @param[out] out        Pointer to destination image. Data type supported U8
622  * @param[in]  in_stride  Stride of the input image
623  * @param[in]  out_stride Stride of the output image
624  */
625 void edge_trace_recursive_U8_U8(uint8_t *__restrict in, uint8_t *__restrict out, const int32_t in_stride, const int32_t out_stride)
626 {
627     // Look for MAYBE pixels in 8 directions
628     *out = EDGE;
629
630     // (-1, 0)
631     uint8_t pixel = *(in - 1);
632
633     if(pixel == MAYBE)
634     {
635         // Touched a MAYBE point. MAYBE becomes EDGE
636         *(in - 1) = EDGE;
637
638         edge_trace_recursive_U8_U8(in - 1, out - 1, in_stride, out_stride);
639     }
640
641     // (+1, 0)
642     pixel = *(in + 1);
643
644     if(pixel == MAYBE)
645     {
646         // Touched a MAYBE point. MAYBE becomes EDGE
647         *(in + 1) = EDGE;
648
649         edge_trace_recursive_U8_U8(in + 1, out + 1, in_stride, out_stride);
650     }
651
652     in -= in_stride;
653     out -= out_stride;
654
655     // (-1, -1)
656     pixel = *(in - 1);
657
658     if(pixel == MAYBE)
659     {
660         // Touched a MAYBE point. MAYBE becomes EDGE
661         *(in - 1) = EDGE;
662
663         edge_trace_recursive_U8_U8(in - 1, out - 1, in_stride, out_stride);
664     }
665
666     // (0, -1)
667     pixel = *in;
668
669     if(pixel == MAYBE)
670     {
671         // Touched a MAYBE point. MAYBE becomes EDGE
672         *in = EDGE;
673
674         edge_trace_recursive_U8_U8(in, out, in_stride, out_stride);
675     }
676
677     // (+1, -1)
678     pixel = *(in + 1);
679
680     if(pixel == MAYBE)
681     {
682         // Touched a MAYBE point. MAYBE becomes EDGE
683         *(in + 1) = EDGE;
684
685         edge_trace_recursive_U8_U8(in + 1, out + 1, in_stride, out_stride);
686     }
687
688     in += in_stride * 2;
689     out += out_stride * 2;
690
691     // (-1, +1)
692     pixel = *(in - 1);
693
694     if(pixel == MAYBE)
695     {
696         // Touched a MAYBE point. MAYBE becomes EDGE
697         *(in - 1) = EDGE;
698
699         edge_trace_recursive_U8_U8(in - 1, out - 1, in_stride, out_stride);
700     }
701
702     // (0, +1)
703     pixel = *in;
704
705     if(pixel == MAYBE)
706     {
707         // Touched a MAYBE point. MAYBE becomes EDGE
708         *in = EDGE;
709
710         edge_trace_recursive_U8_U8(in, out, in_stride, out_stride);
711     }
712
713     // (+1, +1)
714     pixel = *(in + 1);
715
716     if(pixel == MAYBE)
717     {
718         // Touched a MAYBE point. MAYBE becomes EDGE
719         *(in + 1) = EDGE;
720
721         edge_trace_recursive_U8_U8(in + 1, out + 1, in_stride, out_stride);
722     }
723 }
724 } // namespace fp16
725
726 void NEGradientFP16Kernel::configure(const ITensor *gx, const ITensor *gy, ITensor *magnitude, ITensor *phase, int32_t norm_type)
727 {
728     ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(gx, 1, DataType::S16, DataType::S32);
729     ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(gy, 1, DataType::S16, DataType::S32);
730     ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(magnitude, 1, DataType::U16, DataType::U32);
731     ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(phase, 1, DataType::U8);
732     ARM_COMPUTE_ERROR_ON_MISMATCHING_DATA_TYPES(gx, gy);
733     ARM_COMPUTE_ERROR_ON_MSG(element_size_from_data_type(gx->info()->data_type()) != element_size_from_data_type(magnitude->info()->data_type()), "Magnitude must have the same element size as Gx and Gy");
734
735     _gx        = gx;
736     _gy        = gy;
737     _magnitude = magnitude;
738     _phase     = phase;
739
740     if(_gx->info()->data_type() == DataType::S16)
741     {
742         if(norm_type == 1)
743         {
744             _func = &fp16::mag_phase_l1norm_S16_S16_U16_U8;
745         }
746         else
747         {
748             _func = &fp16::mag_phase_l2norm_S16_S16_U16_U8;
749         }
750     }
751     else
752     {
753         if(norm_type == 1)
754         {
755             _func = &fp16::mag_phase_l1norm_S32_S32_U32_U8;
756         }
757         else
758         {
759             _func = &fp16::mag_phase_l2norm_S32_S32_U32_U8;
760         }
761     }
762
763     constexpr unsigned int num_elems_processed_per_iteration = 32;
764
765     // Configure kernel window
766     Window win = calculate_max_window(*_gx->info(), Steps(num_elems_processed_per_iteration));
767
768     AccessWindowHorizontal gx_access(_gx->info(), 0, num_elems_processed_per_iteration);
769     AccessWindowHorizontal gy_access(_gy->info(), 0, num_elems_processed_per_iteration);
770     AccessWindowHorizontal mag_access(_magnitude->info(), 0, num_elems_processed_per_iteration);
771     AccessWindowHorizontal phase_access(_phase->info(), 0, num_elems_processed_per_iteration);
772
773     update_window_and_padding(win, gx_access, gy_access, mag_access, phase_access);
774
775     mag_access.set_valid_region(win, _gx->info()->valid_region());
776     phase_access.set_valid_region(win, _gx->info()->valid_region());
777
778     INEKernel::configure(win);
779 }
780 #endif
781
782 namespace
783 {
784 inline uint8x8_t phase_quantization(const float32x4x2_t &gx, const float32x4x2_t &gy)
785 {
786     // Constant use for evaluating score1 and score3
787     static const float32x4_t const45 = vdupq_n_f32(0.70710678118655f);
788     static const float32x4_t zero    = vdupq_n_f32(0.0f);
789     static const float32x4_t one     = vdupq_n_f32(1.0f);
790     static const float32x4_t two     = vdupq_n_f32(2.0f);
791     static const float32x4_t three   = vdupq_n_f32(3.0f);
792
793     // Score0: (1, 0)
794     const float32x4x2_t score0 =
795     {
796         {
797             vabsq_f32(gx.val[0]),
798             vabsq_f32(gx.val[1])
799         }
800     };
801
802     // Score2: ( 0, 1 )
803     const float32x4x2_t score2 =
804     {
805         {
806             vabsq_f32(gy.val[0]),
807             vabsq_f32(gy.val[1])
808         }
809     };
810
811     // Score1 and Score3: ( sqrt(2) / 2, sqrt(2) / 2 ) - ( -sqrt(2) / 2, sqrt(2) / 2 )
812     float32x4x2_t score1 =
813     {
814         {
815             vmulq_f32(gy.val[0], const45),
816             vmulq_f32(gy.val[1], const45)
817         }
818     };
819
820     float32x4x2_t score3 = score1;
821
822     score1.val[0] = vmlaq_f32(score1.val[0], gx.val[0], const45);
823     score1.val[1] = vmlaq_f32(score1.val[1], gx.val[1], const45);
824     score3.val[0] = vmlsq_f32(score3.val[0], gx.val[0], const45);
825     score3.val[1] = vmlsq_f32(score3.val[1], gx.val[1], const45);
826
827     score1.val[0] = vabsq_f32(score1.val[0]);
828     score1.val[1] = vabsq_f32(score1.val[1]);
829     score3.val[0] = vabsq_f32(score3.val[0]);
830     score3.val[1] = vabsq_f32(score3.val[1]);
831
832     float32x4x2_t phase =
833     {
834         {
835             zero,
836             zero
837         }
838     };
839
840     float32x4x2_t old_score = score0;
841
842     // score1 > old_score?
843     uint32x4x2_t mask =
844     {
845         {
846             vcgtq_f32(score1.val[0], old_score.val[0]),
847             vcgtq_f32(score1.val[1], old_score.val[1])
848         }
849     };
850
851     phase.val[0]     = vbslq_f32(mask.val[0], one, phase.val[0]);
852     phase.val[1]     = vbslq_f32(mask.val[1], one, phase.val[1]);
853     old_score.val[0] = vbslq_f32(mask.val[0], score1.val[0], old_score.val[0]);
854     old_score.val[1] = vbslq_f32(mask.val[1], score1.val[1], old_score.val[1]);
855
856     // score2 > old_score?
857     mask.val[0] = vcgtq_f32(score2.val[0], old_score.val[0]);
858     mask.val[1] = vcgtq_f32(score2.val[1], old_score.val[1]);
859
860     phase.val[0]     = vbslq_f32(mask.val[0], two, phase.val[0]);
861     phase.val[1]     = vbslq_f32(mask.val[1], two, phase.val[1]);
862     old_score.val[0] = vbslq_f32(mask.val[0], score2.val[0], old_score.val[0]);
863     old_score.val[1] = vbslq_f32(mask.val[1], score2.val[1], old_score.val[1]);
864
865     // score3 > old_score?
866     mask.val[0] = vcgtq_f32(score3.val[0], old_score.val[0]);
867     mask.val[1] = vcgtq_f32(score3.val[1], old_score.val[1]);
868
869     phase.val[0]     = vbslq_f32(mask.val[0], three, phase.val[0]);
870     phase.val[1]     = vbslq_f32(mask.val[1], three, phase.val[1]);
871     old_score.val[0] = vbslq_f32(mask.val[0], score3.val[0], old_score.val[0]);
872     old_score.val[1] = vbslq_f32(mask.val[1], score3.val[1], old_score.val[1]);
873
874     // Convert from float32x4_t to uint8x8_t
875     return vmovn_u16(vcombine_u16(vmovn_u32(vcvtq_u32_f32(phase.val[0])),
876                                   vmovn_u32(vcvtq_u32_f32(phase.val[1]))));
877 }
878
879 /* Computes the gradient phase if gradient_size = 3 or 5. The output is quantized.
880  * 0 = 0°, 1 = 45°, 2 = 90°, 3 = 135°
881  *
882  * @param[in] gx Gx component
883  * @param[in] gy Gy component
884  *
885  * @return quantized phase for 8 pixels
886  */
887 inline uint8x8_t phase_quantization_S16_S16(int16x8_t gx, int16x8_t gy)
888 {
889     // Convert to float
890     const float32x4x2_t gx_f32 =
891     {
892         {
893             vcvtq_f32_s32(vmovl_s16(vget_low_s16(gx))),
894             vcvtq_f32_s32(vmovl_s16(vget_high_s16(gx)))
895         }
896     };
897
898     const float32x4x2_t gy_f32 =
899     {
900         {
901             vcvtq_f32_s32(vmovl_s16(vget_low_s16(gy))),
902             vcvtq_f32_s32(vmovl_s16(vget_high_s16(gy)))
903         }
904     };
905
906     return phase_quantization(gx_f32, gy_f32);
907 }
908
909 /* Computes the gradient phase if gradient_size = 7. The output is quantized.
910  * 0 = 0°, 1 = 45°, 2 = 90°, 3 = 135°
911  *
912  * @param[in] gx Gx component
913  * @param[in] gy Gy component
914  *
915  * @return quantized phase for 8 pixels
916  */
917 inline uint8x8_t phase_quantization_S32_S32(const int32x4x2_t &gx, const int32x4x2_t &gy)
918 {
919     // Convert to float
920     const float32x4x2_t gx_f32 =
921     {
922         {
923             vcvtq_f32_s32(gx.val[0]),
924             vcvtq_f32_s32(gx.val[1])
925         }
926     };
927
928     const float32x4x2_t gy_f32 =
929     {
930         {
931             vcvtq_f32_s32(gy.val[0]),
932             vcvtq_f32_s32(gy.val[1])
933         }
934     };
935
936     return phase_quantization(gx_f32, gy_f32);
937 }
938
939 /* Computes the magnitude using the L1-norm type if gradient_size = 3 or 5
940  *
941  * @param[in] gx Gx component
942  * @param[in] gy Gy component
943  *
944  * @return magnitude for 8 pixels
945  */
946 inline uint16x8_t mag_l1_S16_S16(int16x8_t gx, int16x8_t gy)
947 {
948     return vaddq_u16(vreinterpretq_u16_s16(vabsq_s16(gx)),
949                      vreinterpretq_u16_s16(vabsq_s16(gy)));
950 }
951
952 /* Computes the magnitude using the L1-norm type if gradient_size = 7
953  *
954  * @param[in] gx Gx component
955  * @param[in] gy Gy component
956  *
957  * @return magnitude for 8 pixels
958  */
959 inline uint32x4x2_t mag_l1_S32_S32(const int32x4x2_t &gx, const int32x4x2_t &gy)
960 {
961     const uint32x4x2_t gx_abs =
962     {
963         {
964             vreinterpretq_u32_s32(vabsq_s32(gx.val[0])),
965             vreinterpretq_u32_s32(vabsq_s32(gx.val[1]))
966         }
967     };
968
969     const uint32x4x2_t gy_abs =
970     {
971         {
972             vreinterpretq_u32_s32(vabsq_s32(gy.val[0])),
973             vreinterpretq_u32_s32(vabsq_s32(gy.val[1]))
974         }
975     };
976
977     const uint32x4x2_t output =
978     {
979         {
980             vaddq_u32(gx_abs.val[0], gy_abs.val[0]),
981             vaddq_u32(gx_abs.val[1], gy_abs.val[1])
982         }
983     };
984
985     return output;
986 }
987
988 inline float32x4x2_t mag_l2(const float32x4x2_t &gx, const float32x4x2_t &gy)
989 {
990     // x^2 ...
991     float32x4x2_t magnitude =
992     {
993         {
994             vmulq_f32(gx.val[0], gx.val[0]),
995             vmulq_f32(gx.val[1], gx.val[1])
996         }
997     };
998
999     // ... + y^2
1000     magnitude.val[0] = vmlaq_f32(magnitude.val[0], gy.val[0], gy.val[0]);
1001     magnitude.val[1] = vmlaq_f32(magnitude.val[1], gy.val[1], gy.val[1]);
1002
1003     // sqrt(...)
1004     magnitude.val[0] = vmulq_f32(vrsqrteq_f32(magnitude.val[0]), magnitude.val[0]);
1005     magnitude.val[1] = vmulq_f32(vrsqrteq_f32(magnitude.val[1]), magnitude.val[1]);
1006
1007     return magnitude;
1008 }
1009
1010 /* Computes the magnitude using L2-norm if gradient_size = 3 or 5
1011  *
1012  * @param[in] gx Gx component
1013  * @param[in] gy Gy component
1014  *
1015  * @return magnitude for 8 pixels
1016  */
1017 inline uint16x8_t mag_l2_S16_S16(int16x8_t gx, int16x8_t gy)
1018 {
1019     // Compute magnitude using L2 normalization
1020     const float32x4x2_t gx2 =
1021     {
1022         {
1023             vcvtq_f32_s32(vmovl_s16(vget_low_s16(gx))),
1024             vcvtq_f32_s32(vmovl_s16(vget_high_s16(gx)))
1025         }
1026     };
1027
1028     const float32x4x2_t gy2 =
1029     {
1030         {
1031             vcvtq_f32_s32(vmovl_s16(vget_low_s16(gy))),
1032             vcvtq_f32_s32(vmovl_s16(vget_high_s16(gy)))
1033         }
1034     };
1035
1036     const float32x4x2_t magnitude = mag_l2(gx2, gy2);
1037
1038     // Store magnitude - Convert to uint16x8
1039     return vcombine_u16(vmovn_u32(vcvtq_u32_f32(magnitude.val[0])),
1040                         vmovn_u32(vcvtq_u32_f32(magnitude.val[1])));
1041 }
1042
1043 /* Computes the magnitude using L2-norm if gradient_size = 7
1044  *
1045  * @param[in] gx Gx component
1046  * @param[in] gy Gy component
1047  *
1048  * @return magnitude for 8 pixels
1049  */
1050 inline uint32x4x2_t mag_l2_S32_S32(const int32x4x2_t &gx, const int32x4x2_t &gy)
1051 {
1052     // Compute magnitude using L2 normalization
1053     float32x4x2_t gx2 =
1054     {
1055         {
1056             vcvtq_f32_s32(gx.val[0]),
1057             vcvtq_f32_s32(gx.val[1])
1058         }
1059     };
1060
1061     float32x4x2_t gy2 =
1062     {
1063         {
1064             vcvtq_f32_s32(gy.val[0]),
1065             vcvtq_f32_s32(gy.val[1])
1066         }
1067     };
1068
1069     const float32x4x2_t magnitude = mag_l2(gx2, gy2);
1070     const uint32x4x2_t  mag32 =
1071     {
1072         {
1073             vcvtq_u32_f32(magnitude.val[0]),
1074             vcvtq_u32_f32(magnitude.val[1])
1075         }
1076     };
1077
1078     return mag32;
1079 }
1080
1081 /* Gradient function used when the gradient size = 3 or 5 and when the norm_type = L1-norm
1082  *
1083  * @param[in]  gx_ptr        Pointer to source image. Gx image. Data type supported S16
1084  * @param[in]  gy_ptr        Pointer to source image. Gy image. Data type supported S16
1085  * @param[out] magnitude_ptr Pointer to destination image. Magnitude. Data type supported U16
1086  * @param[out] phase_ptr     Pointer to destination image. Quantized phase. Data type supported U8
1087  */
1088 void mag_phase_l1norm_S16_S16_U16_U8(const void *__restrict gx_ptr, const void *__restrict gy_ptr, void *__restrict magnitude_ptr, void *__restrict phase_ptr)
1089 {
1090     const auto gx        = static_cast<const int16_t *__restrict>(gx_ptr);
1091     const auto gy        = static_cast<const int16_t *__restrict>(gy_ptr);
1092     const auto magnitude = static_cast<uint16_t *__restrict>(magnitude_ptr);
1093     const auto phase     = static_cast<uint8_t *__restrict>(phase_ptr);
1094
1095     const int16x8x4_t gx_val =
1096     {
1097         {
1098             vld1q_s16(gx),
1099             vld1q_s16(gx + 8),
1100             vld1q_s16(gx + 16),
1101             vld1q_s16(gx + 24)
1102         }
1103     };
1104
1105     const int16x8x4_t gy_val =
1106     {
1107         {
1108             vld1q_s16(gy),
1109             vld1q_s16(gy + 8),
1110             vld1q_s16(gy + 16),
1111             vld1q_s16(gy + 24)
1112         }
1113     };
1114
1115     // Compute and store phase
1116     vst1_u8(phase + 0, phase_quantization_S16_S16(gx_val.val[0], gy_val.val[0]));
1117     vst1_u8(phase + 8, phase_quantization_S16_S16(gx_val.val[1], gy_val.val[1]));
1118     vst1_u8(phase + 16, phase_quantization_S16_S16(gx_val.val[2], gy_val.val[2]));
1119     vst1_u8(phase + 24, phase_quantization_S16_S16(gx_val.val[3], gy_val.val[3]));
1120
1121     // Compute ans store magnitude using L1 normalization
1122     vst1q_u16(magnitude + 0, mag_l1_S16_S16(gx_val.val[0], gy_val.val[0]));
1123     vst1q_u16(magnitude + 8, mag_l1_S16_S16(gx_val.val[1], gy_val.val[1]));
1124     vst1q_u16(magnitude + 16, mag_l1_S16_S16(gx_val.val[2], gy_val.val[2]));
1125     vst1q_u16(magnitude + 24, mag_l1_S16_S16(gx_val.val[3], gy_val.val[3]));
1126 }
1127
1128 /* Gradient function used when the gradient size = 3 or 5 and when the norm_type = L2-norm
1129  *
1130  * @param[in]  gx_ptr        Pointer to source image. Gx image. Data type supported S16
1131  * @param[in]  gy_ptr        Pointer to source image. Gy image. Data type supported S16
1132  * @param[out] magnitude_ptr Pointer to destination image. Magnitude. Data type supported U16
1133  * @param[out] phase_ptr     Pointer to destination image. Quantized phase. Data type supported U8
1134  */
1135 void mag_phase_l2norm_S16_S16_U16_U8(const void *__restrict gx_ptr, const void *__restrict gy_ptr, void *__restrict magnitude_ptr, void *__restrict phase_ptr)
1136 {
1137     const auto gx        = static_cast<const int16_t *__restrict>(gx_ptr);
1138     const auto gy        = static_cast<const int16_t *__restrict>(gy_ptr);
1139     const auto magnitude = static_cast<uint16_t *__restrict>(magnitude_ptr);
1140     const auto phase     = static_cast<uint8_t *__restrict>(phase_ptr);
1141
1142     const int16x8x4_t gx_val =
1143     {
1144         {
1145             vld1q_s16(gx),
1146             vld1q_s16(gx + 8),
1147             vld1q_s16(gx + 16),
1148             vld1q_s16(gx + 24)
1149         }
1150     };
1151
1152     const int16x8x4_t gy_val =
1153     {
1154         {
1155             vld1q_s16(gy),
1156             vld1q_s16(gy + 8),
1157             vld1q_s16(gy + 16),
1158             vld1q_s16(gy + 24)
1159         }
1160     };
1161
1162     // Compute and store phase
1163     vst1_u8(phase + 0, phase_quantization_S16_S16(gx_val.val[0], gy_val.val[0]));
1164     vst1_u8(phase + 8, phase_quantization_S16_S16(gx_val.val[1], gy_val.val[1]));
1165     vst1_u8(phase + 16, phase_quantization_S16_S16(gx_val.val[2], gy_val.val[2]));
1166     vst1_u8(phase + 24, phase_quantization_S16_S16(gx_val.val[3], gy_val.val[3]));
1167
1168     // Compute and store magnitude using L2 normalization
1169     vst1q_u16(magnitude + 0, mag_l2_S16_S16(gx_val.val[0], gy_val.val[0]));
1170     vst1q_u16(magnitude + 8, mag_l2_S16_S16(gx_val.val[1], gy_val.val[1]));
1171     vst1q_u16(magnitude + 16, mag_l2_S16_S16(gx_val.val[2], gy_val.val[2]));
1172     vst1q_u16(magnitude + 24, mag_l2_S16_S16(gx_val.val[3], gy_val.val[3]));
1173 }
1174
1175 /* Gradient function used when the gradient size = 7 and when the norm_type = L1-norm
1176  *
1177  * @param[in]  gx_ptr        Pointer to source image. Gx image. Data type supported S32
1178  * @param[in]  gy_ptr        Pointer to source image. Gy image. Data type supported S32
1179  * @param[out] magnitude_ptr Pointer to destination image. Magnitude. Data type supported U32
1180  * @param[out] phase_ptr     Pointer to destination image. Quantized phase. Data type support U8
1181  */
1182 void mag_phase_l1norm_S32_S32_U32_U8(const void *__restrict gx_ptr, const void *__restrict gy_ptr, void *__restrict magnitude_ptr, void *__restrict phase_ptr)
1183 {
1184     auto gx        = static_cast<const int32_t *__restrict>(gx_ptr);
1185     auto gy        = static_cast<const int32_t *__restrict>(gy_ptr);
1186     auto magnitude = static_cast<uint32_t *__restrict>(magnitude_ptr);
1187     auto phase     = static_cast<uint8_t *__restrict>(phase_ptr);
1188
1189     // Process low and high part
1190     for(size_t i = 0; i < 2; ++i, gx += 16, gy += 16, magnitude += 16, phase += 16)
1191     {
1192         const int32x4x2_t gx0 =
1193         {
1194             {
1195                 vld1q_s32(gx + 0),
1196                 vld1q_s32(gx + 4)
1197             }
1198         };
1199
1200         const int32x4x2_t gx1 =
1201         {
1202             {
1203                 vld1q_s32(gx + 8),
1204                 vld1q_s32(gx + 12)
1205             }
1206         };
1207
1208         const int32x4x2_t gy0 =
1209         {
1210             {
1211                 vld1q_s32(gy + 0),
1212                 vld1q_s32(gy + 4)
1213             }
1214         };
1215
1216         const int32x4x2_t gy1 =
1217         {
1218             {
1219                 vld1q_s32(gy + 8),
1220                 vld1q_s32(gy + 12)
1221             }
1222         };
1223
1224         // Compute and store phase
1225         vst1_u8(phase + 0, phase_quantization_S32_S32(gx0, gy0));
1226         vst1_u8(phase + 8, phase_quantization_S32_S32(gx1, gy1));
1227
1228         // Compute magnitude using L1 normalization
1229         const uint32x4x2_t mag0 = mag_l1_S32_S32(gx0, gy0);
1230         const uint32x4x2_t mag1 = mag_l1_S32_S32(gx1, gy1);
1231
1232         // Store magnitude
1233         vst1q_u32(magnitude + 0, mag0.val[0]);
1234         vst1q_u32(magnitude + 4, mag0.val[1]);
1235         vst1q_u32(magnitude + 8, mag1.val[0]);
1236         vst1q_u32(magnitude + 12, mag1.val[1]);
1237     }
1238 }
1239
1240 /* Gradient function used when the gradient size = 7 and when the norm_type = L2-norm
1241  *
1242  * @param[in]  gx_ptr        Pointer to source image. Gx image. Data type supported S32
1243  * @param[in]  gy_ptr        Pointer to source image. Gy image. Data type supported S32
1244  * @param[out] magnitude_ptr Pointer to destination image. Magnitude. Data type supported U32
1245  * @param[out] phase_ptr     Pointer to destination image. Quantized phase. Data type supported U8
1246  */
1247 void mag_phase_l2norm_S32_S32_U32_U8(const void *__restrict gx_ptr, const void *__restrict gy_ptr, void *__restrict magnitude_ptr, void *__restrict phase_ptr)
1248 {
1249     auto gx        = static_cast<const int32_t *__restrict>(gx_ptr);
1250     auto gy        = static_cast<const int32_t *__restrict>(gy_ptr);
1251     auto magnitude = static_cast<uint32_t *__restrict>(magnitude_ptr);
1252     auto phase     = static_cast<uint8_t *__restrict>(phase_ptr);
1253
1254     // Process low and high part
1255     for(size_t i = 0; i < 2; ++i, gx += 16, gy += 16, magnitude += 16, phase += 16)
1256     {
1257         const int32x4x2_t gx0 =
1258         {
1259             {
1260                 vld1q_s32(gx + 0),
1261                 vld1q_s32(gx + 4)
1262             }
1263         };
1264
1265         const int32x4x2_t gx1 =
1266         {
1267             {
1268                 vld1q_s32(gx + 8),
1269                 vld1q_s32(gx + 12)
1270             }
1271         };
1272
1273         const int32x4x2_t gy0 =
1274         {
1275             {
1276                 vld1q_s32(gy + 0),
1277                 vld1q_s32(gy + 4)
1278             }
1279         };
1280
1281         const int32x4x2_t gy1 =
1282         {
1283             {
1284                 vld1q_s32(gy + 8),
1285                 vld1q_s32(gy + 12)
1286             }
1287         };
1288
1289         // Compute and store phase
1290         vst1_u8(phase + 0, phase_quantization_S32_S32(gx0, gy0));
1291         vst1_u8(phase + 8, phase_quantization_S32_S32(gx1, gy1));
1292
1293         // Compute magnitude using L2 normalization
1294         const uint32x4x2_t mag0 = mag_l2_S32_S32(gx0, gy0);
1295         const uint32x4x2_t mag1 = mag_l2_S32_S32(gx1, gy1);
1296
1297         // Store magnitude
1298         vst1q_u32(magnitude + 0, mag0.val[0]);
1299         vst1q_u32(magnitude + 4, mag0.val[1]);
1300         vst1q_u32(magnitude + 8, mag1.val[0]);
1301         vst1q_u32(magnitude + 12, mag1.val[1]);
1302     }
1303 }
1304
1305 /* Computes non-maxima suppression and hysteresis when the gradient size = 3 or 5
1306  *
1307  * @param[in]  magnitude_ptr Pointer to source image. Magnitude. Data type supported U16
1308  * @param[in]  phase_ptr     Pointer to source image. Quantized phase. Data type supported U8
1309  * @param[out] output_ptr    Pointer to output image. Data type supported U8
1310  * @param[in]  stride_mag    Stride of magnitude image
1311  * @param[in]  lower_thr     Lower threshold used for the hysteresis
1312  * @param[in]  upper_thr     Upper threshold used for the hysteresis
1313  */
1314 void non_max_suppression_U16_U8_U8(const void *__restrict magnitude_ptr, const void *__restrict phase_ptr, void *__restrict output_ptr, const uint32_t stride_mag, const int32_t lower_thr,
1315                                    const int32_t upper_thr)
1316 {
1317     const auto magnitude = static_cast<const uint16_t *__restrict>(magnitude_ptr);
1318     const auto phase     = static_cast<const uint8_t *__restrict>(phase_ptr);
1319     const auto output    = static_cast<uint8_t *__restrict>(output_ptr);
1320
1321     // Get magnitude and phase of the centre pixels
1322     uint16x8_t mc = vld1q_u16(magnitude);
1323
1324     // Angle_quantized: 0 = 0°, 1 = 45°, 2 = 90°, 3 = 135°
1325     const uint16x8_t pc16 = vmovl_u8(vld1_u8(phase));
1326
1327     // 0 degree
1328     const uint16x8_t mk0_0 = vld1q_u16(magnitude - 1);
1329     const uint16x8_t mk0_1 = vld1q_u16(magnitude + 1);
1330     uint16x8_t       mask0 = vceqq_u16(pc16, vdupq_n_u16(0));
1331     mask0                  = vandq_u16(mask0, vcgeq_u16(mc, mk0_0));
1332     mask0                  = vandq_u16(mask0, vcgeq_u16(mc, mk0_1));
1333
1334     // 45 degree
1335     const uint16x8_t mk45_0 = vld1q_u16(magnitude - stride_mag - 1);
1336     const uint16x8_t mk45_1 = vld1q_u16(magnitude + stride_mag + 1);
1337     uint16x8_t       mask1  = vceqq_u16(pc16, vdupq_n_u16(1));
1338     mask1                   = vandq_u16(mask1, vcgeq_u16(mc, mk45_0));
1339     mask1                   = vandq_u16(mask1, vcgeq_u16(mc, mk45_1));
1340
1341     // 90 degree
1342     const uint16x8_t mk90_0 = vld1q_u16(magnitude - stride_mag);
1343     const uint16x8_t mk90_1 = vld1q_u16(magnitude + stride_mag);
1344     uint16x8_t       mask2  = vceqq_u16(pc16, vdupq_n_u16(2));
1345     mask2                   = vandq_u16(mask2, vcgeq_u16(mc, mk90_0));
1346     mask2                   = vandq_u16(mask2, vcgeq_u16(mc, mk90_1));
1347
1348     // 135 degree
1349     const uint16x8_t mk135_0 = vld1q_u16(magnitude - stride_mag + 1);
1350     const uint16x8_t mk135_1 = vld1q_u16(magnitude + stride_mag - 1);
1351     uint16x8_t       mask3   = vceqq_u16(pc16, vdupq_n_u16(3));
1352     mask3                    = vandq_u16(mask3, vcgeq_u16(mc, mk135_0));
1353     mask3                    = vandq_u16(mask3, vcgeq_u16(mc, mk135_1));
1354
1355     // Merge masks
1356     mask0 = vorrq_u16(mask0, mask1);
1357     mask2 = vorrq_u16(mask2, mask3);
1358     mask0 = vorrq_u16(mask0, mask2);
1359
1360     mc = vbslq_u16(mask0, mc, vdupq_n_u16(0));
1361
1362     // mc > upper_thr
1363     mask0 = vcgtq_u16(mc, vdupq_n_u16(upper_thr));
1364
1365     // mc <= lower_thr
1366     mask1 = vcleq_u16(mc, vdupq_n_u16(lower_thr));
1367
1368     // mc <= upper_thr && mc > lower_thr
1369     mask2 = vcleq_u16(mc, vdupq_n_u16(upper_thr));
1370     mask2 = vandq_u16(mask2, vcgtq_u16(mc, vdupq_n_u16(lower_thr)));
1371
1372     mc = vbslq_u16(mask0, vdupq_n_u16(EDGE), mc);
1373     mc = vbslq_u16(mask1, vdupq_n_u16(NO_EDGE), mc);
1374     mc = vbslq_u16(mask2, vdupq_n_u16(MAYBE), mc);
1375
1376     vst1_u8(output, vmovn_u16(mc));
1377 }
1378
1379 inline uint16x4_t non_max_U32_helper(const uint32_t *input, const uint16x4_t pc, const uint32_t stride_mag, const int32_t lower_thr, const int32_t upper_thr)
1380 {
1381     // Phase for 4 pixel
1382     const uint32x4_t pc32 = vmovl_u16(pc);
1383
1384     // Get magnitude for 4 pixel
1385     uint32x4_t mc = vld1q_u32(input);
1386
1387     // Angle_quantized: 0 = 0°, 1 = 45°, 2 = 90°, 3 = 135°
1388     // 0 degree
1389     const uint32x4_t mk0_0 = vld1q_u32(input - 1);
1390     const uint32x4_t mk0_1 = vld1q_u32(input + 1);
1391     uint32x4_t       mask0 = vceqq_u32(pc32, vdupq_n_u32(0));
1392     mask0                  = vandq_u32(mask0, vcgeq_u32(mc, mk0_0));
1393     mask0                  = vandq_u32(mask0, vcgeq_u32(mc, mk0_1));
1394
1395     // 45 degree
1396     const uint32x4_t mk45_0 = vld1q_u32(input - stride_mag - 1);
1397     const uint32x4_t mk45_1 = vld1q_u32(input + stride_mag + 1);
1398     uint32x4_t       mask1  = vceqq_u32(pc32, vdupq_n_u32(1));
1399     mask1                   = vandq_u32(mask1, vcgeq_u32(mc, mk45_0));
1400     mask1                   = vandq_u32(mask1, vcgeq_u32(mc, mk45_1));
1401
1402     // 90 degree
1403     const uint32x4_t mk90_0 = vld1q_u32(input - stride_mag);
1404     const uint32x4_t mk90_1 = vld1q_u32(input + stride_mag);
1405     uint32x4_t       mask2  = vceqq_u32(pc32, vdupq_n_u32(2));
1406     mask2                   = vandq_u32(mask2, vcgeq_u32(mc, mk90_0));
1407     mask2                   = vandq_u32(mask2, vcgeq_u32(mc, mk90_1));
1408
1409     // 135 degree
1410     const uint32x4_t mk135_0 = vld1q_u32(input - stride_mag + 1);
1411     const uint32x4_t mk135_1 = vld1q_u32(input + stride_mag - 1);
1412     uint32x4_t       mask3   = vceqq_u32(pc32, vdupq_n_u32(3));
1413     mask3                    = vandq_u32(mask3, vcgeq_u32(mc, mk135_0));
1414     mask3                    = vandq_u32(mask3, vcgeq_u32(mc, mk135_1));
1415
1416     // Merge masks
1417     mask0 = vorrq_u32(mask0, mask1);
1418     mask2 = vorrq_u32(mask2, mask3);
1419     mask0 = vorrq_u32(mask0, mask2);
1420
1421     mc = vbslq_u32(mask0, mc, vdupq_n_u32(0));
1422
1423     // mc > upper_thr
1424     mask0 = vcgtq_u32(mc, vdupq_n_u32(upper_thr));
1425
1426     // mc <= lower_thr
1427     mask1 = vcleq_u32(mc, vdupq_n_u32(lower_thr));
1428
1429     // mc <= upper_thr && mc > lower_thr
1430     mask2 = vcleq_u32(mc, vdupq_n_u32(upper_thr));
1431     mask2 = vandq_u32(mask2, vcgtq_u32(mc, vdupq_n_u32(lower_thr)));
1432
1433     mc = vbslq_u32(mask0, vdupq_n_u32(EDGE), mc);
1434     mc = vbslq_u32(mask1, vdupq_n_u32(NO_EDGE), mc);
1435     mc = vbslq_u32(mask2, vdupq_n_u32(MAYBE), mc);
1436
1437     return vmovn_u32(mc);
1438 }
1439
1440 /* Computes non-maxima suppression and hysteresis when the gradient_size = 7
1441  *
1442  * @param[in]  magnitude_ptr Pointer to source image. Magnitude. Data type supported U32
1443  * @param[in]  phase_ptr     Pointer to source image. Quantized phase. Data type supported U8
1444  * @param[out] output_ptr    Pointer to destination image. Data type supported U8
1445  * @param[in]  stride_mag    Stride of magnitude image
1446  * @param[in]  lower_thr     Lower threshold used for the hysteresis
1447  * @param[in]  upper_thr     Upper threshold used for the hysteresis
1448  */
1449 void non_max_suppression_U32_U8_U8(const void *__restrict magnitude_ptr, const void *__restrict phase_ptr, void *__restrict output_ptr, const uint32_t stride_mag, const int32_t lower_thr,
1450                                    const int32_t upper_thr)
1451 {
1452     const auto magnitude = static_cast<const uint32_t *__restrict>(magnitude_ptr);
1453     const auto phase     = static_cast<const uint8_t *__restrict>(phase_ptr);
1454     const auto output    = static_cast<uint8_t *__restrict>(output_ptr);
1455
1456     // Get phase for 8 pixel
1457     const uint16x8_t pc16 = vmovl_u8(vld1_u8(phase));
1458
1459     // Compute non maxima suppression
1460     const uint16x4x2_t res =
1461     {
1462         {
1463             non_max_U32_helper(magnitude, vget_low_u16(pc16), stride_mag, lower_thr, upper_thr),
1464             non_max_U32_helper(magnitude + 4, vget_high_u16(pc16), stride_mag, lower_thr, upper_thr)
1465         }
1466     };
1467
1468     // Store result
1469     vst1_u8(output, vmovn_u16(vcombine_u16(res.val[0], res.val[1])));
1470 }
1471
1472 /* Computes edge tracing when is called by edge_trace_U8_U8 recursively
1473  *
1474  * @param[in]  input         Pointer to source image. Data type supported U8
1475  * @param[out] output        Pointer to destination image. Data type supported U8
1476  * @param[in]  input_stride  Stride of the input image
1477  * @param[in]  output_stride Stride of the output image
1478  */
1479 void edge_trace_recursive_U8_U8(uint8_t *__restrict input, uint8_t *__restrict output, const int32_t input_stride, const int32_t output_stride)
1480 {
1481     // Look for MAYBE pixels in 8 directions
1482     *output = EDGE;
1483
1484     // (-1, 0)
1485     uint8_t pixel = *(input - 1);
1486
1487     if(pixel == MAYBE)
1488     {
1489         // Touched a MAYBE point. MAYBE becomes EDGE
1490         *(input - 1) = EDGE;
1491
1492         edge_trace_recursive_U8_U8(input - 1, output - 1, input_stride, output_stride);
1493     }
1494
1495     // (+1, 0)
1496     pixel = *(input + 1);
1497
1498     if(pixel == MAYBE)
1499     {
1500         // Touched a MAYBE point. MAYBE becomes EDGE
1501         *(input + 1) = EDGE;
1502
1503         edge_trace_recursive_U8_U8(input + 1, output + 1, input_stride, output_stride);
1504     }
1505
1506     input -= input_stride;
1507     output -= output_stride;
1508
1509     // (-1, -1)
1510     pixel = *(input - 1);
1511
1512     if(pixel == MAYBE)
1513     {
1514         // Touched a MAYBE point. MAYBE becomes EDGE
1515         *(input - 1) = EDGE;
1516
1517         edge_trace_recursive_U8_U8(input - 1, output - 1, input_stride, output_stride);
1518     }
1519
1520     // (0, -1)
1521     pixel = *input;
1522
1523     if(pixel == MAYBE)
1524     {
1525         // Touched a MAYBE point. MAYBE becomes EDGE
1526         *input = EDGE;
1527
1528         edge_trace_recursive_U8_U8(input, output, input_stride, output_stride);
1529     }
1530
1531     // (+1, -1)
1532     pixel = *(input + 1);
1533
1534     if(pixel == MAYBE)
1535     {
1536         // Touched a MAYBE point. MAYBE becomes EDGE
1537         *(input + 1) = EDGE;
1538
1539         edge_trace_recursive_U8_U8(input + 1, output + 1, input_stride, output_stride);
1540     }
1541
1542     input += input_stride * 2;
1543     output += output_stride * 2;
1544
1545     // (-1, +1)
1546     pixel = *(input - 1);
1547
1548     if(pixel == MAYBE)
1549     {
1550         // Touched a MAYBE point. MAYBE becomes EDGE
1551         *(input - 1) = EDGE;
1552
1553         edge_trace_recursive_U8_U8(input - 1, output - 1, input_stride, output_stride);
1554     }
1555
1556     // (0, +1)
1557     pixel = *input;
1558
1559     if(pixel == MAYBE)
1560     {
1561         // Touched a MAYBE point. MAYBE becomes EDGE
1562         *input = EDGE;
1563
1564         edge_trace_recursive_U8_U8(input, output, input_stride, output_stride);
1565     }
1566
1567     // (+1, +1)
1568     pixel = *(input + 1);
1569
1570     if(pixel == MAYBE)
1571     {
1572         // Touched a MAYBE point. MAYBE becomes EDGE
1573         *(input + 1) = EDGE;
1574
1575         edge_trace_recursive_U8_U8(input + 1, output + 1, input_stride, output_stride);
1576     }
1577 }
1578
1579 /* Computes edge tracing
1580  *
1581  * @param[in]  input         Pointer to source image. Data type supported U8
1582  * @param[out] output        Pointer to destination image. Data type supported U8
1583  * @param[in]  input_stride  Stride of the input image
1584  * @param[in]  output_stride Stride of the output image
1585  */
1586 void edge_trace_U8_U8(uint8_t *__restrict input, uint8_t *__restrict output, const int32_t input_stride, const int32_t output_stride)
1587 {
1588     if(*input == NO_EDGE)
1589     {
1590         *output = NO_EDGE;
1591     }
1592     // Check if EDGE and not yet touched
1593     else if((*input == EDGE) && (*output == NO_EDGE))
1594     {
1595         edge_trace_recursive_U8_U8(input, output, input_stride, output_stride);
1596     }
1597 }
1598 } // namespace
1599
1600 NEGradientKernel::NEGradientKernel()
1601     : _func(nullptr), _gx(nullptr), _gy(nullptr), _magnitude(nullptr), _phase(nullptr)
1602 {
1603 }
1604
1605 void NEGradientKernel::configure(const ITensor *gx, const ITensor *gy, ITensor *magnitude, ITensor *phase, int32_t norm_type)
1606 {
1607     ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(gx, 1, DataType::S16, DataType::S32);
1608     ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(gy, 1, DataType::S16, DataType::S32);
1609     ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(magnitude, 1, DataType::U16, DataType::U32);
1610     ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(phase, 1, DataType::U8);
1611     ARM_COMPUTE_ERROR_ON_MSG(element_size_from_data_type(gx->info()->data_type()) != element_size_from_data_type(gy->info()->data_type()), "Gx and Gy must have the same element size");
1612     ARM_COMPUTE_ERROR_ON_MSG(element_size_from_data_type(gx->info()->data_type()) != element_size_from_data_type(magnitude->info()->data_type()), "Magnitude must have the same element size as Gx and Gy");
1613
1614     _gx        = gx;
1615     _gy        = gy;
1616     _magnitude = magnitude;
1617     _phase     = phase;
1618
1619     if(_gx->info()->data_type() == DataType::S16)
1620     {
1621         if(norm_type == 1)
1622         {
1623             _func = &mag_phase_l1norm_S16_S16_U16_U8;
1624         }
1625         else
1626         {
1627             _func = &mag_phase_l2norm_S16_S16_U16_U8;
1628         }
1629     }
1630     else
1631     {
1632         if(norm_type == 1)
1633         {
1634             _func = &mag_phase_l1norm_S32_S32_U32_U8;
1635         }
1636         else
1637         {
1638             _func = &mag_phase_l2norm_S32_S32_U32_U8;
1639         }
1640     }
1641
1642     constexpr unsigned int num_elems_processed_per_iteration = 32;
1643
1644     // Configure kernel window
1645     Window win = calculate_max_window(*_gx->info(), Steps(num_elems_processed_per_iteration));
1646
1647     AccessWindowHorizontal gx_access(_gx->info(), 0, num_elems_processed_per_iteration);
1648     AccessWindowHorizontal gy_access(_gy->info(), 0, num_elems_processed_per_iteration);
1649     AccessWindowHorizontal mag_access(_magnitude->info(), 0, num_elems_processed_per_iteration);
1650     AccessWindowHorizontal phase_access(_phase->info(), 0, num_elems_processed_per_iteration);
1651
1652     update_window_and_padding(win, gx_access, gy_access, mag_access, phase_access);
1653
1654     mag_access.set_valid_region(win, _gx->info()->valid_region());
1655     phase_access.set_valid_region(win, _gx->info()->valid_region());
1656
1657     INEKernel::configure(win);
1658 }
1659
1660 void NEGradientKernel::run(const Window &window)
1661 {
1662     ARM_COMPUTE_ERROR_ON_UNCONFIGURED_KERNEL(this);
1663     ARM_COMPUTE_ERROR_ON_INVALID_SUBWINDOW(INEKernel::window(), window);
1664     ARM_COMPUTE_ERROR_ON(_func == nullptr);
1665     Iterator gx(_gx, window);
1666     Iterator gy(_gy, window);
1667     Iterator magnitude(_magnitude, window);
1668     Iterator phase(_phase, window);
1669
1670     execute_window_loop(window, [&](const Coordinates & id)
1671     {
1672         (*_func)(gx.ptr(), gy.ptr(), magnitude.ptr(), phase.ptr());
1673     },
1674     gx, gy, magnitude, phase);
1675 }
1676
1677 NEEdgeNonMaxSuppressionKernel::NEEdgeNonMaxSuppressionKernel()
1678     : _func(nullptr), _magnitude(nullptr), _phase(nullptr), _output(nullptr), _lower_thr(0), _upper_thr(0)
1679 {
1680 }
1681
1682 BorderSize NEEdgeNonMaxSuppressionKernel::border_size() const
1683 {
1684     return BorderSize(1);
1685 }
1686
1687 void NEEdgeNonMaxSuppressionKernel::configure(const ITensor *magnitude, const ITensor *phase, ITensor *output,
1688                                               int32_t upper_thr, int32_t lower_thr, bool border_undefined)
1689 {
1690     ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(magnitude, 1, DataType::U16, DataType::U32);
1691     ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(phase, 1, DataType::U8);
1692     ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(output, 1, DataType::U8);
1693
1694     _magnitude = magnitude;
1695     _phase     = phase;
1696     _output    = output;
1697
1698     switch(_magnitude->info()->data_type())
1699     {
1700         case DataType::U16:
1701             _func = &non_max_suppression_U16_U8_U8;
1702             break;
1703         case DataType::U32:
1704             _func = &non_max_suppression_U32_U8_U8;
1705             break;
1706         default:
1707             ARM_COMPUTE_ERROR("Unsupported data type!");
1708     }
1709
1710     // Set thresholds
1711     _lower_thr = lower_thr;
1712     _upper_thr = upper_thr;
1713
1714     constexpr unsigned int num_elems_processed_per_iteration = 8;
1715     constexpr unsigned int num_elems_read_per_iteration      = 10;
1716     constexpr unsigned int num_rows_read_per_iteration       = 3;
1717
1718     // Configure kernel window
1719     Window win = calculate_max_window(*_magnitude->info(), Steps(num_elems_processed_per_iteration), border_undefined, border_size());
1720
1721     AccessWindowRectangle  mag_access(_magnitude->info(), -border_size().left, -border_size().top, num_elems_read_per_iteration, num_rows_read_per_iteration);
1722     AccessWindowHorizontal phase_access(_phase->info(), 0, num_elems_processed_per_iteration);
1723     AccessWindowHorizontal output_access(_output->info(), 0, num_elems_processed_per_iteration);
1724
1725     update_window_and_padding(win, mag_access, phase_access, output_access);
1726
1727     output_access.set_valid_region(win, _magnitude->info()->valid_region(), border_undefined, border_size());
1728
1729     INEKernel::configure(win);
1730 }
1731
1732 void NEEdgeNonMaxSuppressionKernel::run(const Window &window)
1733 {
1734     ARM_COMPUTE_ERROR_ON_UNCONFIGURED_KERNEL(this);
1735     ARM_COMPUTE_ERROR_ON_INVALID_SUBWINDOW(INEKernel::window(), window);
1736     ARM_COMPUTE_ERROR_ON(_func == nullptr);
1737     Iterator magnitude(_magnitude, window);
1738     Iterator phase(_phase, window);
1739     Iterator output(_output, window);
1740
1741     const size_t input1_stride        = _magnitude->info()->strides_in_bytes()[1];
1742     const size_t input1_stride_ushort = input1_stride / data_size_from_type(_magnitude->info()->data_type());
1743
1744     execute_window_loop(window, [&](const Coordinates & id)
1745     {
1746         (*_func)(magnitude.ptr(), phase.ptr(), output.ptr(), input1_stride_ushort, _lower_thr, _upper_thr);
1747     },
1748     magnitude, phase, output);
1749 }
1750
1751 NEEdgeTraceKernel::NEEdgeTraceKernel()
1752     : _input(nullptr), _output(nullptr)
1753 {
1754 }
1755
1756 BorderSize NEEdgeTraceKernel::border_size() const
1757 {
1758     return BorderSize(1);
1759 }
1760
1761 bool NEEdgeTraceKernel::is_parallelisable() const
1762 {
1763     return false;
1764 }
1765
1766 void NEEdgeTraceKernel::configure(ITensor *input, ITensor *output)
1767 {
1768     ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::U8);
1769     ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(output, 1, DataType::U8);
1770
1771     _input  = input;
1772     _output = output;
1773
1774     constexpr unsigned int num_elems_processed_per_iteration = 1;
1775
1776     // Configure kernel window
1777     Window win = calculate_max_window(*_input->info(), Steps(num_elems_processed_per_iteration));
1778
1779     const ValidRegion &input_valid_region  = input->info()->valid_region();
1780     const ValidRegion &output_valid_region = output->info()->valid_region();
1781
1782     // Reads can occur within the valid region of the input + border
1783     AccessWindowStatic input_access(input->info(),
1784                                     input_valid_region.anchor[0] - border_size().left,
1785                                     input_valid_region.anchor[1] - border_size().top,
1786                                     input_valid_region.anchor[0] + input_valid_region.shape[0] + border_size().right,
1787                                     input_valid_region.anchor[1] + input_valid_region.shape[1] + border_size().bottom);
1788
1789     // Writes can occur within the valid region of the output + border
1790     AccessWindowStatic output_access(output->info(),
1791                                      output_valid_region.anchor[0] - border_size().left,
1792                                      output_valid_region.anchor[1] - border_size().top,
1793                                      output_valid_region.anchor[0] + output_valid_region.shape[0] + border_size().right,
1794                                      output_valid_region.anchor[1] + output_valid_region.shape[1] + border_size().bottom);
1795
1796     update_window_and_padding(win, input_access, output_access);
1797
1798     output_access.set_valid_region(win, _input->info()->valid_region());
1799
1800     INEKernel::configure(win);
1801 }
1802
1803 void NEEdgeTraceKernel::run(const Window &window)
1804 {
1805     ARM_COMPUTE_ERROR_ON_UNCONFIGURED_KERNEL(this);
1806     ARM_COMPUTE_ERROR_ON_INVALID_SUBWINDOW(INEKernel::window(), window);
1807     Iterator input(_input, window);
1808     Iterator output(_output, window);
1809
1810     const size_t input_stride  = _input->info()->strides_in_bytes()[1];
1811     const size_t output_stride = _output->info()->strides_in_bytes()[1];
1812
1813     execute_window_loop(window, [&](const Coordinates & id)
1814     {
1815         edge_trace_U8_U8(input.ptr(), output.ptr(), input_stride, output_stride);
1816     },
1817     input, output);
1818 }