Publishing 2019 R1 content
[platform/upstream/dldt.git] / inference-engine / thirdparty / mkl-dnn / src / cpu / ref_pooling.cpp
index 4ee010d..d7ae208 100644 (file)
@@ -30,43 +30,39 @@ namespace impl {
 namespace cpu {
 
 template <data_type_t data_type, data_type_t acc_type>
-void ref_pooling_fwd_t<data_type, acc_type>::execute_forward() {
+void ref_pooling_fwd_t<data_type, acc_type>::execute_forward() const {
     using namespace alg_kind;
     using namespace prop_kind;
 
-    auto alg = conf_.desc()->alg_kind;
+    auto alg = pd()->desc()->alg_kind;
 
     auto src = reinterpret_cast<const data_t *>(this->input_memory(0));
     auto dst = reinterpret_cast<data_t *>(this->memory(0));
-    auto ws = alg == pooling_max && conf_.desc()->prop_kind == forward_training
+    auto ws = alg == pooling_max && pd()->desc()->prop_kind == forward_training
         ? reinterpret_cast<unsigned char *>(this->memory(1)) : nullptr;
 
-    const memory_desc_wrapper src_d(conf_.src_pd());
-    const memory_desc_wrapper dst_d(conf_.dst_pd());
-    const memory_desc_wrapper ws_d(conf_.workspace_pd());
+    const memory_desc_wrapper src_d(pd()->src_pd());
+    const memory_desc_wrapper dst_d(pd()->dst_pd());
+    const memory_desc_wrapper ws_d(pd()->workspace_pd());
     const data_type_t ws_dt = ws ? ws_d.data_type() : data_type::undef;
 
-    const int ID = conf_.ID();
-    const int IH = conf_.IH();
-    const int IW = conf_.IW();
-    const int KD = conf_.KD();
-    const int KH = conf_.KH();
-    const int KW = conf_.KW();
-    const int SD = conf_.KSD();
-    const int SH = conf_.KSH();
-    const int SW = conf_.KSW();
-    const int padF = conf_.padFront();
-    const int padT = conf_.padT();
-    const int padL = conf_.padL();
-    const int padBack = conf_.padBack();
-    const int padB = conf_.padB();
-    const int padR = conf_.padR();
-
-    const bool is_3d = conf_.desc()->src_desc.ndims == 5;
-
-//    auto apply_offset = [=](int index, int offset) {
-//        return (index > offset) ? index - offset : 0;
-//    };
+    const int ID = pd()->ID();
+    const int IH = pd()->IH();
+    const int IW = pd()->IW();
+    const int KD = pd()->KD();
+    const int KH = pd()->KH();
+    const int KW = pd()->KW();
+    const int SD = pd()->KSD();
+    const int SH = pd()->KSH();
+    const int SW = pd()->KSW();
+    const int padF = pd()->padFront();
+    const int padT = pd()->padT();
+    const int padL = pd()->padL();
+    const int padBack = pd()->padBack();
+    const int padB = pd()->padB();
+    const int padR = pd()->padR();
+
+    const bool is_3d = pd()->desc()->src_desc.ndims == 5;
 
     auto set_ws = [=](int mb, int oc, int od, int oh, int ow, int value) {
         if (ws) {
@@ -195,11 +191,11 @@ void ref_pooling_fwd_t<data_type, acc_type>::execute_forward() {
         d[0] = math::out_round<data_t>((float)dst / num_summands);
     };
 
-    const int MB = conf_.MB();
-    const int OC = conf_.C();
-    const int OD = conf_.OD();
-    const int OH = conf_.OH();
-    const int OW = conf_.OW();
+    const int MB = pd()->MB();
+    const int OC = pd()->C();
+    const int OD = pd()->OD();
+    const int OH = pd()->OH();
+    const int OW = pd()->OW();
 
     if (alg == pooling_max) {
         parallel_nd(MB, OC, OD, OH, OW,
@@ -226,34 +222,34 @@ void ref_pooling_fwd_t<data_type, acc_type>::execute_forward() {
 }
 
 template <data_type_t data_type, data_type_t acc_type>
-void ref_pooling_bwd_t<data_type, acc_type>::execute_backward() {
+void ref_pooling_bwd_t<data_type, acc_type>::execute_backward() const {
     using namespace alg_kind;
 
     auto diff_dst = reinterpret_cast<const data_t *>(this->input_memory(0));
-    auto ws = conf_.desc()->alg_kind != alg_kind::pooling_max ? nullptr
+    auto ws = pd()->desc()->alg_kind != alg_kind::pooling_max ? nullptr
         : reinterpret_cast<const unsigned char *>(this->input_memory(1));
     auto diff_src = reinterpret_cast<data_t *>(this->memory(0));
 
-    const memory_desc_wrapper diff_dst_d(conf_.diff_dst_pd());
-    const memory_desc_wrapper ws_d(conf_.workspace_pd());
-    const memory_desc_wrapper diff_src_d(conf_.diff_src_pd());
+    const memory_desc_wrapper diff_dst_d(pd()->diff_dst_pd());
+    const memory_desc_wrapper ws_d(pd()->workspace_pd());
+    const memory_desc_wrapper diff_src_d(pd()->diff_src_pd());
 
-    const int ID = conf_.ID();
-    const int IH = conf_.IH();
-    const int IW = conf_.IW();
-    const int KD = conf_.KD();
-    const int KH = conf_.KH();
-    const int KW = conf_.KW();
-    const int SD = conf_.KSD();
-    const int SH = conf_.KSH();
-    const int SW = conf_.KSW();
-    const int padF = conf_.padFront();
-    const int padT = conf_.padT();
-    const int padL = conf_.padL();
+    const int ID = pd()->ID();
+    const int IH = pd()->IH();
+    const int IW = pd()->IW();
+    const int KD = pd()->KD();
+    const int KH = pd()->KH();
+    const int KW = pd()->KW();
+    const int SD = pd()->KSD();
+    const int SH = pd()->KSH();
+    const int SW = pd()->KSW();
+    const int padF = pd()->padFront();
+    const int padT = pd()->padT();
+    const int padL = pd()->padL();
 
-    const bool is_3d = conf_.desc()->diff_src_desc.ndims == 5;
+    const bool is_3d = pd()->desc()->diff_src_desc.ndims == 5;
 
-    auto alg = conf_.desc()->alg_kind;
+    auto alg = pd()->desc()->alg_kind;
 
     auto apply_offset = [=](int index, int offset) {
         return (index > offset) ? index - offset : 0;
@@ -360,13 +356,13 @@ void ref_pooling_bwd_t<data_type, acc_type>::execute_backward() {
         }
     };
 
-    const int MB = conf_.MB();
-    const int OC = conf_.C();
-    const int OD = conf_.OD();
-    const int OH = conf_.OH();
-    const int OW = conf_.OW();
+    const int MB = pd()->MB();
+    const int OC = pd()->C();
+    const int OD = pd()->OD();
+    const int OH = pd()->OH();
+    const int OW = pd()->OW();
 
-    if (conf_.desc()->alg_kind == alg_kind::pooling_max) {
+    if (pd()->desc()->alg_kind == alg_kind::pooling_max) {
         parallel_nd(MB, OC, [&](int mb, int oc) {
             if (is_3d) ker_zero_3d(mb, oc);
             else ker_zero(mb, oc);