Publishing 2019 R1 content
[platform/upstream/dldt.git] / inference-engine / thirdparty / mkl-dnn / src / common / math_utils.hpp
1 /*******************************************************************************
2 * Copyright 2017-2018 Intel Corporation
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *     http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
16
17 #ifndef MATH_UTILS_HPP
18 #define MATH_UTILS_HPP
19
20 #include <stdint.h>
21 #include <math.h>
22
23 #include "utils.hpp"
24 #include "nstl.hpp"
25 #include "mkldnn_traits.hpp"
26
27 namespace mkldnn {
28 namespace impl {
29 namespace math {
30
31 template <typename data_t, typename acc_t>
32 inline typename utils::enable_if<!nstl::is_integral<data_t>::value,
33        typename utils::remove_reference<data_t>::type>::type
34 saturate(const acc_t &x) {
35     return (typename utils::remove_reference<data_t>::type)x;
36 }
37
38 template <typename data_t, typename acc_t>
39 inline typename utils::enable_if<nstl::is_integral<data_t>::value,
40        typename utils::remove_reference<data_t>::type>::type
41 saturate(const acc_t &x) {
42     acc_t v = x;
43     if (v < (acc_t)nstl::numeric_limits<data_t>::lowest())
44         v = (acc_t)nstl::numeric_limits<data_t>::lowest();
45     if (v > (acc_t)nstl::numeric_limits<data_t>::max())
46         v = (acc_t)nstl::numeric_limits<data_t>::max();
47     return (typename utils::remove_reference<data_t>::type)v;
48 }
49
50 template <typename data_t>
51 double saturate(const double &x) {
52     double v = x;
53     if (v < (double)nstl::numeric_limits<data_t>::lowest())
54         v = (double)nstl::numeric_limits<data_t>::lowest();
55     if (v > (double)nstl::numeric_limits<data_t>::max())
56         v = (double)nstl::numeric_limits<data_t>::max();
57     return v;
58 }
59
60 template <> inline int8_t saturate<int8_t, uint8_t>(const uint8_t &x) {
61     return x <= 127u ? x : 127;
62 }
63
64 template <> inline uint8_t saturate<uint8_t, int8_t>(const int8_t &x) {
65     return x >= 0 ? x : 0;
66 }
67
68 template <typename out_t>
69 inline typename utils::enable_if<nstl::is_integral<out_t>::value, out_t>::type
70 out_round(float v, round_mode_t rmode = round_mode::nearest)
71 { return (out_t)(rmode == round_mode::down ? floorf(v) : nearbyintf(v)); }
72
73 template <typename out_t>
74 inline typename utils::enable_if<nstl::is_integral<out_t>::value, out_t>::type
75 out_round(double v, round_mode_t rmode = round_mode::nearest)
76 { return (out_t)(rmode == round_mode::down ? floor(v) : nearbyint(v)); }
77
78 template <typename out_t>
79 inline typename utils::enable_if<!nstl::is_integral<out_t>::value, out_t>::type
80 out_round(float v, round_mode_t rmode = round_mode::nearest)
81 { UNUSED(rmode); return v; }
82
83 inline int gcd(int a, int b) {
84     a = impl::nstl::abs(a);
85     b = impl::nstl::abs(b);
86     if (a < b) { int x = a; a = b; b = x; }
87
88     if (b == 0) return a;
89
90     int r;
91     while ((r = a % b) != 0) { a = b; b = r; }
92
93     return b;
94 }
95
96 template <typename T>
97 inline bool is_pow2(const T& v) { return (v & (v - 1)) == 0; }
98
99 /** returns floor(log2(v)), aka the position of the leftmost non-0 bit */
100 inline int ilog2q(size_t v) {
101     if (v == 0)
102         return -1;
103
104     int p = 0;
105 #   define CP(pw) do { if (v >= (1ull << pw)) { v >>= pw; p += pw; } } while(0)
106     CP(32); CP(16); CP(8); CP(4); CP(2); CP(1);
107 #   undef CP
108     return p;
109 }
110
111 template <typename T, typename U = typename utils::remove_reference<T>::type>
112 inline U one_m_square(T x) {
113     return (U)(1 - x) * (1 + x);
114 }
115
116 template <typename T, typename U = typename utils::remove_reference<T>::type>
117 inline U x_m_square(T x) {
118     return (U)(1 - x) * x;
119 }
120
121 /* activation */
122 template <typename T, typename A,
123          typename U = typename utils::remove_reference<T>::type>
124 inline U relu_fwd(T s, A alpha) {
125     return s > 0 ? s : (U)(s * alpha);
126 }
127 template <typename T, typename A,
128          typename U = typename utils::remove_reference<T>::type>
129 inline U relu_bwd(T dd, T s, A alpha) {
130     return s > 0 ? dd : (U)(dd * alpha);
131 }
132
133 template <typename T, typename U = typename utils::remove_reference<T>::type>
134 inline U tanh_fwd(T s) {
135     const float e = tanhf((float) s);
136     return (U)e;
137 }
138
139 template <typename T, typename U = typename utils::remove_reference<T>::type>
140 inline U tanh_bwd(T dd, T s) {
141     const float e = tanh_fwd<float>((float) s);
142     return (U)(dd * (1 - e) * (1 + e));
143 }
144
145 template <typename T, typename A,
146          typename U = typename utils::remove_reference<T>::type>
147 inline U elu_fwd(T s, A alpha) {
148     return s > 0 ? s : (U)(alpha * (::expm1f((float)s)));
149 }
150 template <typename T, typename A,
151          typename U = typename utils::remove_reference<T>::type>
152  inline U elu_bwd(T dd, T s, A alpha) {
153     return (U)(dd * (s > 0 ? 1 : alpha * ::expf((float)s)));
154 }
155
156 template <typename T, typename U = typename utils::remove_reference<T>::type>
157 inline U square_fwd(T s) {
158     return s * s;
159 }
160
161 template <typename T, typename U = typename utils::remove_reference<T>::type>
162 inline U square_bwd(T dd, T s) {
163     return dd * 2 * s;
164 }
165
166 template <typename T, typename U = typename utils::remove_reference<T>::type>
167 inline U abs_fwd(T s) {
168     return s > 0 ? s : -s;
169 }
170
171 template <typename T, typename U = typename utils::remove_reference<T>::type>
172 inline U abs_bwd(T dd, T s) {
173     return s > 0 ? dd : s < 0 ? -dd : 0;
174 }
175
176 template <typename T, typename U = typename utils::remove_reference<T>::type>
177 inline U sqrt_fwd(T s) {
178     return s > 0 ? (U)(::sqrtf((float)(s))) : 0;
179 }
180
181 template <typename T, typename U = typename utils::remove_reference<T>::type>
182 inline U sqrt_bwd(T dd, T s) {
183     return s > 0
184         ? (U)(dd / (2 * ::sqrtf((float)(s))))
185         : 0;
186 }
187
188 template <typename T, typename A,
189          typename U = typename utils::remove_reference<T>::type>
190 inline U linear_fwd(T s, A alpha, A beta) {
191     return (U)(alpha * s + beta);
192 }
193
194 template <typename T, typename A,
195          typename U = typename utils::remove_reference<T>::type>
196 inline U linear_bwd(T dd, T s, A alpha, A beta) {
197     (void) s;
198     (void) beta;
199     return (U)(dd * alpha);
200 }
201
202 template <typename T, typename A,
203          typename U = typename utils::remove_reference<T>::type>
204 inline U bounded_relu_fwd(T s, A alpha) {
205     s = s > 0 ? s : 0;
206     return s > alpha ? (U)(alpha) : s;
207 }
208
209 template <typename T, typename A,
210          typename U = typename utils::remove_reference<T>::type>
211 inline U bounded_relu_bwd(T dd, T s, A alpha) {
212     return dd * (0 < s && s < alpha ? 1 : 0);
213 }
214
215 template <typename T, typename U = typename utils::remove_reference<T>::type>
216 inline U soft_relu_fwd(T s) {
217     float max_logf = 8.872284e+01; //::logf(FLT_MAX)
218     return s < max_logf ? (U)(::log1pf(::expf((float)s))) : s;
219 }
220
221 template <typename T, typename U = typename utils::remove_reference<T>::type>
222 inline U soft_relu_bwd(T dd, T s) {
223     return (U)(dd / (1 + ::expf((float)(-s))));
224 }
225
226 template <typename T, typename U = typename utils::remove_reference<T>::type>
227 inline U logistic_fwd(T s) {
228     U v = (U)(::expf((float) -s));
229     return 1 / (1 + v);
230 }
231
232 template <typename T, typename U = typename utils::remove_reference<T>::type>
233 inline U logistic_bwd(T dd, T s) {
234     U v = logistic_fwd<T, U>(s);
235     return dd * v * (1 - v);
236 }
237
238 template <typename T, typename A,
239          typename U = typename utils::remove_reference<T>::type>
240 inline U clamp_fwd(T s, A alpha, A beta) {
241     return (U)(s > alpha ? alpha : s < beta ? beta : s);
242 }
243
244 template <typename T, typename A,
245          typename U = typename utils::remove_reference<T>::type>
246 inline U clamp_bwd(T dd, T s, A alpha, A beta) {
247     return dd * (beta < s && s < alpha ? 1 : 0);
248 }
249
250 template <typename T,
251          typename U = typename utils::remove_reference<T>::type>
252 inline U exp_fwd(T s) {
253     return (U)(::expf((float)s));
254 }
255
256 template <typename T,
257          typename U = typename utils::remove_reference<T>::type>
258  inline U exp_bwd(T dd, T s) {
259     return (U)(::expf((float)s));
260 }
261
262 template <typename T,
263         typename U = typename utils::remove_reference<T>::type>
264 inline U not_fwd(T s) {
265     return (U)(!s);
266 }
267
268 template <typename T, typename A,
269          typename U = typename utils::remove_reference<T>::type>
270 inline U scale_shift_fwd(T s_val, A w_val, A b_val) {
271     return (U)(s_val*w_val + b_val);
272 }
273
274 template <typename T, typename A,
275          typename U = typename utils::remove_reference<T>::type>
276 inline U prelu_fwd(T s_val, A w_val) {
277     return (U)(s_val >= 0 ? s_val : w_val*s_val);
278 }
279
280 inline bool eltwise_fwd_preserves_zero(alg_kind_t alg, bool jit_impl = false) {
281     using namespace alg_kind;
282     using namespace utils;
283     const bool preserves_zero = true
284         && !one_of(alg, eltwise_linear, eltwise_soft_relu, eltwise_logistic, eltwise_clamp, eltwise_exp, eltwise_not)
285         && IMPLICATION(jit_impl, !one_of(alg, eltwise_elu, eltwise_tanh, eltwise_clamp, eltwise_exp, eltwise_not));
286     return preserves_zero;
287 }
288
289 inline float get_bias(const char *bias, size_t offset, data_type_t data_type)
290 {
291     if (!bias)
292         return 0.0f;
293
294 #define CASE(dt) \
295     case dt: return (float)((const prec_traits<dt>::type *)bias)[offset]
296
297     switch (data_type) {
298     CASE(data_type::s8);
299     CASE(data_type::u8);
300     CASE(data_type::s32);
301     CASE(data_type::f32);
302     default: assert(!"unimplemented");
303     }
304     return 0; // never happens (should probably be a NaN)
305 #undef CASE
306 }
307
308 }
309 }
310 }
311
312 #endif