Publishing 2019 R1 content
[platform/upstream/dldt.git] / inference-engine / thirdparty / mkl-dnn / tests / benchdnn / conv / ref_conv.cpp
index a471d21..60b7912 100644 (file)
@@ -15,6 +15,7 @@
 *******************************************************************************/
 
 #include "src/common/mkldnn_thread.hpp"
+#include "src/common/math_utils.hpp"
 
 #include "conv/conv_common.hpp"
 
@@ -85,17 +86,29 @@ void compute_ref_direct_fwd(const prb_t *p, dnn_mem_t &src_m,
     };
 
     auto maybe_post_ops = [&](float &conv_res, float dst) {
+        using namespace mkldnn::impl::math;
+
         const auto &ops = p->attr.post_ops;
         for (int idx = 0; idx < ops.len; ++idx) {
             using pk = attr_t::post_ops_t::kind_t;
             const auto &e = ops.entry[idx];
+
+            const auto &s = e.eltwise.scale;
+            const auto &a = e.eltwise.alpha;
+            const auto &b = e.eltwise.beta;
+
             switch (e.kind) {
-            case pk::SUM:
-                conv_res += e.sum.scale * dst;
-                break;
-            case pk::RELU:
-                conv_res = e.eltwise.scale * (conv_res < 0 ? 0 : conv_res);
-                break;
+            case pk::SUM: conv_res += e.sum.scale * dst; break;
+            case pk::RELU: conv_res = s*relu_fwd(conv_res, a); break;
+            case pk::TANH: conv_res = s*tanh_fwd(conv_res); break;
+            case pk::ELU: conv_res = s*elu_fwd(conv_res, a); break;
+            case pk::SQUARE: conv_res = s*square_fwd(conv_res); break;
+            case pk::ABS: conv_res = s*abs_fwd(conv_res); break;
+            case pk::SQRT: conv_res = s*sqrt_fwd(conv_res); break;
+            case pk::LINEAR: conv_res = s*linear_fwd(conv_res, a, b); break;
+            case pk::BRELU: conv_res = s*bounded_relu_fwd(conv_res, a); break;
+            case pk::SRELU: conv_res = s*soft_relu_fwd(conv_res); break;
+            case pk::LOGISTIC: conv_res = s*logistic_fwd(conv_res); break;
             default:
                 assert(!"unknown attr::post_ops::kind");
             }
@@ -115,9 +128,6 @@ void compute_ref_direct_fwd(const prb_t *p, dnn_mem_t &src_m,
                 conv_res += ((float*)bia_m)[bia_off];
             }
 
-            if (p->merge == RELU && conv_res < 0)
-                conv_res = 0;
-
             maybe_scale(conv_res, g * p->oc / p->g + oc);
             maybe_post_ops(conv_res, dst);
 
@@ -211,21 +221,55 @@ void compute_ref_direct_bwd_d(const prb_t *p, dnn_mem_t &diff_src_m,
         }
     };
 
+    /* Used for Deconv FWD */
+    auto maybe_post_ops = [&](float &conv_res, float dst) {
+        using namespace mkldnn::impl::math;
+
+        const auto &ops = p->attr.post_ops;
+        for (int idx = 0; idx < ops.len; ++idx) {
+            using pk = attr_t::post_ops_t::kind_t;
+            const auto &e = ops.entry[idx];
+
+            const auto &s = e.eltwise.scale;
+            const auto &a = e.eltwise.alpha;
+            const auto &b = e.eltwise.beta;
+
+            switch (e.kind) {
+            case pk::SUM: conv_res += e.sum.scale * dst; break;
+            case pk::RELU: conv_res = s*relu_fwd(conv_res, a); break;
+            case pk::TANH: conv_res = s*tanh_fwd(conv_res); break;
+            case pk::ELU: conv_res = s*elu_fwd(conv_res, a); break;
+            case pk::SQUARE: conv_res = s*square_fwd(conv_res); break;
+            case pk::ABS: conv_res = s*abs_fwd(conv_res); break;
+            case pk::SQRT: conv_res = s*sqrt_fwd(conv_res); break;
+            case pk::LINEAR: conv_res = s*linear_fwd(conv_res, a, b); break;
+            case pk::BRELU: conv_res = s*bounded_relu_fwd(conv_res, a); break;
+            case pk::SRELU: conv_res = s*soft_relu_fwd(conv_res); break;
+            case pk::LOGISTIC: conv_res = s*logistic_fwd(conv_res); break;
+            default:
+                assert(!"unknown attr::post_ops::kind");
+            }
+        }
+    };
+
     mkldnn::impl::parallel_nd(p->g, p->mb, p->ic / p->g, p->id, p->ih, p->iw,
         [&](int g, int mb, int ic, int id, int ih, int iw) {
             size_t src_off = src_off_f(p, mb, g, ic, id, ih, iw);
             float &ds = ((float*)diff_src_m)[src_off];
-            ds = 0;
+            float conv_res = 0;
             if (fast)
-                ker_fast(ds, g, mb, ic, id, ih, iw);
+                ker_fast(conv_res, g, mb, ic, id, ih, iw);
             else
-                ker(ds, g, mb, ic, id, ih, iw);
+                ker(conv_res, g, mb, ic, id, ih, iw);
 
             if (p->dir & FLAG_BIA) {
                 const size_t bia_off = (size_t)g * p->ic / p->g + ic;
-                ds += ((float*)bia_m)[bia_off];
+                conv_res += ((float*)bia_m)[bia_off];
             }
-            maybe_scale(ds, g * p->ic / p->g + ic);
+            maybe_scale(conv_res, g * p->ic / p->g + ic);
+            maybe_post_ops(conv_res, ds);
+
+            ds = conv_res;
         }
     );
 }