Publishing 2019 R1 content
[platform/upstream/dldt.git] / inference-engine / thirdparty / mkl-dnn / tests / gtests / test_convolution_eltwise_forward_common.hpp
index 5337807..c0b6e21 100644 (file)
 
 #include "mkldnn_test_common.hpp"
 #include "gtest/gtest.h"
-
+#include "math_utils.hpp"
 #include "mkldnn.hpp"
 
-namespace mkldnn {
-
-
-template <typename T, typename A> inline T relu_fwd(T s, A alpha) {
-    return s > 0 ? s : static_cast<T>(s * alpha);
-}
-
-template <typename T> T tanh_fwd(T s) {
-    const float e = ::expf(2*s); /* maybe replace with -2*s? */
-    return static_cast<T>((e - 1.0) / (e + 1.0));
-}
-
-template <typename T, typename A> T elu_fwd(T s, A alpha) {
-    return s > 0 ? s : static_cast<T>(alpha * (::expf(s) - 1));
-}
-
-template <typename T>
-T square_fwd(T s) {
-    return s * s;
-}
-
-template <typename T>
-T abs_fwd(T s) {
-    return s > 0 ? s : -s;;
-}
-
-template <typename T>
-T sqrt_fwd(T s) {
-    return s > 0 ? ::sqrtf(s) : 0;
-}
-
-template <typename T, typename A>
-T linear_fwd(T s, A alpha, A beta) {
-    return alpha * s + beta;
-}
-
-template <typename T, typename A>
-T bounded_relu_fwd(T s, A alpha) {
-    s = s > 0 ? s : 0;
-    return s > alpha ? alpha : s;
-}
-
-template <typename T>
-T soft_relu_fwd(T s) {
-    return logf(1 + ::expf(s));
-}
+using namespace mkldnn::impl::math;
 
-template <typename T>
-T logistic_fwd(T s) {
-    T v = ::expf(s);
-    return v / (v + 1);
-}
-
-template <typename T, typename A>
-T clamp_fwd(T s, A alpha, A beta) {
-    return s > alpha ? (T)(alpha) : s < beta ? (T)(beta) : s;
-}
+namespace mkldnn {
 
 template <typename data_t_src, typename data_t_wei,
           typename data_t_acc, typename data_t_dst>
@@ -94,76 +40,60 @@ void compute_ref_conv_eltwise_fwd(const test_convolution_sizes_t &c,
     const memory::desc weights_d = weights.get_primitive_desc().desc();
     const memory::desc dst_d = dst.get_primitive_desc().desc();
 
+    size_t padded_ic = src_d.data.layout_desc.blocking.padding_dims[1];
+    size_t padded_oc = dst_d.data.layout_desc.blocking.padding_dims[1];
+
+    size_t padded_ic_w = weights_d.data.format == mkldnn_OhIw8o4i ? weights_d.data.layout_desc.blocking.padding_dims[1] :
+                                                                    src_d.data.layout_desc.blocking.padding_dims[1];
+    size_t padded_oc_w = weights_d.data.format == mkldnn_OhIw8o4i ? weights_d.data.layout_desc.blocking.padding_dims[0] :
+                                                                    dst_d.data.layout_desc.blocking.padding_dims[1];
+
     mkldnn::impl::parallel_nd(c.mb, c.ng, c.oc / c.ng, c.oh, c.ow,
         [&](int n, int g, int oc, int oh, int ow) {
-            int oidx = n * c.oc * c.oh * c.ow
-                       + g * c.oc / c.ng * c.oh * c.ow
-                       + oc * c.oh * c.ow + oh * c.ow + ow;
-
-            int didx = map_index(dst_d, oidx);
-            dst_data[didx] = bias_data ?
-                             bias_data[map_index(
-                                     bias.get_primitive_desc().desc(),
-                                     g * c.oc / c.ng + oc)] :
-                             data_t_dst{0};
-            for (int ic = 0; ic < c.ic / c.ng; ic++) {
-                for (int kh = 0; kh < c.kh; kh++) {
-                    for (int kw = 0; kw < c.kw; kw++) {
-                        int iw = ow * c.strw
-                                 - c.padw + kw * (1 + c.dilw);
-                        int ih = oh * c.strh
-                                 - c.padh + kh * (1 + c.dilh);
-                        if (iw < 0 || iw >= c.iw) continue;
-                        if (ih < 0 || ih >= c.ih) continue;
-                        int iidx = n * c.ic * c.ih * c.iw
-                                   + g * c.ic / c.ng * c.ih * c.iw
-                                   + ic * c.ih * c.iw + ih * c.iw + iw;
-                        int widx = g * c.oc / c.ng * c.ic
-                                   / c.ng * c.kh * c.kw
-                                   + oc * c.ic / c.ng * c.kh * c.kw
-                                   + ic * c.kh * c.kw + kh * c.kw + kw;
-
-                        dst_data[didx]
-                                += src_data[map_index(src_d, iidx)]
-                                   * weights_data[map_index(
-                                weights_d, widx)];
-                    }
-                }
+            size_t oidx = n * padded_oc * c.oh * c.ow
+                    + g * padded_oc / c.ng * c.oh * c.ow
+                    + oc * c.oh * c.ow + oh * c.ow + ow;
+
+            size_t didx = map_index(dst_d, oidx);
+            dst_data[didx] = bias_data
+                    ? bias_data[g * c.oc / c.ng + oc] : data_t_dst{0};
+
+            for (int ic = 0; ic < c.ic / c.ng; ic++)
+            for (int kh = 0; kh < c.kh; kh++)
+            for (int kw = 0; kw < c.kw; kw++)
+            {
+                int ih = oh * c.strh - c.padh + kh * (1 + c.dilh);
+                if (ih < 0 || ih >= c.ih) continue;
+                int iw = ow * c.strw - c.padw + kw * (1 + c.dilw);
+                if (iw < 0 || iw >= c.iw) continue;
+
+                size_t iidx = n * padded_ic * c.ih * c.iw
+                    + g * padded_ic / c.ng * c.ih * c.iw
+                    + ic * c.ih * c.iw + ih * c.iw + iw;
+                size_t widx = g * padded_oc_w / c.ng * padded_ic_w
+                    / c.ng * c.kh * c.kw
+                    + oc * padded_ic_w / c.ng * c.kh * c.kw
+                    + ic * c.kh * c.kw + kh * c.kw + kw;
+
+                dst_data[didx] += src_data[map_index(src_d, iidx)]
+                        * weights_data[map_index(weights_d, widx)];
             }
 
+            auto &d = dst_data[didx];
             switch (elt_alg) {
-                case eltwise_relu:
-                    dst_data[didx] = relu_fwd(dst_data[didx], elt_alpha);
-                    break;
-                case eltwise_tanh:
-                    dst_data[didx] = tanh_fwd(dst_data[didx]);
-                    break;
-                case eltwise_elu:
-                    dst_data[didx] = elu_fwd(dst_data[didx], elt_alpha);
-                    break;
-                case eltwise_square:
-                    dst_data[didx] = square_fwd(dst_data[didx]);
-                    break;
-                case eltwise_abs:
-                    dst_data[didx] = abs_fwd(dst_data[didx]);
-                    break;
-                case eltwise_sqrt:
-                    dst_data[didx] = sqrt_fwd(dst_data[didx]);
-                    break;
-                case eltwise_linear:
-                    dst_data[didx] = linear_fwd(dst_data[didx], elt_alpha, elt_beta);
-                    break;
-                case eltwise_bounded_relu:
-                    dst_data[didx] = bounded_relu_fwd(dst_data[didx], elt_alpha);
-                    break;
-                case eltwise_soft_relu:
-                    dst_data[didx] = soft_relu_fwd(dst_data[didx]);
-                    break;
-                case eltwise_logistic:
-                    dst_data[didx] = logistic_fwd(dst_data[didx]);
-                    break;
-                default:
-                    assert(!"unknown alg_kind");
+            case eltwise_relu: d = relu_fwd(d, elt_alpha); break;
+            case eltwise_tanh: d = tanh_fwd(d); break;
+            case eltwise_elu: d = elu_fwd(d, elt_alpha); break;
+            case eltwise_square: d = square_fwd(d); break;
+            case eltwise_abs: d = abs_fwd(d); break;
+            case eltwise_sqrt: d = sqrt_fwd(d); break;
+            case eltwise_linear: d = linear_fwd(d, elt_alpha, elt_beta); break;
+            case eltwise_bounded_relu: d = bounded_relu_fwd(d, elt_alpha); break;
+            case eltwise_soft_relu: d = soft_relu_fwd(d); break;
+            case eltwise_logistic: d = logistic_fwd(d); break;
+            case eltwise_clamp: d = clamp_fwd(d, elt_alpha, elt_beta); break;
+            case eltwise_exp: d = exp_fwd(d); break;
+            default: assert(!"unknown alg_kind");
             }
         }
     );
@@ -174,8 +104,7 @@ template <typename data_t_src, typename data_t_wei,
 class convolution_eltwise_test
     : public ::testing::TestWithParam<test_convolution_eltwise_params_t> {
 protected:
-    virtual void SetUp()
-    {
+    virtual void SetUp() {
         test_convolution_eltwise_params_t p
                 = ::testing::TestWithParam<
                 test_convolution_eltwise_params_t>::GetParam();
@@ -209,11 +138,15 @@ protected:
         auto dst_ref = memory({c_dst_desc, eng});
 
         fill_data<data_t_src>(c_src.get_primitive_desc().get_size()
-                / sizeof(data_t_src), (data_t_src *)c_src.get_data_handle(), data_t_src(0), data_t_src(1));
+                / sizeof(data_t_src), (data_t_src *)c_src.get_data_handle(),
+                data_t_src(0), data_t_src(1));
+        check_zero_tail<data_t_src>(1, c_src);
 
         fill_data<data_t_wei>(
                 c_weights.get_primitive_desc().get_size()
-                / sizeof(data_t_wei),(data_t_wei *)c_weights.get_data_handle(), data_t_wei(0), data_t_wei(1));
+                / sizeof(data_t_wei),(data_t_wei *)c_weights.get_data_handle(),
+                data_t_wei(0), data_t_wei(1));
+        check_zero_tail<data_t_wei>(1, c_weights);
 
         bool with_bias = p.formats.bias_format != memory::format::format_undef;
         auto c_bias_desc = with_bias ?
@@ -226,7 +159,7 @@ protected:
                     (data_t_dst *)c_bias.get_data_handle(), 1., true);
         }
 
-        std::vector<int> padR = { cd.padh, cd.padw };
+        std::vector<ptrdiff_t> padR = { cd.padh, cd.padw };
         for (int i = 0; i < 2; ++i) {
             if ((cd.ih - ((cd.kh - 1) * (cd.dilh + 1) + 1) + cd.padh + padR[0])
                 / cd.strh + 1 != cd.oh)
@@ -273,7 +206,10 @@ protected:
         compute_ref_conv_eltwise_fwd<data_t_src, data_t_wei, data_t_wei,
             data_t_dst>(cd, c_src, c_weights, c_bias, dst_ref, with_bias,
                         p.alg, eltwise_alpha, eltwise_beta);
-        compare_data<data_t_dst>(dst_ref, c_dst);
+        check_zero_tail<data_t_dst>(1, dst_ref);
+
+        compare_data<data_t_dst>(dst_ref, c_dst, 1e-2);
+        check_zero_tail<data_t_dst>(0, c_dst);
     }
 };