2 * Copyright (c) 2016, 2017 ARM Limited.
4 * SPDX-License-Identifier: MIT
6 * Permission is hereby granted, free of charge, to any person obtaining a copy
7 * of this software and associated documentation files (the "Software"), to
8 * deal in the Software without restriction, including without limitation the
9 * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10 * sell copies of the Software, and to permit persons to whom the Software is
11 * furnished to do so, subject to the following conditions:
13 * The above copyright notice and this permission notice shall be included in all
14 * copies or substantial portions of the Software.
16 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
24 #include "arm_compute/core/NEON/kernels/NEMagnitudePhaseKernel.h"
26 #include "arm_compute/core/Error.h"
27 #include "arm_compute/core/Helpers.h"
28 #include "arm_compute/core/IAccessWindow.h"
29 #include "arm_compute/core/ITensor.h"
30 #include "arm_compute/core/Validate.h"
35 using namespace arm_compute;
40 } // namespace arm_compute
44 // Defines for computing atan2
45 constexpr float SCALE_FACTOR = 0.7111111111111111f;
46 constexpr float PI = 3.141592653589793f;
47 constexpr float SCALE_180 = 180.0f / PI;
48 constexpr float SCALE_360 = SCALE_180 * SCALE_FACTOR;
49 constexpr float PI_4 = 0.7853981633974483f;
50 constexpr float COEFF1 = 0.0663f;
51 constexpr float COEFF2 = 0.2447f;
54 #ifdef ARM_COMPUTE_ENABLE_FP16
57 inline float16x8_t inv(float16x8_t x)
59 const float16x8_t estimate = vrecpeq_f16(x);
60 return vmulq_f16(estimate, vrecpsq_f16(x, estimate));
63 inline float16x8_t atan2_fast(float16x8_t gx, float16x8_t gy, float16x8_t scale)
65 static const float16x8_t one = vdupq_n_f16(1.0f);
66 static const float16x8_t ninety = vdupq_n_f16(90.f * SCALE_FACTOR);
67 static const float16x8_t epsilon = vdupq_n_f16(1e-9f);
68 static const float16x8_t piover4 = vdupq_n_f16(PI_4);
69 static const float16x8_t coeff1 = vdupq_n_f16(COEFF1);
70 static const float16x8_t coeff2 = vdupq_n_f16(COEFF2);
72 const float16x8_t abs_gx = vabsq_f16(gx);
73 const float16x8_t abs_gy = vabsq_f16(gy);
74 const float16x8_t tmin = vminq_f16(abs_gx, abs_gy);
75 const float16x8_t tmax = vmaxq_f16(abs_gx, abs_gy);
77 // z = min(x, y) / max(x, y)
78 const float16x8_t z = vmulq_f16(tmin, inv(vaddq_f16(tmax, epsilon)));
79 const float16x8_t absz = vabsq_f16(z);
81 // = x * [pi/4 + (1 - |x|) * (0.2447 + 0.0663 * |x|)]
82 float16x8_t arctan = vmulq_f16(z, vfmaq_f16(piover4,
84 vfmaq_f16(coeff2, coeff1, absz)));
86 // Radians to degrees conversion with applied a scale factor in order to have the result [0, 255]
87 arctan = vmulq_f16(arctan, scale);
89 /* If z > 1, result = 90 - result */
90 return vbslq_f16(vcgeq_f16(abs_gx, abs_gy), arctan, vsubq_f16(ninety, arctan));
93 inline float16x8_t atan2_0_360(float16x8_t gx, float16x8_t gy)
95 static const float16x8_t scale = vdupq_n_f16(SCALE_360);
96 static const float16x8_t threesixty = vdupq_n_f16(360.0f * SCALE_FACTOR);
97 static const float16x8_t zero = vdupq_n_f16(0.0f);
98 static const float16x8_t oneeighty = vdupq_n_f16(180.0f * SCALE_FACTOR);
100 float16x8_t arctan = atan2_fast(gx, gy, scale);
102 // Choose correct quadrant
103 arctan = vbslq_f16(vcltq_f16(gx, zero), vsubq_f16(oneeighty, arctan), arctan);
104 arctan = vbslq_f16(vcltq_f16(gy, zero), vsubq_f16(threesixty, arctan), arctan);
109 inline float16x8_t atan2_0_180(float16x8_t gx, float16x8_t gy)
111 static const float16x8_t scale = vdupq_n_f16(SCALE_180);
112 static const float16x8_t threesixty = vdupq_n_f16(360.0f * SCALE_FACTOR);
113 static const float16x8_t oneeighty = vdupq_n_f16(180.0f * SCALE_FACTOR);
114 static const float16x8_t zero = vdupq_n_f16(0.0f);
116 float16x8_t arctan = atan2_fast(gx, gy, scale);
118 // Choose correct quadrant
119 arctan = vbslq_f16(vcltq_f16(gx, zero), vsubq_f16(oneeighty, arctan), arctan);
120 arctan = vbslq_f16(vcltq_f16(gy, zero), vsubq_f16(threesixty, arctan), arctan);
121 arctan = vbslq_f16(vcgtq_f16(arctan, oneeighty), vsubq_f16(arctan, oneeighty), arctan);
126 inline float32x4_t invsqrtv(float32x4_t x)
128 float32x4_t sqrt_reciprocal = vrsqrteq_f32(x);
130 sqrt_reciprocal = vmulq_f32(vrsqrtsq_f32(vmulq_f32(x, sqrt_reciprocal), sqrt_reciprocal),
132 sqrt_reciprocal = vmulq_f32(vrsqrtsq_f32(vmulq_f32(x, sqrt_reciprocal), sqrt_reciprocal),
135 return sqrt_reciprocal;
138 inline float32x4_t sqrtv(float32x4_t x)
140 float32x4_t res = vdupq_n_f32(0.5f);
141 return vmlaq_f32(res, x, invsqrtv(x));
144 inline int16x8_t magnitude_l1(int16x8_t input1, int16x8_t input2)
146 return vqaddq_s16(vabsq_s16(input1), vabsq_s16(input2));
149 inline int16x8_t magnitude_l2(int16x8_t input1, int16x8_t input2)
151 const int32x4x2_t square_x =
153 vmull_s16(vget_low_s16(input1), vget_low_s16(input1)),
154 vmull_s16(vget_high_s16(input1), vget_high_s16(input1))
157 const int32x4x2_t square_y =
159 vmull_s16(vget_low_s16(input2), vget_low_s16(input2)),
160 vmull_s16(vget_high_s16(input2), vget_high_s16(input2))
163 const uint32x4x2_t sum =
165 vaddq_u32(vreinterpretq_u32_s32(square_x.val[0]),
166 vreinterpretq_u32_s32(square_y.val[0])),
167 vaddq_u32(vreinterpretq_u32_s32(square_x.val[1]),
168 vreinterpretq_u32_s32(square_y.val[1]))
171 const float32x4x2_t res =
173 sqrtv(vcvtq_f32_u32(sum.val[0])),
174 sqrtv(vcvtq_f32_u32(sum.val[1]))
177 return vcombine_s16(vqmovn_s32(vcvtq_s32_f32(res.val[0])),
178 vqmovn_s32(vcvtq_s32_f32(res.val[1])));
181 inline uint8x8_t phase_signed(int16x8_t input1, int16x8_t input2)
183 static const float16x8_t zeropointfive = vdupq_n_f16(0.5f);
185 const float16x8_t inputx_f16 = vcvtq_f16_s16(input1);
186 const float16x8_t inputy_f16 = vcvtq_f16_s16(input2);
188 // Compute fast atan2
189 const float16x8_t angle = atan2_0_360(inputx_f16, inputy_f16);
191 return vqmovun_s16(vcvtq_s16_f16(vaddq_f16(angle, zeropointfive)));
194 inline uint8x8_t phase_unsigned(int16x8_t input1, int16x8_t input2)
196 static const float16x8_t zeropointfive = vdupq_n_f16(0.5f);
198 const float16x8_t inputx_f16 = vcvtq_f16_s16(input1);
199 const float16x8_t inputy_f16 = vcvtq_f16_s16(input2);
201 // Compute fast atan2
202 const float16x8_t angle = atan2_0_180(inputx_f16, inputy_f16);
204 return vqmovun_s16(vcvtq_s16_f16(vaddq_f16(angle, zeropointfive)));
207 template <MagnitudeType mag_type>
208 inline int16x8x2_t compute_magnitude(const int16x8x2_t &in0, const int16x8x2_t &gx);
211 inline int16x8x2_t compute_magnitude<MagnitudeType::L2NORM>(const int16x8x2_t &in0, const int16x8x2_t &gx)
213 const int16x8x2_t mag =
215 magnitude_l2(in0.val[0], gx.val[0]),
216 magnitude_l2(in0.val[1], gx.val[1])
223 inline int16x8x2_t compute_magnitude<MagnitudeType::L1NORM>(const int16x8x2_t &in0, const int16x8x2_t &gx)
225 const int16x8x2_t mag =
227 magnitude_l1(in0.val[0], gx.val[0]),
228 magnitude_l1(in0.val[1], gx.val[1])
234 template <PhaseType phase_type>
235 inline uint8x16_t compute_phase(const int16x8x2_t &in0, const int16x8x2_t &gx);
238 inline uint8x16_t compute_phase<PhaseType::SIGNED>(const int16x8x2_t &in0, const int16x8x2_t &gx)
240 return vcombine_u8(phase_signed(in0.val[0], gx.val[0]),
241 phase_signed(in0.val[1], gx.val[1]));
245 inline uint8x16_t compute_phase<PhaseType::UNSIGNED>(const int16x8x2_t &in0, const int16x8x2_t &gx)
247 return vcombine_u8(phase_unsigned(in0.val[0], gx.val[0]),
248 phase_unsigned(in0.val[1], gx.val[1]));
252 template <MagnitudeType mag_type, PhaseType phase_type>
253 NEMagnitudePhaseFP16Kernel<mag_type, phase_type>::NEMagnitudePhaseFP16Kernel()
254 : _func(nullptr), _gx(nullptr), _gy(nullptr), _magnitude(nullptr), _phase(nullptr)
258 template <MagnitudeType mag_type, PhaseType phase_type>
259 void NEMagnitudePhaseFP16Kernel<mag_type, phase_type>::configure(const ITensor *gx, const ITensor *gy, ITensor *magnitude, ITensor *phase)
261 ARM_COMPUTE_ERROR_ON_FORMAT_NOT_IN(gx, Format::S16);
262 ARM_COMPUTE_ERROR_ON_FORMAT_NOT_IN(gy, Format::S16);
263 ARM_COMPUTE_ERROR_ON((nullptr == magnitude) && (nullptr == phase));
265 const bool run_mag = magnitude != nullptr;
266 const bool run_phase = phase != nullptr;
270 ARM_COMPUTE_ERROR_ON_FORMAT_NOT_IN(magnitude, Format::S16);
275 ARM_COMPUTE_ERROR_ON_FORMAT_NOT_IN(phase, Format::U8);
280 _magnitude = magnitude;
283 if(run_mag && run_phase)
285 /* Run magnitude and phase */
286 _func = &NEMagnitudePhaseFP16Kernel<mag_type, phase_type>::magnitude_phase;
291 _func = &NEMagnitudePhaseFP16Kernel<mag_type, phase_type>::magnitude;
296 _func = &NEMagnitudePhaseFP16Kernel<mag_type, phase_type>::phase;
300 ARM_COMPUTE_ERROR("At least one output must be NOT NULL");
303 const unsigned int processed_elements = 16;
305 // Configure kernel window
306 Window win = calculate_max_window(*gx->info(), Steps(processed_elements));
307 AccessWindowHorizontal magnitude_access(magnitude == nullptr ? nullptr : magnitude->info(), 0, processed_elements);
308 AccessWindowHorizontal phase_access(phase == nullptr ? nullptr : phase->info(), 0, processed_elements);
310 update_window_and_padding(win,
311 AccessWindowHorizontal(gx->info(), 0, processed_elements),
312 AccessWindowHorizontal(gy->info(), 0, processed_elements),
316 ValidRegion valid_region = intersect_valid_regions(gx->info()->valid_region(),
317 gy->info()->valid_region());
319 magnitude_access.set_valid_region(win, valid_region);
320 phase_access.set_valid_region(win, valid_region);
322 INEKernel::configure(win);
325 template <MagnitudeType mag_type, PhaseType phase_type>
326 void NEMagnitudePhaseFP16Kernel<mag_type, phase_type>::magnitude(const Window &window)
328 Iterator gx(_gx, window);
329 Iterator gy(_gy, window);
330 Iterator magnitude(_magnitude, window);
332 execute_window_loop(window, [&](const Coordinates & id)
334 const int16x8x2_t input1 =
336 vld1q_s16(reinterpret_cast<int16_t *>(gx.ptr())),
337 vld1q_s16(reinterpret_cast<int16_t *>(gx.ptr()) + 8)
340 const int16x8x2_t input2 =
342 vld1q_s16(reinterpret_cast<int16_t *>(gy.ptr())),
343 vld1q_s16(reinterpret_cast<int16_t *>(gy.ptr()) + 8)
346 // Compute and store magnitude
347 const int16x8x2_t mag = fp16::compute_magnitude<mag_type>(input1, input2);
349 /* Store magnitude */
350 vst1q_s16(reinterpret_cast<int16_t *>(magnitude.ptr()), mag.val[0]);
351 vst1q_s16(reinterpret_cast<int16_t *>(magnitude.ptr()) + 8, mag.val[1]);
356 template <MagnitudeType mag_type, PhaseType phase_type>
357 void NEMagnitudePhaseFP16Kernel<mag_type, phase_type>::phase(const Window &window)
359 Iterator gx(_gx, window);
360 Iterator gy(_gy, window);
361 Iterator phase(_phase, window);
363 execute_window_loop(window, [&](const Coordinates & id)
365 const int16x8x2_t input1 =
367 vld1q_s16(reinterpret_cast<int16_t *>(gx.ptr())),
368 vld1q_s16(reinterpret_cast<int16_t *>(gx.ptr()) + 8)
371 const int16x8x2_t input2 =
373 vld1q_s16(reinterpret_cast<int16_t *>(gy.ptr())),
374 vld1q_s16(reinterpret_cast<int16_t *>(gy.ptr()) + 8)
377 // Compute and store phase
378 vst1q_u8(phase.ptr(), fp16::compute_phase<phase_type>(input1, input2));
383 template <MagnitudeType mag_type, PhaseType phase_type>
384 void NEMagnitudePhaseFP16Kernel<mag_type, phase_type>::magnitude_phase(const Window &window)
386 Iterator gx(_gx, window);
387 Iterator gy(_gy, window);
388 Iterator magnitude(_magnitude, window);
389 Iterator phase(_phase, window);
391 execute_window_loop(window, [&](const Coordinates & id)
393 const int16x8x2_t input1 =
395 vld1q_s16(reinterpret_cast<int16_t *>(gx.ptr())),
396 vld1q_s16(reinterpret_cast<int16_t *>(gx.ptr()) + 8)
399 const int16x8x2_t input2 =
401 vld1q_s16(reinterpret_cast<int16_t *>(gy.ptr())),
402 vld1q_s16(reinterpret_cast<int16_t *>(gy.ptr()) + 8)
405 // Compute and store magnitude
406 const int16x8x2_t mag = fp16::compute_magnitude<mag_type>(input1, input2);
408 vst1q_s16(reinterpret_cast<int16_t *>(magnitude.ptr()), mag.val[0]);
409 vst1q_s16(reinterpret_cast<int16_t *>(magnitude.ptr()) + 8, mag.val[1]);
411 // Compute and store phase
412 vst1q_u8(phase.ptr(), fp16::compute_phase<phase_type>(input1, input2));
414 gx, gy, magnitude, phase);
417 template <MagnitudeType mag_type, PhaseType phase_type>
418 void NEMagnitudePhaseFP16Kernel<mag_type, phase_type>::run(const Window &window)
420 ARM_COMPUTE_ERROR_ON_UNCONFIGURED_KERNEL(this);
421 ARM_COMPUTE_ERROR_ON_INVALID_SUBWINDOW(INEKernel::window(), window);
422 ARM_COMPUTE_ERROR_ON(_func == nullptr);
424 (this->*_func)(window);
427 template class arm_compute::NEMagnitudePhaseFP16Kernel<MagnitudeType::L1NORM, PhaseType::SIGNED>;
428 template class arm_compute::NEMagnitudePhaseFP16Kernel<MagnitudeType::L2NORM, PhaseType::SIGNED>;
429 template class arm_compute::NEMagnitudePhaseFP16Kernel<MagnitudeType::L1NORM, PhaseType::UNSIGNED>;
430 template class arm_compute::NEMagnitudePhaseFP16Kernel<MagnitudeType::L2NORM, PhaseType::UNSIGNED>;
435 inline float32x4_t inv(float32x4_t x)
437 float32x4_t result = vrecpeq_f32(x);
438 result = vmulq_f32(vrecpsq_f32(x, result), result);
442 inline float32x4_t atan2_0_360(float32x4_t gx, float32x4_t gy)
444 const float32x4_t zero = vdupq_n_f32(0.0f);
445 const float32x4_t epsilon = vdupq_n_f32(1e-9f);
446 const float32x4_t piover4 = vdupq_n_f32(PI_4);
447 const float32x4_t coeff1 = vdupq_n_f32(COEFF1);
448 const float32x4_t coeff2 = vdupq_n_f32(COEFF2);
449 const float32x4_t ninety = vdupq_n_f32(90.0f * SCALE_FACTOR);
450 const float32x4_t oneeighty = vdupq_n_f32(180.0f * SCALE_FACTOR);
451 const float32x4_t threesixty = vdupq_n_f32(360.0f * SCALE_FACTOR);
452 const float32x4_t scale = vdupq_n_f32(SCALE_360);
454 float32x4_t abs_gx = vabsq_f32(gx);
455 float32x4_t abs_gy = vabsq_f32(gy);
456 float32x4_t tmin = vminq_f32(abs_gx, abs_gy);
457 float32x4_t tmax = vmaxq_f32(abs_gx, abs_gy);
458 float32x4_t z = vmulq_f32(tmin, inv(vaddq_f32(tmax, epsilon)));
459 float32x4_t absz = vabsq_f32(z);
460 float32x4_t term = vmulq_f32(z, vsubq_f32(vdupq_n_f32(1.0f), absz));
462 /* Compute y = pi/4 * x - x*(abs(x)-1)*(0.2447+0.0663 * abs(x) */
463 float32x4_t result = vaddq_f32(coeff2, vmulq_f32(absz, coeff1));
464 result = vmulq_f32(result, term);
465 result = vmlaq_f32(result, piover4, z);
467 /* Radians to degrees conversion with applied a scale factor in order to have the result [0, 255] */
468 result = vmulq_f32(result, scale);
470 /* If z > 1, result = 90 - result */
471 result = vbslq_f32(vcgeq_f32(abs_gx, abs_gy), result, vsubq_f32(ninety, result));
473 /* Choose correct quadrant */
474 result = vbslq_f32(vcltq_f32(gx, zero), vsubq_f32(oneeighty, result), result);
475 result = vbslq_f32(vcltq_f32(gy, zero), vsubq_f32(threesixty, result), result);
480 inline float32x4_t atan2_0_180(float32x4_t gx, float32x4_t gy)
482 const float32x4_t zero = vdupq_n_f32(0.0f);
483 const float32x4_t epsilon = vdupq_n_f32(1e-9f); // epsilon used to avoiding division by 0
484 const float32x4_t piover4 = vdupq_n_f32(PI_4);
485 const float32x4_t coeff1 = vdupq_n_f32(COEFF1);
486 const float32x4_t coeff2 = vdupq_n_f32(COEFF2);
487 const float32x4_t ninety = vdupq_n_f32(90.0f);
488 const float32x4_t oneeighty = vdupq_n_f32(180.0f);
489 const float32x4_t threesixty = vdupq_n_f32(360.0f);
490 const float32x4_t scale = vdupq_n_f32(SCALE_180);
492 float32x4_t abs_gx = vabsq_f32(gx);
493 float32x4_t abs_gy = vabsq_f32(gy);
494 float32x4_t tmin = vminq_f32(abs_gx, abs_gy);
495 float32x4_t tmax = vmaxq_f32(abs_gx, abs_gy);
496 float32x4_t z = vmulq_f32(tmin, inv(vaddq_f32(tmax, epsilon)));
497 float32x4_t absz = vabsq_f32(z);
499 /* Compute y = pi/4 * z - z*(abs(z)-1)*(0.2447+0.0663 * abs(z) */
500 float32x4_t term = vmulq_f32(z, vsubq_f32(vdupq_n_f32(1.0f), absz));
501 float32x4_t result = vaddq_f32(coeff2, vmulq_f32(absz, coeff1));
502 result = vmulq_f32(result, term);
503 result = vmlaq_f32(result, piover4, z);
505 /* Radians to degrees conversion */
506 result = vmulq_f32(result, scale);
508 /* If z > 1, result = 90 - result */
509 result = vbslq_f32(vcgeq_f32(abs_gx, abs_gy), result, vsubq_f32(ninety, result));
511 /* Choose correct quadrant */
512 result = vbslq_f32(vcltq_f32(gx, zero), vsubq_f32(oneeighty, result), result);
513 result = vbslq_f32(vcltq_f32(gy, zero), vsubq_f32(threesixty, result), result);
514 result = vbslq_f32(vcgtq_f32(result, oneeighty), vsubq_f32(result, oneeighty), result);
519 inline float32x4_t invsqrtv(float32x4_t x)
521 float32x4_t sqrt_reciprocal = vrsqrteq_f32(x);
523 sqrt_reciprocal = vmulq_f32(vrsqrtsq_f32(vmulq_f32(x, sqrt_reciprocal), sqrt_reciprocal),
525 sqrt_reciprocal = vmulq_f32(vrsqrtsq_f32(vmulq_f32(x, sqrt_reciprocal), sqrt_reciprocal),
528 return sqrt_reciprocal;
531 inline float32x4_t sqrtv(float32x4_t x)
533 float32x4_t res = vdupq_n_f32(0.5f);
534 return vmlaq_f32(res, x, invsqrtv(x));
537 inline int16x8_t magnitude_l2(int16x8_t input1, int16x8_t input2)
539 const int32x4x2_t square_x =
542 vmull_s16(vget_low_s16(input1), vget_low_s16(input1)),
543 vmull_s16(vget_high_s16(input1), vget_high_s16(input1))
547 const int32x4x2_t square_y =
550 vmull_s16(vget_low_s16(input2), vget_low_s16(input2)),
551 vmull_s16(vget_high_s16(input2), vget_high_s16(input2))
555 const uint32x4x2_t sum =
558 vaddq_u32(vreinterpretq_u32_s32(square_x.val[0]), vreinterpretq_u32_s32(square_y.val[0])),
559 vaddq_u32(vreinterpretq_u32_s32(square_x.val[1]), vreinterpretq_u32_s32(square_y.val[1]))
563 const float32x4x2_t res =
566 sqrtv(vcvtq_f32_u32(sum.val[0])),
567 sqrtv(vcvtq_f32_u32(sum.val[1]))
571 return vcombine_s16(vqmovn_s32(vcvtq_s32_f32(res.val[0])),
572 vqmovn_s32(vcvtq_s32_f32(res.val[1])));
575 inline int16x8_t magnitude_l1(int16x8_t input1, int16x8_t input2)
577 int16x8_t gx_abs = vabsq_s16(input1);
578 int16x8_t gy_abs = vabsq_s16(input2);
581 return vqaddq_s16(gx_abs, gy_abs);
584 inline uint8x8_t phase_signed(int16x8_t input1, int16x8_t input2)
586 const float32x4_t zeropointfive = vdupq_n_f32(0.5f);
588 float32x4_t inputx_f32_high = vcvtq_f32_s32(vmovl_s16(vget_high_s16(input1)));
589 float32x4_t inputx_f32_low = vcvtq_f32_s32(vmovl_s16(vget_low_s16(input1)));
590 float32x4_t inputy_f32_high = vcvtq_f32_s32(vmovl_s16(vget_high_s16(input2)));
591 float32x4_t inputy_f32_low = vcvtq_f32_s32(vmovl_s16(vget_low_s16(input2)));
593 /* Compute fast atan2 */
594 float32x4_t angle_high = atan2_0_360(inputx_f32_high, inputy_f32_high);
595 float32x4_t angle_low = atan2_0_360(inputx_f32_low, inputy_f32_low);
597 angle_high = vaddq_f32(angle_high, zeropointfive);
598 angle_low = vaddq_f32(angle_low, zeropointfive);
600 return vmovn_u16(vcombine_u16(vqmovun_s32(vcvtq_s32_f32(angle_low)),
601 vqmovun_s32(vcvtq_s32_f32(angle_high))));
604 inline uint8x8_t phase_unsigned(int16x8_t input1, int16x8_t input2)
606 const float32x4_t zeropointfive = vdupq_n_f32(0.5f);
608 float32x4_t inputx_f32_high = vcvtq_f32_s32(vmovl_s16(vget_high_s16(input1)));
609 float32x4_t inputx_f32_low = vcvtq_f32_s32(vmovl_s16(vget_low_s16(input1)));
610 float32x4_t inputy_f32_high = vcvtq_f32_s32(vmovl_s16(vget_high_s16(input2)));
611 float32x4_t inputy_f32_low = vcvtq_f32_s32(vmovl_s16(vget_low_s16(input2)));
613 /* Compute fast atan2 */
614 float32x4_t angle_high = atan2_0_180(inputx_f32_high, inputy_f32_high);
615 float32x4_t angle_low = atan2_0_180(inputx_f32_low, inputy_f32_low);
617 angle_high = vaddq_f32(angle_high, zeropointfive);
618 angle_low = vaddq_f32(angle_low, zeropointfive);
620 return vmovn_u16(vcombine_u16(vqmovun_s32(vcvtq_s32_f32(angle_low)),
621 vqmovun_s32(vcvtq_s32_f32(angle_high))));
625 template <MagnitudeType mag_type, PhaseType phase_type>
626 NEMagnitudePhaseKernel<mag_type, phase_type>::NEMagnitudePhaseKernel()
627 : _func(nullptr), _gx(nullptr), _gy(nullptr), _magnitude(nullptr), _phase(nullptr)
631 template <MagnitudeType mag_type, PhaseType phase_type>
632 void NEMagnitudePhaseKernel<mag_type, phase_type>::configure(const ITensor *gx, const ITensor *gy, ITensor *magnitude, ITensor *phase)
634 ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(gx, 1, DataType::S16);
635 ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(gy, 1, DataType::S16);
636 ARM_COMPUTE_ERROR_ON((nullptr == magnitude) && (nullptr == phase));
638 const bool run_mag = magnitude != nullptr;
639 const bool run_phase = phase != nullptr;
643 ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(magnitude, 1, DataType::S16);
648 ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(phase, 1, DataType::U8);
653 _magnitude = magnitude;
656 if(run_mag && run_phase)
658 /* Run magnitude and phase */
659 _func = &NEMagnitudePhaseKernel<mag_type, phase_type>::magnitude_phase;
666 _func = &NEMagnitudePhaseKernel<mag_type, phase_type>::magnitude;
671 _func = &NEMagnitudePhaseKernel<mag_type, phase_type>::phase;
675 ARM_COMPUTE_ERROR("At least one output must be NOT NULL");
679 const unsigned int processed_elements = 16;
681 // Configure kernel window
682 Window win = calculate_max_window(*gx->info(), Steps(processed_elements));
683 AccessWindowHorizontal magnitude_access(magnitude == nullptr ? nullptr : magnitude->info(), 0, processed_elements);
684 AccessWindowHorizontal phase_access(phase == nullptr ? nullptr : phase->info(), 0, processed_elements);
686 update_window_and_padding(win,
687 AccessWindowHorizontal(gx->info(), 0, processed_elements),
688 AccessWindowHorizontal(gy->info(), 0, processed_elements),
692 ValidRegion valid_region = intersect_valid_regions(gx->info()->valid_region(),
693 gy->info()->valid_region());
695 magnitude_access.set_valid_region(win, valid_region);
696 phase_access.set_valid_region(win, valid_region);
698 INEKernel::configure(win);
701 template <MagnitudeType mag_type, PhaseType phase_type>
702 void NEMagnitudePhaseKernel<mag_type, phase_type>::magnitude(const Window &window)
704 Iterator gx(_gx, window);
705 Iterator gy(_gy, window);
706 Iterator magnitude(_magnitude, window);
708 execute_window_loop(window, [&](const Coordinates & id)
710 const int16x8x2_t input1 =
713 vld1q_s16(reinterpret_cast<int16_t *>(gx.ptr())),
714 vld1q_s16(reinterpret_cast<int16_t *>(gx.ptr()) + 8)
718 const int16x8x2_t input2 =
721 vld1q_s16(reinterpret_cast<int16_t *>(gy.ptr())),
722 vld1q_s16(reinterpret_cast<int16_t *>(gy.ptr()) + 8)
726 /* Compute magnitude */
727 int16x8x2_t mag{ {} };
729 if(MagnitudeType::L2NORM == mag_type)
731 mag.val[0] = magnitude_l2(input1.val[0], input2.val[0]);
732 mag.val[1] = magnitude_l2(input1.val[1], input2.val[1]);
736 mag.val[0] = magnitude_l1(input1.val[0], input2.val[0]);
737 mag.val[1] = magnitude_l1(input1.val[1], input2.val[1]);
740 /* Store magnitude */
741 vst1q_s16(reinterpret_cast<int16_t *>(magnitude.ptr()), mag.val[0]);
742 vst1q_s16(reinterpret_cast<int16_t *>(magnitude.ptr()) + 8, mag.val[1]);
747 template <MagnitudeType mag_type, PhaseType phase_type>
748 void NEMagnitudePhaseKernel<mag_type, phase_type>::phase(const Window &window)
750 Iterator gx(_gx, window);
751 Iterator gy(_gy, window);
752 Iterator phase(_phase, window);
754 execute_window_loop(window, [&](const Coordinates & id)
756 const int16x8x2_t input1 =
759 vld1q_s16(reinterpret_cast<int16_t *>(gx.ptr())),
760 vld1q_s16(reinterpret_cast<int16_t *>(gx.ptr()) + 8)
764 const int16x8x2_t input2 =
767 vld1q_s16(reinterpret_cast<int16_t *>(gy.ptr())),
768 vld1q_s16(reinterpret_cast<int16_t *>(gy.ptr()) + 8)
773 uint8x8x2_t vphase{ {} };
775 if(PhaseType::SIGNED == phase_type)
777 vphase.val[0] = phase_signed(input1.val[0], input2.val[0]);
778 vphase.val[1] = phase_signed(input1.val[1], input2.val[1]);
782 vphase.val[0] = phase_unsigned(input1.val[0], input2.val[0]);
783 vphase.val[1] = phase_unsigned(input1.val[1], input2.val[1]);
787 vst1q_u8(phase.ptr(), vcombine_u8(vphase.val[0], vphase.val[1]));
792 template <MagnitudeType mag_type, PhaseType phase_type>
793 void NEMagnitudePhaseKernel<mag_type, phase_type>::magnitude_phase(const Window &window)
795 Iterator gx(_gx, window);
796 Iterator gy(_gy, window);
797 Iterator magnitude(_magnitude, window);
798 Iterator phase(_phase, window);
800 execute_window_loop(window, [&](const Coordinates & id)
802 const int16x8x2_t input1 =
805 vld1q_s16(reinterpret_cast<int16_t *>(gx.ptr())),
806 vld1q_s16(reinterpret_cast<int16_t *>(gx.ptr()) + 8)
810 const int16x8x2_t input2 =
813 vld1q_s16(reinterpret_cast<int16_t *>(gy.ptr())),
814 vld1q_s16(reinterpret_cast<int16_t *>(gy.ptr()) + 8)
818 /* Compute magnitude */
819 int16x8x2_t mag{ {} };
821 if(MagnitudeType::L2NORM == mag_type)
823 mag.val[0] = magnitude_l2(input1.val[0], input2.val[0]);
824 mag.val[1] = magnitude_l2(input1.val[1], input2.val[1]);
828 mag.val[0] = magnitude_l1(input1.val[0], input2.val[0]);
829 mag.val[1] = magnitude_l1(input1.val[1], input2.val[1]);
832 /* Store magnitude */
833 vst1q_s16(reinterpret_cast<int16_t *>(magnitude.ptr()), mag.val[0]);
834 vst1q_s16(reinterpret_cast<int16_t *>(magnitude.ptr()) + 8, mag.val[1]);
837 uint8x8x2_t vphase{ {} };
839 if(PhaseType::SIGNED == phase_type)
841 vphase.val[0] = phase_signed(input1.val[0], input2.val[0]);
842 vphase.val[1] = phase_signed(input1.val[1], input2.val[1]);
846 vphase.val[0] = phase_unsigned(input1.val[0], input2.val[0]);
847 vphase.val[1] = phase_unsigned(input1.val[1], input2.val[1]);
851 vst1q_u8(phase.ptr(), vcombine_u8(vphase.val[0], vphase.val[1]));
853 gx, gy, magnitude, phase);
856 template <MagnitudeType mag_type, PhaseType phase_type>
857 void NEMagnitudePhaseKernel<mag_type, phase_type>::run(const Window &window)
859 ARM_COMPUTE_ERROR_ON_UNCONFIGURED_KERNEL(this);
860 ARM_COMPUTE_ERROR_ON_INVALID_SUBWINDOW(INEKernel::window(), window);
861 ARM_COMPUTE_ERROR_ON(_func == nullptr);
863 (this->*_func)(window);
866 template class arm_compute::NEMagnitudePhaseKernel<MagnitudeType::L1NORM, PhaseType::SIGNED>;
867 template class arm_compute::NEMagnitudePhaseKernel<MagnitudeType::L2NORM, PhaseType::SIGNED>;
868 template class arm_compute::NEMagnitudePhaseKernel<MagnitudeType::L1NORM, PhaseType::UNSIGNED>;
869 template class arm_compute::NEMagnitudePhaseKernel<MagnitudeType::L2NORM, PhaseType::UNSIGNED>;