1 /*******************************************************************************
2 * Copyright 2017-2018 Intel Corporation
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
8 * http://www.apache.org/licenses/LICENSE-2.0
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
17 #ifndef MATH_UTILS_HPP
18 #define MATH_UTILS_HPP
25 #include "mkldnn_traits.hpp"
31 template <typename data_t, typename acc_t>
32 inline typename utils::enable_if<!nstl::is_integral<data_t>::value,
33 typename utils::remove_reference<data_t>::type>::type
34 saturate(const acc_t &x) {
35 return (typename utils::remove_reference<data_t>::type)x;
38 template <typename data_t, typename acc_t>
39 inline typename utils::enable_if<nstl::is_integral<data_t>::value,
40 typename utils::remove_reference<data_t>::type>::type
41 saturate(const acc_t &x) {
43 if (v < (acc_t)nstl::numeric_limits<data_t>::lowest())
44 v = (acc_t)nstl::numeric_limits<data_t>::lowest();
45 if (v > (acc_t)nstl::numeric_limits<data_t>::max())
46 v = (acc_t)nstl::numeric_limits<data_t>::max();
47 return (typename utils::remove_reference<data_t>::type)v;
50 template <typename data_t>
51 double saturate(const double &x) {
53 if (v < (double)nstl::numeric_limits<data_t>::lowest())
54 v = (double)nstl::numeric_limits<data_t>::lowest();
55 if (v > (double)nstl::numeric_limits<data_t>::max())
56 v = (double)nstl::numeric_limits<data_t>::max();
60 template <> inline int8_t saturate<int8_t, uint8_t>(const uint8_t &x) {
61 return x <= 127u ? x : 127;
64 template <> inline uint8_t saturate<uint8_t, int8_t>(const int8_t &x) {
65 return x >= 0 ? x : 0;
68 template <typename out_t>
69 inline typename utils::enable_if<nstl::is_integral<out_t>::value, out_t>::type
70 out_round(float v, round_mode_t rmode = round_mode::nearest)
71 { return (out_t)(rmode == round_mode::down ? floorf(v) : nearbyintf(v)); }
73 template <typename out_t>
74 inline typename utils::enable_if<nstl::is_integral<out_t>::value, out_t>::type
75 out_round(double v, round_mode_t rmode = round_mode::nearest)
76 { return (out_t)(rmode == round_mode::down ? floor(v) : nearbyint(v)); }
78 template <typename out_t>
79 inline typename utils::enable_if<!nstl::is_integral<out_t>::value, out_t>::type
80 out_round(float v, round_mode_t rmode = round_mode::nearest)
81 { UNUSED(rmode); return v; }
83 inline int gcd(int a, int b) {
84 a = impl::nstl::abs(a);
85 b = impl::nstl::abs(b);
86 if (a < b) { int x = a; a = b; b = x; }
91 while ((r = a % b) != 0) { a = b; b = r; }
97 inline bool is_pow2(const T& v) { return (v & (v - 1)) == 0; }
99 /** returns floor(log2(v)), aka the position of the leftmost non-0 bit */
100 inline int ilog2q(size_t v) {
105 # define CP(pw) do { if (v >= (1ull << pw)) { v >>= pw; p += pw; } } while(0)
106 CP(32); CP(16); CP(8); CP(4); CP(2); CP(1);
111 template <typename T, typename U = typename utils::remove_reference<T>::type>
112 inline U one_m_square(T x) {
113 return (U)(1 - x) * (1 + x);
116 template <typename T, typename U = typename utils::remove_reference<T>::type>
117 inline U x_m_square(T x) {
118 return (U)(1 - x) * x;
122 template <typename T, typename A,
123 typename U = typename utils::remove_reference<T>::type>
124 inline U relu_fwd(T s, A alpha) {
125 return s > 0 ? s : (U)(s * alpha);
127 template <typename T, typename A,
128 typename U = typename utils::remove_reference<T>::type>
129 inline U relu_bwd(T dd, T s, A alpha) {
130 return s > 0 ? dd : (U)(dd * alpha);
133 template <typename T, typename U = typename utils::remove_reference<T>::type>
134 inline U tanh_fwd(T s) {
135 const float e = tanhf((float) s);
139 template <typename T, typename U = typename utils::remove_reference<T>::type>
140 inline U tanh_bwd(T dd, T s) {
141 const float e = tanh_fwd<float>((float) s);
142 return (U)(dd * (1 - e) * (1 + e));
145 template <typename T, typename A,
146 typename U = typename utils::remove_reference<T>::type>
147 inline U elu_fwd(T s, A alpha) {
148 return s > 0 ? s : (U)(alpha * (::expm1f((float)s)));
150 template <typename T, typename A,
151 typename U = typename utils::remove_reference<T>::type>
152 inline U elu_bwd(T dd, T s, A alpha) {
153 return (U)(dd * (s > 0 ? 1 : alpha * ::expf((float)s)));
156 template <typename T, typename U = typename utils::remove_reference<T>::type>
157 inline U square_fwd(T s) {
161 template <typename T, typename U = typename utils::remove_reference<T>::type>
162 inline U square_bwd(T dd, T s) {
166 template <typename T, typename U = typename utils::remove_reference<T>::type>
167 inline U abs_fwd(T s) {
168 return s > 0 ? s : -s;
171 template <typename T, typename U = typename utils::remove_reference<T>::type>
172 inline U abs_bwd(T dd, T s) {
173 return s > 0 ? dd : s < 0 ? -dd : 0;
176 template <typename T, typename U = typename utils::remove_reference<T>::type>
177 inline U sqrt_fwd(T s) {
178 return s > 0 ? (U)(::sqrtf((float)(s))) : 0;
181 template <typename T, typename U = typename utils::remove_reference<T>::type>
182 inline U sqrt_bwd(T dd, T s) {
184 ? (U)(dd / (2 * ::sqrtf((float)(s))))
188 template <typename T, typename A,
189 typename U = typename utils::remove_reference<T>::type>
190 inline U linear_fwd(T s, A alpha, A beta) {
191 return (U)(alpha * s + beta);
194 template <typename T, typename A,
195 typename U = typename utils::remove_reference<T>::type>
196 inline U linear_bwd(T dd, T s, A alpha, A beta) {
199 return (U)(dd * alpha);
202 template <typename T, typename A,
203 typename U = typename utils::remove_reference<T>::type>
204 inline U bounded_relu_fwd(T s, A alpha) {
206 return s > alpha ? (U)(alpha) : s;
209 template <typename T, typename A,
210 typename U = typename utils::remove_reference<T>::type>
211 inline U bounded_relu_bwd(T dd, T s, A alpha) {
212 return dd * (0 < s && s < alpha ? 1 : 0);
215 template <typename T, typename U = typename utils::remove_reference<T>::type>
216 inline U soft_relu_fwd(T s) {
217 float max_logf = 8.872284e+01; //::logf(FLT_MAX)
218 return s < max_logf ? (U)(::log1pf(::expf((float)s))) : s;
221 template <typename T, typename U = typename utils::remove_reference<T>::type>
222 inline U soft_relu_bwd(T dd, T s) {
223 return (U)(dd / (1 + ::expf((float)(-s))));
226 template <typename T, typename U = typename utils::remove_reference<T>::type>
227 inline U logistic_fwd(T s) {
228 U v = (U)(::expf((float) -s));
232 template <typename T, typename U = typename utils::remove_reference<T>::type>
233 inline U logistic_bwd(T dd, T s) {
234 U v = logistic_fwd<T, U>(s);
235 return dd * v * (1 - v);
238 template <typename T, typename A,
239 typename U = typename utils::remove_reference<T>::type>
240 inline U clamp_fwd(T s, A alpha, A beta) {
241 return (U)(s > alpha ? alpha : s < beta ? beta : s);
244 template <typename T, typename A,
245 typename U = typename utils::remove_reference<T>::type>
246 inline U clamp_bwd(T dd, T s, A alpha, A beta) {
247 return dd * (beta < s && s < alpha ? 1 : 0);
250 template <typename T,
251 typename U = typename utils::remove_reference<T>::type>
252 inline U exp_fwd(T s) {
253 return (U)(::expf((float)s));
256 template <typename T,
257 typename U = typename utils::remove_reference<T>::type>
258 inline U exp_bwd(T dd, T s) {
259 return (U)(::expf((float)s));
262 template <typename T,
263 typename U = typename utils::remove_reference<T>::type>
264 inline U not_fwd(T s) {
268 template <typename T, typename A,
269 typename U = typename utils::remove_reference<T>::type>
270 inline U scale_shift_fwd(T s_val, A w_val, A b_val) {
271 return (U)(s_val*w_val + b_val);
274 template <typename T, typename A,
275 typename U = typename utils::remove_reference<T>::type>
276 inline U prelu_fwd(T s_val, A w_val) {
277 return (U)(s_val >= 0 ? s_val : w_val*s_val);
280 inline bool eltwise_fwd_preserves_zero(alg_kind_t alg, bool jit_impl = false) {
281 using namespace alg_kind;
282 using namespace utils;
283 const bool preserves_zero = true
284 && !one_of(alg, eltwise_linear, eltwise_soft_relu, eltwise_logistic, eltwise_clamp, eltwise_exp, eltwise_not)
285 && IMPLICATION(jit_impl, !one_of(alg, eltwise_elu, eltwise_tanh, eltwise_clamp, eltwise_exp, eltwise_not));
286 return preserves_zero;
289 inline float get_bias(const char *bias, size_t offset, data_type_t data_type)
295 case dt: return (float)((const prec_traits<dt>::type *)bias)[offset]
300 CASE(data_type::s32);
301 CASE(data_type::f32);
302 default: assert(!"unimplemented");
304 return 0; // never happens (should probably be a NaN)