7336f5dc4bc48705f0a083b5d3f5cebda9d8ed7d
[platform/upstream/armcl.git] / src / core / NEON / kernels / NEMagnitudePhaseKernel.cpp
1 /*
2  * Copyright (c) 2016, 2017 ARM Limited.
3  *
4  * SPDX-License-Identifier: MIT
5  *
6  * Permission is hereby granted, free of charge, to any person obtaining a copy
7  * of this software and associated documentation files (the "Software"), to
8  * deal in the Software without restriction, including without limitation the
9  * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10  * sell copies of the Software, and to permit persons to whom the Software is
11  * furnished to do so, subject to the following conditions:
12  *
13  * The above copyright notice and this permission notice shall be included in all
14  * copies or substantial portions of the Software.
15  *
16  * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17  * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18  * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19  * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20  * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21  * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22  * SOFTWARE.
23  */
24 #include "arm_compute/core/NEON/kernels/NEMagnitudePhaseKernel.h"
25
26 #include "arm_compute/core/Error.h"
27 #include "arm_compute/core/Helpers.h"
28 #include "arm_compute/core/IAccessWindow.h"
29 #include "arm_compute/core/ITensor.h"
30 #include "arm_compute/core/Validate.h"
31
32 #include <arm_neon.h>
33 #include <cstdint>
34
35 using namespace arm_compute;
36
37 namespace arm_compute
38 {
39 class Coordinates;
40 } // namespace arm_compute
41
42 namespace
43 {
44 // Defines for computing atan2
45 constexpr float SCALE_FACTOR = 0.7111111111111111f;
46 constexpr float PI           = 3.141592653589793f;
47 constexpr float SCALE_180    = 180.0f / PI;
48 constexpr float SCALE_360    = SCALE_180 * SCALE_FACTOR;
49 constexpr float PI_4         = 0.7853981633974483f;
50 constexpr float COEFF1       = 0.0663f;
51 constexpr float COEFF2       = 0.2447f;
52 } // namespace
53
54 #ifdef ARM_COMPUTE_ENABLE_FP16
55 namespace fp16
56 {
57 inline float16x8_t inv(float16x8_t x)
58 {
59     const float16x8_t estimate = vrecpeq_f16(x);
60     return vmulq_f16(estimate, vrecpsq_f16(x, estimate));
61 }
62
63 inline float16x8_t atan2_fast(float16x8_t gx, float16x8_t gy, float16x8_t scale)
64 {
65     static const float16x8_t one     = vdupq_n_f16(1.0f);
66     static const float16x8_t ninety  = vdupq_n_f16(90.f * SCALE_FACTOR);
67     static const float16x8_t epsilon = vdupq_n_f16(1e-9f);
68     static const float16x8_t piover4 = vdupq_n_f16(PI_4);
69     static const float16x8_t coeff1  = vdupq_n_f16(COEFF1);
70     static const float16x8_t coeff2  = vdupq_n_f16(COEFF2);
71
72     const float16x8_t abs_gx = vabsq_f16(gx);
73     const float16x8_t abs_gy = vabsq_f16(gy);
74     const float16x8_t tmin   = vminq_f16(abs_gx, abs_gy);
75     const float16x8_t tmax   = vmaxq_f16(abs_gx, abs_gy);
76
77     // z = min(x, y) / max(x, y)
78     const float16x8_t z    = vmulq_f16(tmin, inv(vaddq_f16(tmax, epsilon)));
79     const float16x8_t absz = vabsq_f16(z);
80
81     //                   = x * [pi/4 + (1 - |x|) * (0.2447 + 0.0663 * |x|)]
82     float16x8_t arctan = vmulq_f16(z, vfmaq_f16(piover4,
83                                                 vsubq_f16(one, absz),
84                                                 vfmaq_f16(coeff2, coeff1, absz)));
85
86     // Radians to degrees conversion with applied a scale factor in order to have the result [0, 255]
87     arctan = vmulq_f16(arctan, scale);
88
89     /* If z > 1, result = 90 - result */
90     return vbslq_f16(vcgeq_f16(abs_gx, abs_gy), arctan, vsubq_f16(ninety, arctan));
91 }
92
93 inline float16x8_t atan2_0_360(float16x8_t gx, float16x8_t gy)
94 {
95     static const float16x8_t scale      = vdupq_n_f16(SCALE_360);
96     static const float16x8_t threesixty = vdupq_n_f16(360.0f * SCALE_FACTOR);
97     static const float16x8_t zero       = vdupq_n_f16(0.0f);
98     static const float16x8_t oneeighty  = vdupq_n_f16(180.0f * SCALE_FACTOR);
99
100     float16x8_t arctan = atan2_fast(gx, gy, scale);
101
102     // Choose correct quadrant
103     arctan = vbslq_f16(vcltq_f16(gx, zero), vsubq_f16(oneeighty, arctan), arctan);
104     arctan = vbslq_f16(vcltq_f16(gy, zero), vsubq_f16(threesixty, arctan), arctan);
105
106     return arctan;
107 }
108
109 inline float16x8_t atan2_0_180(float16x8_t gx, float16x8_t gy)
110 {
111     static const float16x8_t scale      = vdupq_n_f16(SCALE_180);
112     static const float16x8_t threesixty = vdupq_n_f16(360.0f * SCALE_FACTOR);
113     static const float16x8_t oneeighty  = vdupq_n_f16(180.0f * SCALE_FACTOR);
114     static const float16x8_t zero       = vdupq_n_f16(0.0f);
115
116     float16x8_t arctan = atan2_fast(gx, gy, scale);
117
118     // Choose correct quadrant
119     arctan = vbslq_f16(vcltq_f16(gx, zero), vsubq_f16(oneeighty, arctan), arctan);
120     arctan = vbslq_f16(vcltq_f16(gy, zero), vsubq_f16(threesixty, arctan), arctan);
121     arctan = vbslq_f16(vcgtq_f16(arctan, oneeighty), vsubq_f16(arctan, oneeighty), arctan);
122
123     return arctan;
124 }
125
126 inline float32x4_t invsqrtv(float32x4_t x)
127 {
128     float32x4_t sqrt_reciprocal = vrsqrteq_f32(x);
129
130     sqrt_reciprocal = vmulq_f32(vrsqrtsq_f32(vmulq_f32(x, sqrt_reciprocal), sqrt_reciprocal),
131                                 sqrt_reciprocal);
132     sqrt_reciprocal = vmulq_f32(vrsqrtsq_f32(vmulq_f32(x, sqrt_reciprocal), sqrt_reciprocal),
133                                 sqrt_reciprocal);
134
135     return sqrt_reciprocal;
136 }
137
138 inline float32x4_t sqrtv(float32x4_t x)
139 {
140     float32x4_t res = vdupq_n_f32(0.5f);
141     return vmlaq_f32(res, x, invsqrtv(x));
142 }
143
144 inline int16x8_t magnitude_l1(int16x8_t input1, int16x8_t input2)
145 {
146     return vqaddq_s16(vabsq_s16(input1), vabsq_s16(input2));
147 }
148
149 inline int16x8_t magnitude_l2(int16x8_t input1, int16x8_t input2)
150 {
151     const int32x4x2_t square_x =
152     {
153         vmull_s16(vget_low_s16(input1), vget_low_s16(input1)),
154         vmull_s16(vget_high_s16(input1), vget_high_s16(input1))
155     };
156
157     const int32x4x2_t square_y =
158     {
159         vmull_s16(vget_low_s16(input2), vget_low_s16(input2)),
160         vmull_s16(vget_high_s16(input2), vget_high_s16(input2))
161     };
162
163     const uint32x4x2_t sum =
164     {
165         vaddq_u32(vreinterpretq_u32_s32(square_x.val[0]),
166         vreinterpretq_u32_s32(square_y.val[0])),
167         vaddq_u32(vreinterpretq_u32_s32(square_x.val[1]),
168         vreinterpretq_u32_s32(square_y.val[1]))
169     };
170
171     const float32x4x2_t res =
172     {
173         sqrtv(vcvtq_f32_u32(sum.val[0])),
174         sqrtv(vcvtq_f32_u32(sum.val[1]))
175     };
176
177     return vcombine_s16(vqmovn_s32(vcvtq_s32_f32(res.val[0])),
178                         vqmovn_s32(vcvtq_s32_f32(res.val[1])));
179 }
180
181 inline uint8x8_t phase_signed(int16x8_t input1, int16x8_t input2)
182 {
183     static const float16x8_t zeropointfive = vdupq_n_f16(0.5f);
184
185     const float16x8_t inputx_f16 = vcvtq_f16_s16(input1);
186     const float16x8_t inputy_f16 = vcvtq_f16_s16(input2);
187
188     // Compute fast atan2
189     const float16x8_t angle = atan2_0_360(inputx_f16, inputy_f16);
190
191     return vqmovun_s16(vcvtq_s16_f16(vaddq_f16(angle, zeropointfive)));
192 }
193
194 inline uint8x8_t phase_unsigned(int16x8_t input1, int16x8_t input2)
195 {
196     static const float16x8_t zeropointfive = vdupq_n_f16(0.5f);
197
198     const float16x8_t inputx_f16 = vcvtq_f16_s16(input1);
199     const float16x8_t inputy_f16 = vcvtq_f16_s16(input2);
200
201     // Compute fast atan2
202     const float16x8_t angle = atan2_0_180(inputx_f16, inputy_f16);
203
204     return vqmovun_s16(vcvtq_s16_f16(vaddq_f16(angle, zeropointfive)));
205 }
206
207 template <MagnitudeType mag_type>
208 inline int16x8x2_t compute_magnitude(const int16x8x2_t &in0, const int16x8x2_t &gx);
209
210 template <>
211 inline int16x8x2_t compute_magnitude<MagnitudeType::L2NORM>(const int16x8x2_t &in0, const int16x8x2_t &gx)
212 {
213     const int16x8x2_t mag =
214     {
215         magnitude_l2(in0.val[0], gx.val[0]),
216         magnitude_l2(in0.val[1], gx.val[1])
217     };
218
219     return mag;
220 }
221
222 template <>
223 inline int16x8x2_t compute_magnitude<MagnitudeType::L1NORM>(const int16x8x2_t &in0, const int16x8x2_t &gx)
224 {
225     const int16x8x2_t mag =
226     {
227         magnitude_l1(in0.val[0], gx.val[0]),
228         magnitude_l1(in0.val[1], gx.val[1])
229     };
230
231     return mag;
232 }
233
234 template <PhaseType phase_type>
235 inline uint8x16_t compute_phase(const int16x8x2_t &in0, const int16x8x2_t &gx);
236
237 template <>
238 inline uint8x16_t compute_phase<PhaseType::SIGNED>(const int16x8x2_t &in0, const int16x8x2_t &gx)
239 {
240     return vcombine_u8(phase_signed(in0.val[0], gx.val[0]),
241                        phase_signed(in0.val[1], gx.val[1]));
242 }
243
244 template <>
245 inline uint8x16_t compute_phase<PhaseType::UNSIGNED>(const int16x8x2_t &in0, const int16x8x2_t &gx)
246 {
247     return vcombine_u8(phase_unsigned(in0.val[0], gx.val[0]),
248                        phase_unsigned(in0.val[1], gx.val[1]));
249 }
250 } // namespace fp16
251
252 template <MagnitudeType mag_type, PhaseType phase_type>
253 NEMagnitudePhaseFP16Kernel<mag_type, phase_type>::NEMagnitudePhaseFP16Kernel()
254     : _func(nullptr), _gx(nullptr), _gy(nullptr), _magnitude(nullptr), _phase(nullptr)
255 {
256 }
257
258 template <MagnitudeType mag_type, PhaseType phase_type>
259 void NEMagnitudePhaseFP16Kernel<mag_type, phase_type>::configure(const ITensor *gx, const ITensor *gy, ITensor *magnitude, ITensor *phase)
260 {
261     ARM_COMPUTE_ERROR_ON_FORMAT_NOT_IN(gx, Format::S16);
262     ARM_COMPUTE_ERROR_ON_FORMAT_NOT_IN(gy, Format::S16);
263     ARM_COMPUTE_ERROR_ON((nullptr == magnitude) && (nullptr == phase));
264
265     const bool run_mag   = magnitude != nullptr;
266     const bool run_phase = phase != nullptr;
267
268     if(run_mag)
269     {
270         ARM_COMPUTE_ERROR_ON_FORMAT_NOT_IN(magnitude, Format::S16);
271     }
272
273     if(run_phase)
274     {
275         ARM_COMPUTE_ERROR_ON_FORMAT_NOT_IN(phase, Format::U8);
276     }
277
278     _gx        = gx;
279     _gy        = gy;
280     _magnitude = magnitude;
281     _phase     = phase;
282
283     if(run_mag && run_phase)
284     {
285         /* Run magnitude and phase */
286         _func = &NEMagnitudePhaseFP16Kernel<mag_type, phase_type>::magnitude_phase;
287     }
288     else if(run_mag)
289     {
290         /* Run magnitude */
291         _func = &NEMagnitudePhaseFP16Kernel<mag_type, phase_type>::magnitude;
292     }
293     else if(run_phase)
294     {
295         /* Run phase */
296         _func = &NEMagnitudePhaseFP16Kernel<mag_type, phase_type>::phase;
297     }
298     else
299     {
300         ARM_COMPUTE_ERROR("At least one output must be NOT NULL");
301     }
302
303     const unsigned int processed_elements = 16;
304
305     // Configure kernel window
306     Window                 win = calculate_max_window(*gx->info(), Steps(processed_elements));
307     AccessWindowHorizontal magnitude_access(magnitude == nullptr ? nullptr : magnitude->info(), 0, processed_elements);
308     AccessWindowHorizontal phase_access(phase == nullptr ? nullptr : phase->info(), 0, processed_elements);
309
310     update_window_and_padding(win,
311                               AccessWindowHorizontal(gx->info(), 0, processed_elements),
312                               AccessWindowHorizontal(gy->info(), 0, processed_elements),
313                               magnitude_access,
314                               phase_access);
315
316     ValidRegion valid_region = intersect_valid_regions(gx->info()->valid_region(),
317                                                        gy->info()->valid_region());
318
319     magnitude_access.set_valid_region(win, valid_region);
320     phase_access.set_valid_region(win, valid_region);
321
322     INEKernel::configure(win);
323 }
324
325 template <MagnitudeType mag_type, PhaseType phase_type>
326 void NEMagnitudePhaseFP16Kernel<mag_type, phase_type>::magnitude(const Window &window)
327 {
328     Iterator gx(_gx, window);
329     Iterator gy(_gy, window);
330     Iterator magnitude(_magnitude, window);
331
332     execute_window_loop(window, [&](const Coordinates & id)
333     {
334         const int16x8x2_t input1 =
335         {
336             vld1q_s16(reinterpret_cast<int16_t *>(gx.ptr())),
337             vld1q_s16(reinterpret_cast<int16_t *>(gx.ptr()) + 8)
338         };
339
340         const int16x8x2_t input2 =
341         {
342             vld1q_s16(reinterpret_cast<int16_t *>(gy.ptr())),
343             vld1q_s16(reinterpret_cast<int16_t *>(gy.ptr()) + 8)
344         };
345
346         // Compute and store magnitude
347         const int16x8x2_t mag = fp16::compute_magnitude<mag_type>(input1, input2);
348
349         /* Store magnitude */
350         vst1q_s16(reinterpret_cast<int16_t *>(magnitude.ptr()), mag.val[0]);
351         vst1q_s16(reinterpret_cast<int16_t *>(magnitude.ptr()) + 8, mag.val[1]);
352     },
353     gx, gy, magnitude);
354 }
355
356 template <MagnitudeType mag_type, PhaseType phase_type>
357 void NEMagnitudePhaseFP16Kernel<mag_type, phase_type>::phase(const Window &window)
358 {
359     Iterator gx(_gx, window);
360     Iterator gy(_gy, window);
361     Iterator phase(_phase, window);
362
363     execute_window_loop(window, [&](const Coordinates & id)
364     {
365         const int16x8x2_t input1 =
366         {
367             vld1q_s16(reinterpret_cast<int16_t *>(gx.ptr())),
368             vld1q_s16(reinterpret_cast<int16_t *>(gx.ptr()) + 8)
369         };
370
371         const int16x8x2_t input2 =
372         {
373             vld1q_s16(reinterpret_cast<int16_t *>(gy.ptr())),
374             vld1q_s16(reinterpret_cast<int16_t *>(gy.ptr()) + 8)
375         };
376
377         // Compute and store phase
378         vst1q_u8(phase.ptr(), fp16::compute_phase<phase_type>(input1, input2));
379     },
380     gx, gy, phase);
381 }
382
383 template <MagnitudeType mag_type, PhaseType phase_type>
384 void NEMagnitudePhaseFP16Kernel<mag_type, phase_type>::magnitude_phase(const Window &window)
385 {
386     Iterator gx(_gx, window);
387     Iterator gy(_gy, window);
388     Iterator magnitude(_magnitude, window);
389     Iterator phase(_phase, window);
390
391     execute_window_loop(window, [&](const Coordinates & id)
392     {
393         const int16x8x2_t input1 =
394         {
395             vld1q_s16(reinterpret_cast<int16_t *>(gx.ptr())),
396             vld1q_s16(reinterpret_cast<int16_t *>(gx.ptr()) + 8)
397         };
398
399         const int16x8x2_t input2 =
400         {
401             vld1q_s16(reinterpret_cast<int16_t *>(gy.ptr())),
402             vld1q_s16(reinterpret_cast<int16_t *>(gy.ptr()) + 8)
403         };
404
405         // Compute and store magnitude
406         const int16x8x2_t mag = fp16::compute_magnitude<mag_type>(input1, input2);
407
408         vst1q_s16(reinterpret_cast<int16_t *>(magnitude.ptr()), mag.val[0]);
409         vst1q_s16(reinterpret_cast<int16_t *>(magnitude.ptr()) + 8, mag.val[1]);
410
411         // Compute and store phase
412         vst1q_u8(phase.ptr(), fp16::compute_phase<phase_type>(input1, input2));
413     },
414     gx, gy, magnitude, phase);
415 }
416
417 template <MagnitudeType mag_type, PhaseType phase_type>
418 void NEMagnitudePhaseFP16Kernel<mag_type, phase_type>::run(const Window &window)
419 {
420     ARM_COMPUTE_ERROR_ON_UNCONFIGURED_KERNEL(this);
421     ARM_COMPUTE_ERROR_ON_INVALID_SUBWINDOW(INEKernel::window(), window);
422     ARM_COMPUTE_ERROR_ON(_func == nullptr);
423
424     (this->*_func)(window);
425 }
426
427 template class arm_compute::NEMagnitudePhaseFP16Kernel<MagnitudeType::L1NORM, PhaseType::SIGNED>;
428 template class arm_compute::NEMagnitudePhaseFP16Kernel<MagnitudeType::L2NORM, PhaseType::SIGNED>;
429 template class arm_compute::NEMagnitudePhaseFP16Kernel<MagnitudeType::L1NORM, PhaseType::UNSIGNED>;
430 template class arm_compute::NEMagnitudePhaseFP16Kernel<MagnitudeType::L2NORM, PhaseType::UNSIGNED>;
431 #endif
432
433 namespace
434 {
435 inline float32x4_t inv(float32x4_t x)
436 {
437     float32x4_t result = vrecpeq_f32(x);
438     result             = vmulq_f32(vrecpsq_f32(x, result), result);
439     return result;
440 }
441
442 inline float32x4_t atan2_0_360(float32x4_t gx, float32x4_t gy)
443 {
444     const float32x4_t zero       = vdupq_n_f32(0.0f);
445     const float32x4_t epsilon    = vdupq_n_f32(1e-9f);
446     const float32x4_t piover4    = vdupq_n_f32(PI_4);
447     const float32x4_t coeff1     = vdupq_n_f32(COEFF1);
448     const float32x4_t coeff2     = vdupq_n_f32(COEFF2);
449     const float32x4_t ninety     = vdupq_n_f32(90.0f * SCALE_FACTOR);
450     const float32x4_t oneeighty  = vdupq_n_f32(180.0f * SCALE_FACTOR);
451     const float32x4_t threesixty = vdupq_n_f32(360.0f * SCALE_FACTOR);
452     const float32x4_t scale      = vdupq_n_f32(SCALE_360);
453
454     float32x4_t abs_gx = vabsq_f32(gx);
455     float32x4_t abs_gy = vabsq_f32(gy);
456     float32x4_t tmin   = vminq_f32(abs_gx, abs_gy);
457     float32x4_t tmax   = vmaxq_f32(abs_gx, abs_gy);
458     float32x4_t z      = vmulq_f32(tmin, inv(vaddq_f32(tmax, epsilon)));
459     float32x4_t absz   = vabsq_f32(z);
460     float32x4_t term   = vmulq_f32(z, vsubq_f32(vdupq_n_f32(1.0f), absz));
461
462     /* Compute y = pi/4 * x - x*(abs(x)-1)*(0.2447+0.0663 * abs(x) */
463     float32x4_t result = vaddq_f32(coeff2, vmulq_f32(absz, coeff1));
464     result             = vmulq_f32(result, term);
465     result             = vmlaq_f32(result, piover4, z);
466
467     /* Radians to degrees conversion with applied a scale factor in order to have the result [0, 255]  */
468     result = vmulq_f32(result, scale);
469
470     /* If z > 1, result = 90 - result */
471     result = vbslq_f32(vcgeq_f32(abs_gx, abs_gy), result, vsubq_f32(ninety, result));
472
473     /* Choose correct quadrant */
474     result = vbslq_f32(vcltq_f32(gx, zero), vsubq_f32(oneeighty, result), result);
475     result = vbslq_f32(vcltq_f32(gy, zero), vsubq_f32(threesixty, result), result);
476
477     return result;
478 }
479
480 inline float32x4_t atan2_0_180(float32x4_t gx, float32x4_t gy)
481 {
482     const float32x4_t zero       = vdupq_n_f32(0.0f);
483     const float32x4_t epsilon    = vdupq_n_f32(1e-9f); // epsilon used to avoiding division by 0
484     const float32x4_t piover4    = vdupq_n_f32(PI_4);
485     const float32x4_t coeff1     = vdupq_n_f32(COEFF1);
486     const float32x4_t coeff2     = vdupq_n_f32(COEFF2);
487     const float32x4_t ninety     = vdupq_n_f32(90.0f);
488     const float32x4_t oneeighty  = vdupq_n_f32(180.0f);
489     const float32x4_t threesixty = vdupq_n_f32(360.0f);
490     const float32x4_t scale      = vdupq_n_f32(SCALE_180);
491
492     float32x4_t abs_gx = vabsq_f32(gx);
493     float32x4_t abs_gy = vabsq_f32(gy);
494     float32x4_t tmin   = vminq_f32(abs_gx, abs_gy);
495     float32x4_t tmax   = vmaxq_f32(abs_gx, abs_gy);
496     float32x4_t z      = vmulq_f32(tmin, inv(vaddq_f32(tmax, epsilon)));
497     float32x4_t absz   = vabsq_f32(z);
498
499     /* Compute y = pi/4 * z - z*(abs(z)-1)*(0.2447+0.0663 * abs(z) */
500     float32x4_t term   = vmulq_f32(z, vsubq_f32(vdupq_n_f32(1.0f), absz));
501     float32x4_t result = vaddq_f32(coeff2, vmulq_f32(absz, coeff1));
502     result             = vmulq_f32(result, term);
503     result             = vmlaq_f32(result, piover4, z);
504
505     /* Radians to degrees conversion */
506     result = vmulq_f32(result, scale);
507
508     /* If z > 1, result = 90 - result */
509     result = vbslq_f32(vcgeq_f32(abs_gx, abs_gy), result, vsubq_f32(ninety, result));
510
511     /* Choose correct quadrant */
512     result = vbslq_f32(vcltq_f32(gx, zero), vsubq_f32(oneeighty, result), result);
513     result = vbslq_f32(vcltq_f32(gy, zero), vsubq_f32(threesixty, result), result);
514     result = vbslq_f32(vcgtq_f32(result, oneeighty), vsubq_f32(result, oneeighty), result);
515
516     return result;
517 }
518
519 inline float32x4_t invsqrtv(float32x4_t x)
520 {
521     float32x4_t sqrt_reciprocal = vrsqrteq_f32(x);
522
523     sqrt_reciprocal = vmulq_f32(vrsqrtsq_f32(vmulq_f32(x, sqrt_reciprocal), sqrt_reciprocal),
524                                 sqrt_reciprocal);
525     sqrt_reciprocal = vmulq_f32(vrsqrtsq_f32(vmulq_f32(x, sqrt_reciprocal), sqrt_reciprocal),
526                                 sqrt_reciprocal);
527
528     return sqrt_reciprocal;
529 }
530
531 inline float32x4_t sqrtv(float32x4_t x)
532 {
533     float32x4_t res = vdupq_n_f32(0.5f);
534     return vmlaq_f32(res, x, invsqrtv(x));
535 }
536
537 inline int16x8_t magnitude_l2(int16x8_t input1, int16x8_t input2)
538 {
539     const int32x4x2_t square_x =
540     {
541         {
542             vmull_s16(vget_low_s16(input1), vget_low_s16(input1)),
543             vmull_s16(vget_high_s16(input1), vget_high_s16(input1))
544         }
545     };
546
547     const int32x4x2_t square_y =
548     {
549         {
550             vmull_s16(vget_low_s16(input2), vget_low_s16(input2)),
551             vmull_s16(vget_high_s16(input2), vget_high_s16(input2))
552         }
553     };
554
555     const uint32x4x2_t sum =
556     {
557         {
558             vaddq_u32(vreinterpretq_u32_s32(square_x.val[0]), vreinterpretq_u32_s32(square_y.val[0])),
559             vaddq_u32(vreinterpretq_u32_s32(square_x.val[1]), vreinterpretq_u32_s32(square_y.val[1]))
560         }
561     };
562
563     const float32x4x2_t res =
564     {
565         {
566             sqrtv(vcvtq_f32_u32(sum.val[0])),
567             sqrtv(vcvtq_f32_u32(sum.val[1]))
568         }
569     };
570
571     return vcombine_s16(vqmovn_s32(vcvtq_s32_f32(res.val[0])),
572                         vqmovn_s32(vcvtq_s32_f32(res.val[1])));
573 }
574
575 inline int16x8_t magnitude_l1(int16x8_t input1, int16x8_t input2)
576 {
577     int16x8_t gx_abs = vabsq_s16(input1);
578     int16x8_t gy_abs = vabsq_s16(input2);
579
580     /* Saturating add */
581     return vqaddq_s16(gx_abs, gy_abs);
582 }
583
584 inline uint8x8_t phase_signed(int16x8_t input1, int16x8_t input2)
585 {
586     const float32x4_t zeropointfive = vdupq_n_f32(0.5f);
587
588     float32x4_t inputx_f32_high = vcvtq_f32_s32(vmovl_s16(vget_high_s16(input1)));
589     float32x4_t inputx_f32_low  = vcvtq_f32_s32(vmovl_s16(vget_low_s16(input1)));
590     float32x4_t inputy_f32_high = vcvtq_f32_s32(vmovl_s16(vget_high_s16(input2)));
591     float32x4_t inputy_f32_low  = vcvtq_f32_s32(vmovl_s16(vget_low_s16(input2)));
592
593     /* Compute fast atan2 */
594     float32x4_t angle_high = atan2_0_360(inputx_f32_high, inputy_f32_high);
595     float32x4_t angle_low  = atan2_0_360(inputx_f32_low, inputy_f32_low);
596
597     angle_high = vaddq_f32(angle_high, zeropointfive);
598     angle_low  = vaddq_f32(angle_low, zeropointfive);
599
600     return vmovn_u16(vcombine_u16(vqmovun_s32(vcvtq_s32_f32(angle_low)),
601                                   vqmovun_s32(vcvtq_s32_f32(angle_high))));
602 }
603
604 inline uint8x8_t phase_unsigned(int16x8_t input1, int16x8_t input2)
605 {
606     const float32x4_t zeropointfive = vdupq_n_f32(0.5f);
607
608     float32x4_t inputx_f32_high = vcvtq_f32_s32(vmovl_s16(vget_high_s16(input1)));
609     float32x4_t inputx_f32_low  = vcvtq_f32_s32(vmovl_s16(vget_low_s16(input1)));
610     float32x4_t inputy_f32_high = vcvtq_f32_s32(vmovl_s16(vget_high_s16(input2)));
611     float32x4_t inputy_f32_low  = vcvtq_f32_s32(vmovl_s16(vget_low_s16(input2)));
612
613     /* Compute fast atan2 */
614     float32x4_t angle_high = atan2_0_180(inputx_f32_high, inputy_f32_high);
615     float32x4_t angle_low  = atan2_0_180(inputx_f32_low, inputy_f32_low);
616
617     angle_high = vaddq_f32(angle_high, zeropointfive);
618     angle_low  = vaddq_f32(angle_low, zeropointfive);
619
620     return vmovn_u16(vcombine_u16(vqmovun_s32(vcvtq_s32_f32(angle_low)),
621                                   vqmovun_s32(vcvtq_s32_f32(angle_high))));
622 }
623 } // namespace
624
625 template <MagnitudeType mag_type, PhaseType phase_type>
626 NEMagnitudePhaseKernel<mag_type, phase_type>::NEMagnitudePhaseKernel()
627     : _func(nullptr), _gx(nullptr), _gy(nullptr), _magnitude(nullptr), _phase(nullptr)
628 {
629 }
630
631 template <MagnitudeType mag_type, PhaseType phase_type>
632 void NEMagnitudePhaseKernel<mag_type, phase_type>::configure(const ITensor *gx, const ITensor *gy, ITensor *magnitude, ITensor *phase)
633 {
634     ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(gx, 1, DataType::S16);
635     ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(gy, 1, DataType::S16);
636     ARM_COMPUTE_ERROR_ON((nullptr == magnitude) && (nullptr == phase));
637
638     const bool run_mag   = magnitude != nullptr;
639     const bool run_phase = phase != nullptr;
640
641     if(run_mag)
642     {
643         ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(magnitude, 1, DataType::S16);
644     }
645
646     if(run_phase)
647     {
648         ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(phase, 1, DataType::U8);
649     }
650
651     _gx        = gx;
652     _gy        = gy;
653     _magnitude = magnitude;
654     _phase     = phase;
655
656     if(run_mag && run_phase)
657     {
658         /* Run magnitude and phase */
659         _func = &NEMagnitudePhaseKernel<mag_type, phase_type>::magnitude_phase;
660     }
661     else
662     {
663         if(run_mag)
664         {
665             /* Run magnitude */
666             _func = &NEMagnitudePhaseKernel<mag_type, phase_type>::magnitude;
667         }
668         else if(run_phase)
669         {
670             /* Run phase */
671             _func = &NEMagnitudePhaseKernel<mag_type, phase_type>::phase;
672         }
673         else
674         {
675             ARM_COMPUTE_ERROR("At least one output must be NOT NULL");
676         }
677     }
678
679     const unsigned int processed_elements = 16;
680
681     // Configure kernel window
682     Window                 win = calculate_max_window(*gx->info(), Steps(processed_elements));
683     AccessWindowHorizontal magnitude_access(magnitude == nullptr ? nullptr : magnitude->info(), 0, processed_elements);
684     AccessWindowHorizontal phase_access(phase == nullptr ? nullptr : phase->info(), 0, processed_elements);
685
686     update_window_and_padding(win,
687                               AccessWindowHorizontal(gx->info(), 0, processed_elements),
688                               AccessWindowHorizontal(gy->info(), 0, processed_elements),
689                               magnitude_access,
690                               phase_access);
691
692     ValidRegion valid_region = intersect_valid_regions(gx->info()->valid_region(),
693                                                        gy->info()->valid_region());
694
695     magnitude_access.set_valid_region(win, valid_region);
696     phase_access.set_valid_region(win, valid_region);
697
698     INEKernel::configure(win);
699 }
700
701 template <MagnitudeType mag_type, PhaseType phase_type>
702 void NEMagnitudePhaseKernel<mag_type, phase_type>::magnitude(const Window &window)
703 {
704     Iterator gx(_gx, window);
705     Iterator gy(_gy, window);
706     Iterator magnitude(_magnitude, window);
707
708     execute_window_loop(window, [&](const Coordinates & id)
709     {
710         const int16x8x2_t input1 =
711         {
712             {
713                 vld1q_s16(reinterpret_cast<int16_t *>(gx.ptr())),
714                 vld1q_s16(reinterpret_cast<int16_t *>(gx.ptr()) + 8)
715             }
716         };
717
718         const int16x8x2_t input2 =
719         {
720             {
721                 vld1q_s16(reinterpret_cast<int16_t *>(gy.ptr())),
722                 vld1q_s16(reinterpret_cast<int16_t *>(gy.ptr()) + 8)
723             }
724         };
725
726         /* Compute magnitude */
727         int16x8x2_t mag{ {} };
728
729         if(MagnitudeType::L2NORM == mag_type)
730         {
731             mag.val[0] = magnitude_l2(input1.val[0], input2.val[0]);
732             mag.val[1] = magnitude_l2(input1.val[1], input2.val[1]);
733         }
734         else
735         {
736             mag.val[0] = magnitude_l1(input1.val[0], input2.val[0]);
737             mag.val[1] = magnitude_l1(input1.val[1], input2.val[1]);
738         }
739
740         /* Store magnitude */
741         vst1q_s16(reinterpret_cast<int16_t *>(magnitude.ptr()), mag.val[0]);
742         vst1q_s16(reinterpret_cast<int16_t *>(magnitude.ptr()) + 8, mag.val[1]);
743     },
744     gx, gy, magnitude);
745 }
746
747 template <MagnitudeType mag_type, PhaseType phase_type>
748 void NEMagnitudePhaseKernel<mag_type, phase_type>::phase(const Window &window)
749 {
750     Iterator gx(_gx, window);
751     Iterator gy(_gy, window);
752     Iterator phase(_phase, window);
753
754     execute_window_loop(window, [&](const Coordinates & id)
755     {
756         const int16x8x2_t input1 =
757         {
758             {
759                 vld1q_s16(reinterpret_cast<int16_t *>(gx.ptr())),
760                 vld1q_s16(reinterpret_cast<int16_t *>(gx.ptr()) + 8)
761             }
762         };
763
764         const int16x8x2_t input2 =
765         {
766             {
767                 vld1q_s16(reinterpret_cast<int16_t *>(gy.ptr())),
768                 vld1q_s16(reinterpret_cast<int16_t *>(gy.ptr()) + 8)
769             }
770         };
771
772         /* Compute phase */
773         uint8x8x2_t vphase{ {} };
774
775         if(PhaseType::SIGNED == phase_type)
776         {
777             vphase.val[0] = phase_signed(input1.val[0], input2.val[0]);
778             vphase.val[1] = phase_signed(input1.val[1], input2.val[1]);
779         }
780         else
781         {
782             vphase.val[0] = phase_unsigned(input1.val[0], input2.val[0]);
783             vphase.val[1] = phase_unsigned(input1.val[1], input2.val[1]);
784         }
785
786         /* Store phase */
787         vst1q_u8(phase.ptr(), vcombine_u8(vphase.val[0], vphase.val[1]));
788     },
789     gx, gy, phase);
790 }
791
792 template <MagnitudeType mag_type, PhaseType phase_type>
793 void NEMagnitudePhaseKernel<mag_type, phase_type>::magnitude_phase(const Window &window)
794 {
795     Iterator gx(_gx, window);
796     Iterator gy(_gy, window);
797     Iterator magnitude(_magnitude, window);
798     Iterator phase(_phase, window);
799
800     execute_window_loop(window, [&](const Coordinates & id)
801     {
802         const int16x8x2_t input1 =
803         {
804             {
805                 vld1q_s16(reinterpret_cast<int16_t *>(gx.ptr())),
806                 vld1q_s16(reinterpret_cast<int16_t *>(gx.ptr()) + 8)
807             }
808         };
809
810         const int16x8x2_t input2 =
811         {
812             {
813                 vld1q_s16(reinterpret_cast<int16_t *>(gy.ptr())),
814                 vld1q_s16(reinterpret_cast<int16_t *>(gy.ptr()) + 8)
815             }
816         };
817
818         /* Compute magnitude */
819         int16x8x2_t mag{ {} };
820
821         if(MagnitudeType::L2NORM == mag_type)
822         {
823             mag.val[0] = magnitude_l2(input1.val[0], input2.val[0]);
824             mag.val[1] = magnitude_l2(input1.val[1], input2.val[1]);
825         }
826         else
827         {
828             mag.val[0] = magnitude_l1(input1.val[0], input2.val[0]);
829             mag.val[1] = magnitude_l1(input1.val[1], input2.val[1]);
830         }
831
832         /* Store magnitude */
833         vst1q_s16(reinterpret_cast<int16_t *>(magnitude.ptr()), mag.val[0]);
834         vst1q_s16(reinterpret_cast<int16_t *>(magnitude.ptr()) + 8, mag.val[1]);
835
836         /* Compute phase */
837         uint8x8x2_t vphase{ {} };
838
839         if(PhaseType::SIGNED == phase_type)
840         {
841             vphase.val[0] = phase_signed(input1.val[0], input2.val[0]);
842             vphase.val[1] = phase_signed(input1.val[1], input2.val[1]);
843         }
844         else
845         {
846             vphase.val[0] = phase_unsigned(input1.val[0], input2.val[0]);
847             vphase.val[1] = phase_unsigned(input1.val[1], input2.val[1]);
848         }
849
850         /* Store phase */
851         vst1q_u8(phase.ptr(), vcombine_u8(vphase.val[0], vphase.val[1]));
852     },
853     gx, gy, magnitude, phase);
854 }
855
856 template <MagnitudeType mag_type, PhaseType phase_type>
857 void NEMagnitudePhaseKernel<mag_type, phase_type>::run(const Window &window)
858 {
859     ARM_COMPUTE_ERROR_ON_UNCONFIGURED_KERNEL(this);
860     ARM_COMPUTE_ERROR_ON_INVALID_SUBWINDOW(INEKernel::window(), window);
861     ARM_COMPUTE_ERROR_ON(_func == nullptr);
862
863     (this->*_func)(window);
864 }
865
866 template class arm_compute::NEMagnitudePhaseKernel<MagnitudeType::L1NORM, PhaseType::SIGNED>;
867 template class arm_compute::NEMagnitudePhaseKernel<MagnitudeType::L2NORM, PhaseType::SIGNED>;
868 template class arm_compute::NEMagnitudePhaseKernel<MagnitudeType::L1NORM, PhaseType::UNSIGNED>;
869 template class arm_compute::NEMagnitudePhaseKernel<MagnitudeType::L2NORM, PhaseType::UNSIGNED>;