83d09c3a25560b6224faa06d18e2a40766c8c17e
[platform/upstream/dldt.git] / inference-engine / src / extension / common / fast_exp.h
1 // Copyright (C) 2018 Intel Corporation
2 //
3 // SPDX-License-Identifier: Apache-2.0
4 //
5
6 #pragma once
7
8 #include "defs.h"
9
10 #define FAST_EXP_HI   87.3365402f
11 #define FAST_EXP_LO  -87.3365402f
12
13 #define LOG2EF 1.44269504088896341f
14 #define LOG2  0.693147181f
15
16 #define FAST_EXP_C1  12582912.0f
17 #define FAST_EXP_C2  0.00829171948f
18
19 #define FAST_EXP_P0 1.42860677e-06f
20 #define FAST_EXP_P1 0.0418735221f
21 #define FAST_EXP_P2 0.166674316f
22 #define FAST_EXP_P3 0.49999392f
23 #define FAST_EXP_P4 0.999999881f
24 #define FAST_EXP_P5 1.0f
25
26 #if defined(HAVE_AVX2)
27 static inline __m256 _avx_fast_exp_ps(__m256 vsrc) {
28     __m256 vc_exp_c1 = _mm256_set1_ps(FAST_EXP_C1);
29     __m256 vc_exp_c2 = _mm256_set1_ps(FAST_EXP_C2);
30     __m256 vc_log2e  = _mm256_set1_ps(LOG2EF);
31     __m256 vc_log2   = _mm256_set1_ps(LOG2);
32
33     __m256 vc_exp_p0 = _mm256_set1_ps(FAST_EXP_P0);
34     __m256 vc_exp_p1 = _mm256_set1_ps(FAST_EXP_P1);
35     __m256 vc_exp_p2 = _mm256_set1_ps(FAST_EXP_P2);
36     __m256 vc_exp_p3 = _mm256_set1_ps(FAST_EXP_P3);
37     __m256 vc_exp_p4 = _mm256_set1_ps(FAST_EXP_P4);
38     __m256 cv_exp_p5 = _mm256_set1_ps(FAST_EXP_P5);
39
40     __m256 vc_exp_hi = _mm256_set1_ps(FAST_EXP_HI);
41     __m256 vc_exp_lo = _mm256_set1_ps(FAST_EXP_LO);
42
43     vsrc = _mm256_max_ps(_mm256_min_ps(vsrc, vc_exp_hi), vc_exp_lo);
44 #if defined(HAVE_FMA)
45     __m256 fx = _mm256_fmadd_ps(vsrc, vc_log2e, vc_exp_c1);
46 #else
47     __m256 fx = _mm256_add_ps(_mm256_mul_ps(vsrc, vc_log2e), vc_exp_c1);
48 #endif
49     __m256 fx_ = _mm256_sub_ps(fx, vc_exp_c1);
50     __m256i msk = _mm256_slli_epi32(_mm256_castps_si256(fx), 23);
51
52 #if defined(HAVE_FMA)
53     __m256 q = _mm256_fnmadd_ps(fx_, vc_log2, vsrc);
54     __m256 y = _mm256_fnmadd_ps(fx_, vc_exp_p0, q);
55            q = _mm256_fmadd_ps(vc_exp_c2, y, vc_exp_p1);
56            q = _mm256_fmadd_ps(y, q, vc_exp_p2);
57            q = _mm256_fmadd_ps(y, q, vc_exp_p3);
58            q = _mm256_fmadd_ps(y, q, vc_exp_p4);
59            q = _mm256_fmadd_ps(y, q, cv_exp_p5);
60 #else
61     __m256 q = _mm256_sub_ps(vsrc, _mm256_mul_ps(fx_, vc_log2));
62     __m256 y = _mm256_sub_ps(q, _mm256_mul_ps(fx_, vc_exp_p0));
63            q = _mm256_add_ps(_mm256_mul_ps(vc_exp_c2, y), vc_exp_p1);
64            q = _mm256_add_ps(_mm256_mul_ps(y, q), vc_exp_p2);
65            q = _mm256_add_ps(_mm256_mul_ps(y, q), vc_exp_p3);
66            q = _mm256_add_ps(_mm256_mul_ps(y, q), vc_exp_p4);
67            q = _mm256_add_ps(_mm256_mul_ps(y, q), cv_exp_p5);
68 #endif
69
70     __m256 vexp = _mm256_castsi256_ps(_mm256_add_epi32(_mm256_castps_si256(q), msk));
71     return vexp;
72 }
73 #endif
74
75 #if defined(HAVE_SSE)
76 static inline __m128 _sse_fast_exp_ps(__m128 vsrc) {
77     __m128 vc_exp_c1 = _mm_set1_ps(FAST_EXP_C1);
78     __m128 vc_exp_c2 = _mm_set1_ps(FAST_EXP_C2);
79     __m128 vc_log2e  = _mm_set1_ps(LOG2EF);
80     __m128 vc_log2   = _mm_set1_ps(LOG2);
81
82     __m128 vc_exp_p0 = _mm_set1_ps(FAST_EXP_P0);
83     __m128 vc_exp_p1 = _mm_set1_ps(FAST_EXP_P1);
84     __m128 vc_exp_p2 = _mm_set1_ps(FAST_EXP_P2);
85     __m128 vc_exp_p3 = _mm_set1_ps(FAST_EXP_P3);
86     __m128 vc_exp_p4 = _mm_set1_ps(FAST_EXP_P4);
87     __m128 cv_exp_p5 = _mm_set1_ps(FAST_EXP_P5);
88
89     __m128 vc_exp_hi = _mm_set1_ps(FAST_EXP_HI);
90     __m128 vc_exp_lo = _mm_set1_ps(FAST_EXP_LO);
91
92     vsrc = _mm_max_ps(_mm_min_ps(vsrc, vc_exp_hi), vc_exp_lo);
93
94     __m128 fx = _mm_fmadd_ps(vsrc, vc_log2e, vc_exp_c1);
95     __m128 fx_ = _mm_sub_ps(fx, vc_exp_c1);
96     __m128i msk = _mm_slli_epi32(_mm_castps_si128(fx), 23);
97
98     __m128 q = _mm_fnmadd_ps(fx_, vc_log2, vsrc);
99     __m128 y = _mm_fnmadd_ps(fx_, vc_exp_p0, q);
100            q = _mm_fmadd_ps(vc_exp_c2, y, vc_exp_p1);
101            q = _mm_fmadd_ps(y, q, vc_exp_p2);
102            q = _mm_fmadd_ps(y, q, vc_exp_p3);
103            q = _mm_fmadd_ps(y, q, vc_exp_p4);
104            q = _mm_fmadd_ps(y, q, cv_exp_p5);
105
106     __m128 vexp = _mm_castsi128_ps(_mm_add_epi32(_mm_castps_si128(q), msk));
107
108     return vexp;
109 }
110 #endif