Publishing 2019 R1 content
[platform/upstream/dldt.git] / inference-engine / thirdparty / mkl-dnn / src / cpu / ref_convolution.cpp
index 33b5fe0..d3e6483 100644 (file)
@@ -27,56 +27,59 @@ namespace impl {
 namespace cpu {
 
 using math::saturate;
+using math::get_bias;
 
-template <bool with_relu, data_type_t src_type, data_type_t wei_type,
+template <data_type_t src_type, data_type_t wei_type,
          data_type_t dst_type, data_type_t acc_type>
-void _ref_convolution_fwd_t<with_relu, src_type, wei_type, dst_type, acc_type>
-        ::execute_forward() {
+void ref_convolution_fwd_t<src_type, wei_type, dst_type, acc_type>
+        ::execute_forward() const {
     auto src = reinterpret_cast<const src_data_t *>(this->input_memory(0));
     auto weights = reinterpret_cast<const wei_data_t *>(this->input_memory(1));
     auto bias = reinterpret_cast<const char *>(this->input_memory(2));
     auto dst = reinterpret_cast<dst_data_t *>(this->memory());
 
-    const memory_desc_wrapper src_d(conf_.src_pd());
-    const memory_desc_wrapper dst_d(conf_.dst_pd());
-    const memory_desc_wrapper weights_d(conf_.weights_pd(0));
-    const memory_desc_wrapper bias_d(conf_.weights_pd(1));
+    const memory_desc_wrapper src_d(pd()->src_pd());
+    const memory_desc_wrapper dst_d(pd()->dst_pd());
+    const memory_desc_wrapper weights_d(pd()->weights_pd(0));
+    const memory_desc_wrapper bias_d(pd()->weights_pd(1));
 
-    const bool with_groups = conf_.with_groups();
+    const bool with_groups = pd()->with_groups();
 
-    const int G = conf_.G();
-    const int MB = conf_.MB();
-    const int OD = conf_.OD();
-    const int OH = conf_.OH();
-    const int OW = conf_.OW();
-    const int ID = conf_.ID();
-    const int IH = conf_.IH();
-    const int IW = conf_.IW();
+    const int G = pd()->G();
+    const int MB = pd()->MB();
+    const int OD = pd()->OD();
+    const int OH = pd()->OH();
+    const int OW = pd()->OW();
+    const int ID = pd()->ID();
+    const int IH = pd()->IH();
+    const int IW = pd()->IW();
 
-    const int OC = conf_.OC() / G;
-    const int IC = conf_.IC() / G;
-    const int KD = conf_.KD();
-    const int KH = conf_.KH();
-    const int KW = conf_.KW();
+    const int OC = pd()->OC() / G;
+    const int IC = pd()->IC() / G;
+    const int KD = pd()->KD();
+    const int KH = pd()->KH();
+    const int KW = pd()->KW();
 
-    const int KSD = conf_.KSD();
-    const int KSH = conf_.KSH();
-    const int KSW = conf_.KSW();
+    const int KSD = pd()->KSD();
+    const int KSH = pd()->KSH();
+    const int KSW = pd()->KSW();
 
-    const int KDD = conf_.KDD();
-    const int KDH = conf_.KDH();
-    const int KDW = conf_.KDW();
+    const int KDD = pd()->KDD();
+    const int KDH = pd()->KDH();
+    const int KDW = pd()->KDW();
 
-    const int padFront = conf_.padFront();
-    const int padT = conf_.padT();
-    const int padL = conf_.padL();
+    const int padFront = pd()->padFront();
+    const int padT = pd()->padT();
+    const int padL = pd()->padL();
 
-    const float nslope = conf_.negative_slope();
+    const bool with_relu = 0; // TODO: change if support post_ops
+    const float nslope = 0.f;
 
-    const int ndims = conf_.cdesc()->src_desc.ndims;
+    const int ndims = pd()->desc()->src_desc.ndims;
 
-    auto ker = [=](acc_data_t &d, int g, int mb, int oc, int od, int oh,
+    auto ker = [=](int g, int mb, int oc, int od, int oh,
             int ow) {
+        acc_data_t d = 0;
         for (int ic = 0; ic < IC; ++ic)
         for (int kd = 0; kd < KD; ++kd)
         for (int kh = 0; kh < KH; ++kh)
@@ -107,36 +110,23 @@ void _ref_convolution_fwd_t<with_relu, src_type, wei_type, dst_type, acc_type>
            else
                assert(false);
 
-       }
-    };
-    auto get_bias = [=, &bias](size_t off) -> float {
-#       define CASE(dt) case dt: \
-            return (float)(*((const prec_traits<dt>::type *)bias + off))
-        switch (conf_.cdesc()->bias_desc.data_type) {
-        CASE(data_type::s8);
-        CASE(data_type::u8);
-        CASE(data_type::s32);
-        CASE(data_type::f32);
-        default: assert(!"unimplemented");
         }
-#       undef CASE
-        return 0;
+        return d;
     };
+
     parallel_nd(G, MB, OC, OD, OH, OW,
         [&](int g, int mb, int oc, int od, int oh, int ow) {
-        acc_data_t a = 0;
-        ker(a, g, mb, oc, od, oh, ow);
-
-        float a_fp = (float)a;
+        float a_fp = ker(g, mb, oc, od, oh, ow);
 
         if (bias)
-            a_fp += get_bias(bias_d.off(g*OC + oc));
+            a_fp += get_bias(bias, bias_d.off(g * OC + oc),
+                             pd()->desc()->bias_desc.data_type);
 
         if (with_relu && a_fp < 0)
             a_fp *= nslope;
 
         if (data_traits<dst_data_t>::data_type != data_type::f32) {
-            switch (conf_.attr()->round_mode_) {
+            switch (pd()->attr()->round_mode_) {
                 case round_mode::down:    a_fp = floorf(a_fp); break;
                 case round_mode::nearest: a_fp = nearbyintf(a_fp); break;
             }
@@ -156,51 +146,52 @@ void _ref_convolution_fwd_t<with_relu, src_type, wei_type, dst_type, acc_type>
 template <data_type_t diff_src_type, data_type_t wei_type,
          data_type_t diff_dst_type, data_type_t acc_type>
 void ref_convolution_bwd_data_t<diff_src_type, wei_type, diff_dst_type,
-     acc_type>::execute_backward_data() {
+     acc_type>::execute_backward_data() const {
     auto diff_dst = reinterpret_cast<const diff_dst_data_t*>(
             this->input_memory(0));
     auto weights = reinterpret_cast<const wei_data_t*>(this->input_memory(1));
     auto bias = reinterpret_cast<const char *>(this->input_memory(2));
     auto diff_src = reinterpret_cast<diff_src_data_t*>(this->memory());
 
-    const memory_desc_wrapper diff_dst_d(conf_.diff_dst_pd());
-    const memory_desc_wrapper diff_src_d(conf_.diff_src_pd());
-    const memory_desc_wrapper weights_d(conf_.weights_pd(0));
-    const memory_desc_wrapper bias_d(conf_.weights_pd(1));
+    const memory_desc_wrapper diff_dst_d(pd()->diff_dst_pd());
+    const memory_desc_wrapper diff_src_d(pd()->diff_src_pd());
+    const memory_desc_wrapper weights_d(pd()->weights_pd(0));
+    const memory_desc_wrapper bias_d(pd()->weights_pd(1));
 
-    const bool with_groups = conf_.with_groups();
+    const bool with_groups = pd()->with_groups();
 
-    const int G = conf_.G();
-    const int MB = conf_.MB();
-    const int OD = conf_.OD();
-    const int OH = conf_.OH();
-    const int OW = conf_.OW();
-    const int ID = conf_.ID();
-    const int IH = conf_.IH();
-    const int IW = conf_.IW();
+    const int G = pd()->G();
+    const int MB = pd()->MB();
+    const int OD = pd()->OD();
+    const int OH = pd()->OH();
+    const int OW = pd()->OW();
+    const int ID = pd()->ID();
+    const int IH = pd()->IH();
+    const int IW = pd()->IW();
 
-    const int OC = conf_.OC() / G;
-    const int IC = conf_.IC() / G;
-    const int KD = conf_.KD();
-    const int KH = conf_.KH();
-    const int KW = conf_.KW();
+    const int OC = pd()->OC() / G;
+    const int IC = pd()->IC() / G;
+    const int KD = pd()->KD();
+    const int KH = pd()->KH();
+    const int KW = pd()->KW();
 
-    const int KSD = conf_.KSD();
-    const int KSH = conf_.KSH();
-    const int KSW = conf_.KSW();
+    const int KSD = pd()->KSD();
+    const int KSH = pd()->KSH();
+    const int KSW = pd()->KSW();
 
-    const int KDD = conf_.KDD();
-    const int KDH = conf_.KDH();
-    const int KDW = conf_.KDW();
+    const int KDD = pd()->KDD();
+    const int KDH = pd()->KDH();
+    const int KDW = pd()->KDW();
 
-    const int padFront = conf_.padFront();
-    const int padT = conf_.padT();
-    const int padL = conf_.padL();
+    const int padFront = pd()->padFront();
+    const int padT = pd()->padT();
+    const int padL = pd()->padL();
 
-    const int ndims = conf_.cdesc()->diff_src_desc.ndims;
+    const int ndims = pd()->desc()->diff_src_desc.ndims;
 
-    auto ker = [=](acc_data_t &d, int g, int mb, int ic, int id, int ih,
+    auto ker = [=](int g, int mb, int ic, int id, int ih,
             int iw) {
+        acc_data_t d = 0;
         for (int oc = 0; oc < OC; ++oc)
         for (int kd = 0; kd < KD; ++kd)
         for (int kh = 0; kh < KH; ++kh)
@@ -239,20 +230,9 @@ void ref_convolution_bwd_data_t<diff_src_type, wei_type, diff_dst_type,
                     assert(false);
             }
         }
+        return d;
     };
-    auto get_bias = [=, &bias](size_t off) -> acc_data_t {
-#       define CASE(dt) case dt: \
-            return (acc_data_t)(*((const prec_traits<dt>::type *)bias + off))
-        switch (conf_.desc()->bias_desc.data_type) {
-        CASE(data_type::s8);
-        CASE(data_type::u8);
-        CASE(data_type::s32);
-        CASE(data_type::f32);
-        default: assert(!"unimplemented");
-        }
-#       undef CASE
-        return 0;
-    };
+
     parallel_nd(G, MB, IC, ID, IH, IW,
         [&](int g, int mb, int ic, int id, int ih, int iw) {
         auto ds_idx = (ndims == 5)
@@ -260,10 +240,11 @@ void ref_convolution_bwd_data_t<diff_src_type, wei_type, diff_dst_type,
             : (ndims == 4)
             ? diff_src_d.off(mb, g*IC + ic, ih, iw)
             : diff_src_d.off(mb, g*IC + ic, iw);
-        acc_data_t a = bias
-            ? get_bias(bias_d.off(g*IC + ic))
-            : (acc_data_t)0;
-        ker(a, g, mb, ic, id, ih, iw);
+        float a = bias
+            ? get_bias(bias, bias_d.off(g * IC + ic),
+                    pd()->desc()->bias_desc.data_type)
+            : 0;
+        a += ker(g, mb, ic, id, ih, iw);
         diff_src[ds_idx] = saturate<diff_src_data_t>(a);
     });
 }
@@ -271,48 +252,48 @@ void ref_convolution_bwd_data_t<diff_src_type, wei_type, diff_dst_type,
 template <data_type_t src_type, data_type_t diff_wei_type,
          data_type_t diff_dst_type, data_type_t acc_type>
 void ref_convolution_bwd_weights_t<src_type, diff_wei_type, diff_dst_type,
-     acc_type>::execute_backward_weights() {
+     acc_type>::execute_backward_weights() const {
     auto src = reinterpret_cast<const src_data_t *>(this->input_memory(0));
     auto diff_dst = reinterpret_cast<const diff_dst_data_t *>(
             this->input_memory(1));
     auto diff_weights = reinterpret_cast<diff_wei_data_t*>(this->memory(0));
     auto diff_bias = reinterpret_cast<diff_wei_data_t *>(this->memory(1));
 
-    const memory_desc_wrapper src_d(conf_.src_pd());
-    const memory_desc_wrapper diff_dst_d(conf_.diff_dst_pd());
-    const memory_desc_wrapper diff_weights_d(conf_.diff_weights_pd(0));
-    const memory_desc_wrapper diff_bias_d(conf_.diff_weights_pd(1));
+    const memory_desc_wrapper src_d(pd()->src_pd());
+    const memory_desc_wrapper diff_dst_d(pd()->diff_dst_pd());
+    const memory_desc_wrapper diff_weights_d(pd()->diff_weights_pd(0));
+    const memory_desc_wrapper diff_bias_d(pd()->diff_weights_pd(1));
 
-    const bool with_groups = conf_.with_groups();
+    const bool with_groups = pd()->with_groups();
 
-    const int G = conf_.G();
-    const int MB = conf_.MB();
-    const int OD = conf_.OD();
-    const int OH = conf_.OH();
-    const int OW = conf_.OW();
-    const int ID = conf_.ID();
-    const int IH = conf_.IH();
-    const int IW = conf_.IW();
+    const int G = pd()->G();
+    const int MB = pd()->MB();
+    const int OD = pd()->OD();
+    const int OH = pd()->OH();
+    const int OW = pd()->OW();
+    const int ID = pd()->ID();
+    const int IH = pd()->IH();
+    const int IW = pd()->IW();
 
-    const int OC = conf_.OC() / G;
-    const int IC = conf_.IC() / G;
-    const int KD = conf_.KD();
-    const int KH = conf_.KH();
-    const int KW = conf_.KW();
+    const int OC = pd()->OC() / G;
+    const int IC = pd()->IC() / G;
+    const int KD = pd()->KD();
+    const int KH = pd()->KH();
+    const int KW = pd()->KW();
 
-    const int KSD = conf_.KSD();
-    const int KSH = conf_.KSH();
-    const int KSW = conf_.KSW();
+    const int KSD = pd()->KSD();
+    const int KSH = pd()->KSH();
+    const int KSW = pd()->KSW();
 
-    const int KDD = conf_.KDD();
-    const int KDH = conf_.KDH();
-    const int KDW = conf_.KDW();
+    const int KDD = pd()->KDD();
+    const int KDH = pd()->KDH();
+    const int KDW = pd()->KDW();
 
-    const int padFront = conf_.padFront();
-    const int padT = conf_.padT();
-    const int padL = conf_.padL();
+    const int padFront = pd()->padFront();
+    const int padT = pd()->padT();
+    const int padL = pd()->padL();
 
-    const int ndims = conf_.cdesc()->src_desc.ndims;
+    const int ndims = pd()->desc()->src_desc.ndims;
 
 auto ker = [=](acc_data_t &d, int g, int oc, int ic, int kd, int kh, int kw) {
         for (int mb = 0; mb < MB; ++mb)
@@ -364,6 +345,7 @@ auto ker = [=](acc_data_t &d, int g, int oc, int ic, int kd, int kh, int kw) {
 
     parallel_nd(G, OC, [&](int g, int oc) {
         if (diff_bias) {
+            // XXX: loss of precision when bias is a float...
             acc_data_t db = 0;
             ker_bias(db, g, oc);
             diff_bias[diff_bias_d.off(g*OC+oc)]
@@ -401,19 +383,13 @@ auto ker = [=](acc_data_t &d, int g, int oc, int ic, int kd, int kh, int kw) {
 
 using namespace data_type;
 
-template struct _ref_convolution_fwd_t<false, f32>;
-template struct _ref_convolution_fwd_t<true, f32>;
-template struct _ref_convolution_fwd_t<false, s16, s16, s32, s32>;
-template struct _ref_convolution_fwd_t<true, s16, s16, s32, s32>;
-
-template struct _ref_convolution_fwd_t<false, u8, s8, f32, s32>;
-template struct _ref_convolution_fwd_t<true, u8, s8, f32, s32>;
-template struct _ref_convolution_fwd_t<false, u8, s8, s32, s32>;
-template struct _ref_convolution_fwd_t<true, u8, s8, s32, s32>;
-template struct _ref_convolution_fwd_t<false, u8, s8, s8, s32>;
-template struct _ref_convolution_fwd_t<true, u8, s8, s8, s32>;
-template struct _ref_convolution_fwd_t<false, u8, s8, u8, s32>;
-template struct _ref_convolution_fwd_t<true, u8, s8, u8, s32>;
+template struct ref_convolution_fwd_t<f32>;
+template struct ref_convolution_fwd_t<s16, s16, s32, s32>;
+
+template struct ref_convolution_fwd_t<u8, s8, f32, s32>;
+template struct ref_convolution_fwd_t<u8, s8, s32, s32>;
+template struct ref_convolution_fwd_t<u8, s8, s8, s32>;
+template struct ref_convolution_fwd_t<u8, s8, u8, s32>;
 
 template struct ref_convolution_bwd_data_t<f32, f32, f32, f32>;
 template struct ref_convolution_bwd_data_t<s32, s16, s16, s32>;