Publishing 2019 R1 content
[platform/upstream/dldt.git] / inference-engine / thirdparty / mkl-dnn / src / common / math_utils.hpp
index 0ae7093..6e2e285 100644 (file)
@@ -22,6 +22,7 @@
 
 #include "utils.hpp"
 #include "nstl.hpp"
+#include "mkldnn_traits.hpp"
 
 namespace mkldnn {
 namespace impl {
@@ -107,118 +108,203 @@ inline int ilog2q(size_t v) {
     return p;
 }
 
+template <typename T, typename U = typename utils::remove_reference<T>::type>
+inline U one_m_square(T x) {
+    return (U)(1 - x) * (1 + x);
+}
+
+template <typename T, typename U = typename utils::remove_reference<T>::type>
+inline U x_m_square(T x) {
+    return (U)(1 - x) * x;
+}
+
 /* activation */
-template <typename T, typename A> inline T relu_fwd(T s, A alpha) {
-    return s > 0 ? s : (T)(s * alpha);
+template <typename T, typename A,
+         typename U = typename utils::remove_reference<T>::type>
+inline U relu_fwd(T s, A alpha) {
+    return s > 0 ? s : (U)(s * alpha);
 }
-template <typename T, typename A> inline T relu_bwd(T dd, T s, A alpha) {
-    return s > 0 ? dd : (T)(dd * alpha);
+template <typename T, typename A,
+         typename U = typename utils::remove_reference<T>::type>
+inline U relu_bwd(T dd, T s, A alpha) {
+    return s > 0 ? dd : (U)(dd * alpha);
 }
 
-template <typename T> inline T tanh_fwd(T s) {
+template <typename T, typename U = typename utils::remove_reference<T>::type>
+inline U tanh_fwd(T s) {
     const float e = tanhf((float) s);
-    return (T) e;
+    return (U)e;
 }
-template <typename T> inline T tanh_bwd(T dd, T s) {
+
+template <typename T, typename U = typename utils::remove_reference<T>::type>
+inline U tanh_bwd(T dd, T s) {
     const float e = tanh_fwd<float>((float) s);
-    return (T)(dd * (1 - e) * (1 + e));
+    return (U)(dd * (1 - e) * (1 + e));
 }
 
-template <typename T, typename A> inline T elu_fwd(T s, A alpha) {
-    return s > 0 ? s : (T)(alpha * (::expm1f((float)s)));
+template <typename T, typename A,
+         typename U = typename utils::remove_reference<T>::type>
+inline U elu_fwd(T s, A alpha) {
+    return s > 0 ? s : (U)(alpha * (::expm1f((float)s)));
 }
-template <typename T, typename A> inline T elu_bwd(T dd, T s, A alpha) {
-    return (T)(dd * (s > 0 ? 1 : alpha * ::expf((float)s)));
+template <typename T, typename A,
+         typename U = typename utils::remove_reference<T>::type>
+ inline U elu_bwd(T dd, T s, A alpha) {
+    return (U)(dd * (s > 0 ? 1 : alpha * ::expf((float)s)));
 }
 
-template <typename T>
-inline T square_fwd(T s) {
+template <typename T, typename U = typename utils::remove_reference<T>::type>
+inline U square_fwd(T s) {
     return s * s;
 }
 
-template <typename T>
-inline T square_bwd(T dd, T s) {
-    return dd * 2*s;
+template <typename T, typename U = typename utils::remove_reference<T>::type>
+inline U square_bwd(T dd, T s) {
+    return dd * 2 * s;
 }
 
-template <typename T>
-inline T abs_fwd(T s) {
+template <typename T, typename U = typename utils::remove_reference<T>::type>
+inline U abs_fwd(T s) {
     return s > 0 ? s : -s;
 }
 
-template <typename T>
-inline T abs_bwd(T dd, T s) {
+template <typename T, typename U = typename utils::remove_reference<T>::type>
+inline U abs_bwd(T dd, T s) {
     return s > 0 ? dd : s < 0 ? -dd : 0;
 }
 
-template <typename T>
-inline T sqrt_fwd(T s) {
-    return s > 0 ? (T)(::sqrtf((float)(s))) : 0;
+template <typename T, typename U = typename utils::remove_reference<T>::type>
+inline U sqrt_fwd(T s) {
+    return s > 0 ? (U)(::sqrtf((float)(s))) : 0;
 }
 
-template <typename T>
-inline T sqrt_bwd(T dd, T s) {
+template <typename T, typename U = typename utils::remove_reference<T>::type>
+inline U sqrt_bwd(T dd, T s) {
     return s > 0
-        ? (T)(dd / (2 * ::sqrtf((float)(s))))
+        ? (U)(dd / (2 * ::sqrtf((float)(s))))
         : 0;
 }
 
-template <typename T, typename A>
-inline T linear_fwd(T s, A alpha, A beta) {
-    return (T)(alpha * s + beta);
+template <typename T, typename A,
+         typename U = typename utils::remove_reference<T>::type>
+inline U linear_fwd(T s, A alpha, A beta) {
+    return (U)(alpha * s + beta);
 }
 
-template <typename T, typename A>
-inline T linear_bwd(T dd, T s, A alpha, A beta) {
+template <typename T, typename A,
+         typename U = typename utils::remove_reference<T>::type>
+inline U linear_bwd(T dd, T s, A alpha, A beta) {
     (void) s;
     (void) beta;
-    return (T)(dd * alpha);
+    return (U)(dd * alpha);
 }
 
-template <typename T, typename A>
-inline T bounded_relu_fwd(T s, A alpha) {
+template <typename T, typename A,
+         typename U = typename utils::remove_reference<T>::type>
+inline U bounded_relu_fwd(T s, A alpha) {
     s = s > 0 ? s : 0;
-    return s > alpha ? (T)(alpha) : s;
+    return s > alpha ? (U)(alpha) : s;
 }
 
-template <typename T, typename A>
-inline T bounded_relu_bwd(T dd, T s, A alpha) {
+template <typename T, typename A,
+         typename U = typename utils::remove_reference<T>::type>
+inline U bounded_relu_bwd(T dd, T s, A alpha) {
     return dd * (0 < s && s < alpha ? 1 : 0);
 }
 
-template <typename T>
-inline T soft_relu_fwd(T s) {
+template <typename T, typename U = typename utils::remove_reference<T>::type>
+inline U soft_relu_fwd(T s) {
     float max_logf = 8.872284e+01; //::logf(FLT_MAX)
-    return s < max_logf ? (T)(::log1pf(::expf((float)s))) : s;
+    return s < max_logf ? (U)(::log1pf(::expf((float)s))) : s;
 }
 
-template <typename T>
-inline T soft_relu_bwd(T dd, T s) {
-    return (T)(dd / (1 + ::expf((float)(-s))));
+template <typename T, typename U = typename utils::remove_reference<T>::type>
+inline U soft_relu_bwd(T dd, T s) {
+    return (U)(dd / (1 + ::expf((float)(-s))));
 }
 
-template <typename T>
-inline T logistic_fwd(T s) {
-    T v = (T)(::expf((float) -s));
+template <typename T, typename U = typename utils::remove_reference<T>::type>
+inline U logistic_fwd(T s) {
+    U v = (U)(::expf((float) -s));
     return 1 / (1 + v);
 }
 
-template <typename T>
-inline T logistic_bwd(T dd, T s) {
-    T v = logistic_fwd<T>(s);
+template <typename T, typename U = typename utils::remove_reference<T>::type>
+inline U logistic_bwd(T dd, T s) {
+    U v = logistic_fwd<T, U>(s);
     return dd * v * (1 - v);
 }
 
-template <typename T, typename A>
-T clamp_fwd(T s, A alpha, A beta) {
-    return s > alpha ? (T)(alpha) : s < beta ? (T)(beta) : s;
+template <typename T, typename A,
+         typename U = typename utils::remove_reference<T>::type>
+inline U clamp_fwd(T s, A alpha, A beta) {
+    return (U)(s > alpha ? alpha : s < beta ? beta : s);
 }
 
-template <typename T, typename A>
-T clamp_bwd(T dd, T s, A alpha, A beta) {
+template <typename T, typename A,
+         typename U = typename utils::remove_reference<T>::type>
+inline U clamp_bwd(T dd, T s, A alpha, A beta) {
     return dd * (beta < s && s < alpha ? 1 : 0);
 }
 
+template <typename T,
+         typename U = typename utils::remove_reference<T>::type>
+inline U exp_fwd(T s) {
+    return (U)(::expf((float)s));
+}
+
+template <typename T,
+         typename U = typename utils::remove_reference<T>::type>
+ inline U exp_bwd(T dd, T s) {
+    return (U)(::expf((float)s));
+}
+
+template <typename T,
+        typename U = typename utils::remove_reference<T>::type>
+inline U not_fwd(T s) {
+    return (U)(!s);
+}
+
+template <typename T, typename A,
+         typename U = typename utils::remove_reference<T>::type>
+inline U scale_shift_fwd(T s_val, A w_val, A b_val) {
+    return (U)(s_val*w_val + b_val);
+}
+
+template <typename T, typename A,
+         typename U = typename utils::remove_reference<T>::type>
+inline U prelu_fwd(T s_val, A w_val) {
+    return (U)(s_val >= 0 ? s_val : w_val*s_val);
+}
+
+inline bool eltwise_fwd_preserves_zero(alg_kind_t alg, bool jit_impl = false) {
+    using namespace alg_kind;
+    using namespace utils;
+    const bool preserves_zero = true
+        && !one_of(alg, eltwise_linear, eltwise_soft_relu, eltwise_logistic, eltwise_clamp, eltwise_exp, eltwise_not)
+        && IMPLICATION(jit_impl, !one_of(alg, eltwise_elu, eltwise_tanh, eltwise_clamp, eltwise_exp, eltwise_not));
+    return preserves_zero;
+}
+
+inline float get_bias(const char *bias, size_t offset, data_type_t data_type)
+{
+    if (!bias)
+        return 0.0f;
+
+#define CASE(dt) \
+    case dt: return (float)((const prec_traits<dt>::type *)bias)[offset]
+
+    switch (data_type) {
+    CASE(data_type::s8);
+    CASE(data_type::u8);
+    CASE(data_type::s32);
+    CASE(data_type::f32);
+    default: assert(!"unimplemented");
+    }
+    return 0; // never happens (should probably be a NaN)
+#undef CASE
+}
+
 }
 }
 }