Publishing 2019 R1 content
[platform/upstream/dldt.git] / inference-engine / thirdparty / mkl-dnn / src / cpu / gemm_inner_product.cpp
index d9a8fe5..7f62c6b 100644 (file)
@@ -31,20 +31,20 @@ using namespace mkldnn::impl::memory_format;
 using namespace mkldnn::impl::primitive_kind;
 
 template <impl::data_type_t data_type>
-void gemm_inner_product_fwd_t<data_type>::execute_forward() {
+void gemm_inner_product_fwd_t<data_type>::execute_forward() const {
     auto src = reinterpret_cast<const data_t *>(this->input_memory(0));
     auto weights = reinterpret_cast<const data_t *>(this->input_memory(1));
     auto bias = reinterpret_cast<const data_t *>(this->input_memory(2));
     auto dst = reinterpret_cast<data_t*>(this->memory());
 
-    const int MB = conf_.MB();
-    const int OC = conf_.OC();
-    const int IC = conf_.IC_total_padded();
+    const int MB = pd()->MB();
+    const int OC = pd()->OC();
+    const int IC = pd()->IC_total_padded();
 
-    bool wei_tr = !utils::one_of(conf_.weights_pd()->desc()->format,
+    bool wei_tr = !utils::one_of(pd()->weights_pd()->desc()->format,
              hwio, dhwio, io);
 
-    const auto &post_ops = conf_.attr()->post_ops_;
+    const auto &post_ops = pd()->attr()->post_ops_;
     const bool do_relu = post_ops.len_ == 1;
 
     float alpha = 1.0, beta = 0.0;
@@ -62,16 +62,16 @@ void gemm_inner_product_fwd_t<data_type>::execute_forward() {
 }
 
 template <impl::data_type_t data_type>
-void gemm_inner_product_bwd_data_t<data_type>::execute_backward_data() {
+void gemm_inner_product_bwd_data_t<data_type>::execute_backward_data() const {
     auto diff_dst = reinterpret_cast<const data_t *>(this->input_memory(0));
     auto weights = reinterpret_cast<const data_t *>(this->input_memory(1));
     auto diff_src = reinterpret_cast<data_t*>(this->memory());
 
-    const int MB = conf_.MB();
-    const int OC = conf_.OC();
-    const int IC = conf_.IC_total_padded();
+    const int MB = pd()->MB();
+    const int OC = pd()->OC();
+    const int IC = pd()->IC_total_padded();
 
-    bool wei_tr = utils::one_of(conf_.weights_pd()->desc()->format,
+    bool wei_tr = utils::one_of(pd()->weights_pd()->desc()->format,
              hwio, dhwio, io);
 
     float alpha = 1.0, beta = 0.0;
@@ -80,22 +80,22 @@ void gemm_inner_product_bwd_data_t<data_type>::execute_backward_data() {
 }
 
 template <impl::data_type_t data_type>
-void gemm_inner_product_bwd_weights_t<data_type>::execute_backward_weights() {
+void gemm_inner_product_bwd_weights_t<data_type>::execute_backward_weights() const {
     auto src = reinterpret_cast<const data_t *>(this->input_memory(0));
     auto diff_dst = reinterpret_cast<const data_t *>(this->input_memory(1));
     auto diff_weights = reinterpret_cast<data_t *>(this->memory(0));
     auto diff_bias = reinterpret_cast<data_t *>(this->memory(1));
 
-    const memory_desc_wrapper diff_dst_d(conf_.diff_dst_pd());
-    const memory_desc_wrapper diff_bias_d(conf_.diff_weights_pd(1));
+    const memory_desc_wrapper diff_dst_d(pd()->diff_dst_pd());
+    const memory_desc_wrapper diff_bias_d(pd()->diff_weights_pd(1));
 
     diff_dst += diff_dst_d.blocking_desc().offset_padding;
 
-    const int MB = conf_.MB();
-    const int OC = conf_.OC();
-    const int IC = conf_.IC_total_padded();
+    const int MB = pd()->MB();
+    const int OC = pd()->OC();
+    const int IC = pd()->IC_total_padded();
 
-    bool wei_tr = utils::one_of(conf_.diff_weights_pd()->desc()->format,
+    bool wei_tr = utils::one_of(pd()->diff_weights_pd()->desc()->format,
              hwio, dhwio, io);
 
     float alpha = 1.0, beta = 0.0;