Publishing 2019 R1 content
[platform/upstream/dldt.git] / inference-engine / thirdparty / mkl-dnn / src / cpu / jit_avx512_core_x8s8s32x_1x1_convolution.cpp
index a71f285..1bab22e 100644 (file)
 * limitations under the License.
 *******************************************************************************/
 
-#include "mkldnn_types.h"
-
 #include "c_types_map.hpp"
-#include "utils.hpp"
 #include "mkldnn_thread.hpp"
 #include "type_helpers.hpp"
+#include "utils.hpp"
+
 #include "jit_generator.hpp"
 
 #include "jit_avx512_core_x8s8s32x_1x1_convolution.hpp"
@@ -30,6 +29,7 @@ namespace cpu {
 
 using namespace mkldnn::impl::status;
 using namespace mkldnn::impl::memory_format;
+using namespace mkldnn::impl::memory_tracking::names;
 using namespace mkldnn::impl::utils;
 
 namespace {
@@ -56,41 +56,61 @@ void balance2D(U nthr, U ithr, T ny, T &ny_start, T &ny_end,
 }
 
 /* convolution forward */
-template <bool with_relu, data_type_t src_type, data_type_t dst_type>
-void _jit_avx512_core_x8s8s32x_1x1_convolution_fwd_t
-                              <with_relu, src_type, dst_type>::execute_forward()
+template <data_type_t src_type, data_type_t dst_type>
+void jit_avx512_core_x8s8s32x_1x1_convolution_fwd_t
+                              <src_type, dst_type>::execute_forward() const
 {
     auto src = reinterpret_cast<const src_data_t *>(this->input_memory(0));
     auto weights =
         reinterpret_cast<const wei_data_t *>(this->input_memory(1));
     auto bias = reinterpret_cast<const char *>(this->input_memory(2));
     auto dst = reinterpret_cast<dst_data_t *>(this->memory());
+
+    auto scratchpad = this->scratchpad();
+
+    if (pd()->jcp_.signed_input && pd()->jcp_.ver != ver_vnni) {
+        auto local_scales = scratchpad.template get<float>(
+                key_conv_adjusted_scales);
+        auto scales = pd()->attr()->output_scales_.scales_;
+        size_t count = pd()->attr()->output_scales_.count_;
+        float factor = 1.f / pd()->jcp_.wei_adj_scale;
+        if (count == 1) {
+            utils::array_set(local_scales, scales[0] * factor, 16);
+        } else {
+            for (size_t c = 0; c < count; c++)
+                local_scales[c] = scales[c] * factor;
+        }
+    }
+
     parallel(kernel_->jcp.nthr, [&](const int ithr, const int nthr) {
-        execute_forward_thr(ithr, nthr, src, weights, bias, dst);
+        execute_forward_thr(ithr, nthr, src, weights, bias, dst, scratchpad);
     });
 }
 
-template <bool with_relu, data_type_t src_type, data_type_t dst_type>
-void _jit_avx512_core_x8s8s32x_1x1_convolution_fwd_t<with_relu, src_type, dst_type>
+template <data_type_t src_type, data_type_t dst_type>
+void jit_avx512_core_x8s8s32x_1x1_convolution_fwd_t<src_type, dst_type>
 ::execute_forward_thr(const int ithr, const int nthr, const src_data_t *src,
-        const wei_data_t *weights, const char *bias, dst_data_t *dst) {
-    const memory_desc_wrapper src_d(conf_.src_pd());
-    const memory_desc_wrapper dst_d(conf_.dst_pd());
-    const memory_desc_wrapper weights_d(conf_.weights_pd(0));
+        const wei_data_t *weights, const char *bias, dst_data_t *dst,
+        const memory_tracking::grantor_t &scratchpad) const {
+    const memory_desc_wrapper src_d(pd()->src_pd());
+    const memory_desc_wrapper dst_d(pd()->dst_pd());
+    const memory_desc_wrapper weights_d(pd()->weights_pd(0));
 
-    const size_t bia_dt_size = conf_.with_bias()
-        ? types::data_type_size(conf_.cdesc()->bias_desc.data_type) : 0;
+    const size_t bia_dt_size = pd()->with_bias()
+        ? types::data_type_size(pd()->desc()->bias_desc.data_type) : 0;
 
     const auto &jcp = kernel_->jcp;
+    auto rtus_space = scratchpad.get<src_data_t>(key_conv_rtus_space);
+    auto local_scales = scratchpad.get<float>(key_conv_adjusted_scales);
 
     const int work_amount = jcp.mb * jcp.ngroups * jcp.nb_bcast;
 
-    const int stride_h = conf_.cdesc()->strides[0];
-    const int stride_w = conf_.cdesc()->strides[1];
-    const int pad_t = conf_.cdesc()->padding[0][0];
-    const int pad_l = conf_.cdesc()->padding[0][1];
+    const int stride_h = pd()->desc()->strides[0];
+    const int stride_w = pd()->desc()->strides[1];
+    const int pad_t = pd()->desc()->padding[0][0];
+    const int pad_l = pd()->desc()->padding[0][1];
 
-    const auto &oscales = conf_.attr()->output_scales_;
+    const auto &oscales = pd()->attr()->output_scales_;
 
     int offset = jcp.ngroups * (jcp.oc / jcp.oc_block) * (jcp.ic / jcp.ic_block)
         * jcp.oc_block * jcp.ic_block;
@@ -167,17 +187,17 @@ void _jit_avx512_core_x8s8s32x_1x1_convolution_fwd_t<with_relu, src_type, dst_ty
         const size_t dst_off = dst_d.blk_off(n, _ocb * jcp.oc_block, oh, ow);
 
         p.output_data = &dst[dst_off];
-        p.load_data = &weights[conf_.with_groups()
+        p.load_data = &weights[pd()->with_groups()
             ? weights_d.blk_off(g, ocb, icb)
             : weights_d.blk_off(ocb, icb)];
         p.bias_data = &bias[_ocb * jcp.oc_block * bia_dt_size];
         p.compensation = (jcp.signed_input)
             ? &compensation[_ocb * jcp.oc_block] : 0;
         p.scales = (jcp.signed_input && jcp.ver != ver_vnni)
-            ? &local_scales_[jcp.is_oc_scale * _ocb * jcp.oc_block]
+            ? &local_scales[jcp.is_oc_scale * _ocb * jcp.oc_block]
             : &oscales.scales_[jcp.is_oc_scale * _ocb * jcp.oc_block];
-        if (conf_.rtus_.reduce_src_) {
-            rp.ws = scratch_ + ithr * ws_per_thread_
+        if (pd()->rtus_.reduce_src_) {
+            rp.ws = rtus_space + ithr * pd()->rtus_.space_per_thread_
                 + _icb * jcp.is * jcp.ic_block;
             if (ocb == ocb_start) {
                 rp.src = src + src_d.blk_off(n, _icb * jcp.ic_block, ih, iw);
@@ -187,6 +207,8 @@ void _jit_avx512_core_x8s8s32x_1x1_convolution_fwd_t<with_relu, src_type, dst_ty
         } else
             p.bcast_data = src + src_d.blk_off(n, _icb * jcp.ic_block, ih, iw);
 
+        p.oc_off = _ocb * jcp.oc_block * sizeof(float);
+
         kernel_->jit_ker(&p);
     };
 
@@ -255,38 +277,16 @@ void _jit_avx512_core_x8s8s32x_1x1_convolution_fwd_t<with_relu, src_type, dst_ty
     }
 }
 
-template struct _jit_avx512_core_x8s8s32x_1x1_convolution_fwd_t<false,
-                                                  data_type::u8, data_type::u8>;
-template struct _jit_avx512_core_x8s8s32x_1x1_convolution_fwd_t<true,
-                                                  data_type::u8, data_type::u8>;
-template struct _jit_avx512_core_x8s8s32x_1x1_convolution_fwd_t<false,
-                                                  data_type::s8, data_type::u8>;
-template struct _jit_avx512_core_x8s8s32x_1x1_convolution_fwd_t<true,
-                                                  data_type::s8, data_type::u8>;
-template struct _jit_avx512_core_x8s8s32x_1x1_convolution_fwd_t<false,
-                                                  data_type::u8, data_type::s8>;
-template struct _jit_avx512_core_x8s8s32x_1x1_convolution_fwd_t<true,
-                                                  data_type::u8, data_type::s8>;
-template struct _jit_avx512_core_x8s8s32x_1x1_convolution_fwd_t<false,
-                                                  data_type::s8, data_type::s8>;
-template struct _jit_avx512_core_x8s8s32x_1x1_convolution_fwd_t<true,
-                                                  data_type::s8, data_type::s8>;
-template struct _jit_avx512_core_x8s8s32x_1x1_convolution_fwd_t<false,
-                                                 data_type::u8, data_type::s32>;
-template struct _jit_avx512_core_x8s8s32x_1x1_convolution_fwd_t<true,
-                                                 data_type::u8, data_type::s32>;
-template struct _jit_avx512_core_x8s8s32x_1x1_convolution_fwd_t<false,
-                                                 data_type::s8, data_type::s32>;
-template struct _jit_avx512_core_x8s8s32x_1x1_convolution_fwd_t<true,
-                                                 data_type::s8, data_type::s32>;
-template struct _jit_avx512_core_x8s8s32x_1x1_convolution_fwd_t<false,
-                                                 data_type::u8, data_type::f32>;
-template struct _jit_avx512_core_x8s8s32x_1x1_convolution_fwd_t<true,
-                                                 data_type::u8, data_type::f32>;
-template struct _jit_avx512_core_x8s8s32x_1x1_convolution_fwd_t<false,
-                                                 data_type::s8, data_type::f32>;
-template struct _jit_avx512_core_x8s8s32x_1x1_convolution_fwd_t<true,
-                                                 data_type::s8, data_type::f32>;
+using namespace data_type;
+template struct jit_avx512_core_x8s8s32x_1x1_convolution_fwd_t<u8, u8>;
+template struct jit_avx512_core_x8s8s32x_1x1_convolution_fwd_t<s8, u8>;
+template struct jit_avx512_core_x8s8s32x_1x1_convolution_fwd_t<u8, s8>;
+template struct jit_avx512_core_x8s8s32x_1x1_convolution_fwd_t<s8, s8>;
+template struct jit_avx512_core_x8s8s32x_1x1_convolution_fwd_t<u8, s32>;
+template struct jit_avx512_core_x8s8s32x_1x1_convolution_fwd_t<s8, s32>;
+template struct jit_avx512_core_x8s8s32x_1x1_convolution_fwd_t<u8, f32>;
+template struct jit_avx512_core_x8s8s32x_1x1_convolution_fwd_t<s8, f32>;
+
 }
 }
 }