Publishing 2019 R1 content
[platform/upstream/dldt.git] / inference-engine / thirdparty / mkl-dnn / src / cpu / jit_avx512_core_x8s8s32x_convolution.cpp
index 8d1297f..e5cdcb1 100644 (file)
@@ -14,7 +14,6 @@
 * limitations under the License.
 *******************************************************************************/
 
-#include "mkldnn_types.h"
 #include "c_types_map.hpp"
 #include "mkldnn_thread.hpp"
 #include "type_helpers.hpp"
@@ -28,6 +27,7 @@ namespace cpu {
 
 using namespace mkldnn::impl::status;
 using namespace mkldnn::impl::memory_format;
+using namespace mkldnn::impl::memory_tracking::names;
 using namespace mkldnn::impl::utils;
 
 using namespace nstl;
@@ -35,37 +35,52 @@ using namespace nstl;
 using jit_conv_ker_t = void (*)(jit_conv_call_s *);
 
 #define wht_blk_off(d, g, ...) \
-        (conf_.with_groups() \
+        (pd()->with_groups() \
          ? (d).blk_off((g), __VA_ARGS__) \
          : (d).blk_off(__VA_ARGS__))
 
-template <bool with_relu, data_type_t src_type, data_type_t dst_type>
-void _jit_avx512_core_x8s8s32x_convolution_fwd_t<with_relu, src_type, dst_type>::
-execute_forward()
+template <data_type_t src_type, data_type_t dst_type>
+void jit_avx512_core_x8s8s32x_convolution_fwd_t<src_type, dst_type>::
+execute_forward() const
 {
     auto src = reinterpret_cast<const src_data_t *>(this->input_memory(0));
     auto weights = reinterpret_cast<const wei_data_t *>(this->input_memory(1));
     auto bias = reinterpret_cast<const char *>(this->input_memory(2));
     auto dst = reinterpret_cast<dst_data_t *>(this->memory());
 
-    const memory_desc_wrapper src_d(conf_.src_pd());
-    const memory_desc_wrapper dst_d(conf_.dst_pd());
-    const memory_desc_wrapper weights_d(conf_.weights_pd(0));
-    const memory_desc_wrapper bias_d(conf_.weights_pd(1));
+    const memory_desc_wrapper src_d(pd()->src_pd());
+    const memory_desc_wrapper dst_d(pd()->dst_pd());
+    const memory_desc_wrapper weights_d(pd()->weights_pd(0));
+    const memory_desc_wrapper bias_d(pd()->weights_pd(1));
 
-    const size_t bia_dt_size = conf_.with_bias()
-        ? types::data_type_size(conf_.cdesc()->bias_desc.data_type) : 0;
+    const size_t bia_dt_size = pd()->with_bias()
+        ? types::data_type_size(pd()->desc()->bias_desc.data_type) : 0;
 
-    const auto &jcp = kernel_->jcp;
+    const auto &jcp = pd()->jcp_;
     assert(jcp.nb_oc % jcp.nb_oc_blocking == 0);
+    assert(jcp.nb_ch % jcp.nb_ch_blocking == 0);
+
+    const float *oscales = pd()->attr()->output_scales_.scales_;
+    if (jcp.signed_input && jcp.ver != ver_vnni) {
+        auto local_scales = scratchpad().template get<float>(
+                key_conv_adjusted_scales);
+        size_t count = pd()->attr()->output_scales_.count_;
+        float factor = 1.f / pd()->jcp_.wei_adj_scale;
+        if (count == 1) {
+            utils::array_set(local_scales, oscales[0] * factor, 16);
+        } else {
+            for (size_t c = 0; c < count; c++)
+                local_scales[c] = oscales[c] * factor;
+        }
+        oscales = local_scales;
+    }
 
-    size_t offset = (size_t)jcp.ngroups * jcp.oc * jcp.ic * jcp.kh * jcp.kw;
+    size_t offset = weights_d.size() - weights_d.additional_buffer_size();
     auto w = const_cast<wei_data_t *>(weights);
     int32_t* compensation = (jcp.signed_input)
                                 ? reinterpret_cast<int32_t *>(&w[offset]) : 0;
-    const auto &oscales = conf_.attr()->output_scales_;
     int oc_chunks = jcp.nb_oc / jcp.nb_oc_blocking;
-    int nb_groups = jcp.nb_ch;
+    int nb_groups = jcp.nb_ch / jcp.nb_ch_blocking;
     int group_block = jcp.ch_block;
     int work_amount = jcp.mb * nb_groups * oc_chunks * jcp.oh * jcp.nb_ow;
 
@@ -80,20 +95,24 @@ execute_forward()
         size_t dst_h_stride = dst_d.blk_off(0, 0, 1);
         size_t wht_h_stride = wht_blk_off(weights_d, 0, 0, 0, 1);
 
-        int n{ 0 }, gb{ 0 }, occ{ 0 }, oh_s{ 0 }, owb{ 0 };
+        int n{ 0 }, gg{ 0 }, occ{ 0 }, oh_s{ 0 }, owb{ 0 };
         if (jcp.loop_order == loop_cwgn)
-            nd_iterator_init(start, occ, oc_chunks, owb, jcp.nb_ow, gb,
+            nd_iterator_init(start, occ, oc_chunks, owb, jcp.nb_ow, gg,
                     nb_groups, n, jcp.mb, oh_s, jcp.oh);
         else if (jcp.loop_order == loop_gncw)
-            nd_iterator_init(start, gb, nb_groups, n, jcp.mb, occ, oc_chunks,
+            nd_iterator_init(start, gg, nb_groups, n, jcp.mb, occ, oc_chunks,
                     owb, jcp.nb_ow, oh_s, jcp.oh);
         else if (jcp.loop_order == loop_ngcw)
-            nd_iterator_init(start, n, jcp.mb, gb, nb_groups, occ, oc_chunks,
+            nd_iterator_init(start, n, jcp.mb, gg, nb_groups, occ, oc_chunks,
                     owb, jcp.nb_ow, oh_s, jcp.oh);
+        else if (jcp.loop_order == loop_nhwcg)
+            nd_iterator_init(start, n, jcp.mb, oh_s, jcp.oh, owb, jcp.nb_ow,
+                    occ, oc_chunks, gg, nb_groups);
         else
             assert(!"unsupported loop order");
         while (start < end) {
             int ocb = occ * jcp.nb_oc_blocking;
+            int gb = gg * jcp.nb_ch_blocking;
             int g = gb * group_block;
             int g_oc = (g * jcp.nb_oc + ocb) * jcp.oc_block;
 
@@ -102,6 +121,7 @@ execute_forward()
             int work_rem = end - start;
             int ih_s = -jcp.t_pad + oh_s * jcp.stride_h;
             int oh_e = oh_s + work_rem > jcp.oh ? jcp.oh : oh_s + work_rem;
+            if (jcp.loop_order == loop_nhwcg) oh_e = oh_s + 1; // step instead
             int ow_s = owb * jcp.ow_block;
             int iw_s = ow_s * jcp.stride_w;
 
@@ -115,9 +135,7 @@ execute_forward()
             auto src_w = src + src_d.blk_off(n, g_ic, ih_s, iw_s);
             auto wht_w = weights + wht_blk_off(weights_d, gb, ocb, 0);
 
-            auto scales = (jcp.signed_input && jcp.ver != ver_vnni)
-                ? &local_scales_[jcp.is_oc_scale * g_oc]
-                : &oscales.scales_[jcp.is_oc_scale * g_oc];
+            auto scales = &oscales[jcp.is_oc_scale * g_oc];
 
             for (int oj = oh_s, ij = ih_s; oj < oh_e;
                 ++oj, ij += jcp.stride_h) {
@@ -144,57 +162,48 @@ execute_forward()
                 p.b_overflow = i_b_overflow;
                 p.owb = owb;
 
+                p.oc_off = g_oc * sizeof(float);
+
                 kernel_->jit_ker(&p);
 
                 src_w += src_h_stride * jcp.stride_h;
                 dst_w += dst_h_stride;
             }
             if (jcp.loop_order == loop_cwgn)
-                nd_iterator_jump(start, end, occ, oc_chunks, owb, jcp.nb_ow, gb,
+                nd_iterator_jump(start, end, occ, oc_chunks, owb, jcp.nb_ow, gg,
                         nb_groups, n, jcp.mb, oh_s, jcp.oh);
             else if (jcp.loop_order == loop_gncw)
-                nd_iterator_jump(start, end, gb, nb_groups, n, jcp.mb, occ,
+                nd_iterator_jump(start, end, gg, nb_groups, n, jcp.mb, occ,
                         oc_chunks, owb, jcp.nb_ow, oh_s, jcp.oh);
             else if (jcp.loop_order == loop_ngcw)
-                nd_iterator_jump(start, end, n, jcp.mb, gb, nb_groups, occ,
+                nd_iterator_jump(start, end, n, jcp.mb, gg, nb_groups, occ,
                         oc_chunks, owb, jcp.nb_ow, oh_s, jcp.oh);
+            else if (jcp.loop_order == loop_nhwcg) {
+                ++start;
+                nd_iterator_step(n, jcp.mb, oh_s, jcp.oh, owb, jcp.nb_ow, occ,
+                        oc_chunks, gg, nb_groups);
+            }
             else
                 assert(!"unsupported loop order");
         }
     });
 }
 
-template struct _jit_avx512_core_x8s8s32x_convolution_fwd_t<false,
+template struct jit_avx512_core_x8s8s32x_convolution_fwd_t<
                                                 data_type::s8, data_type::u8>;
-template struct _jit_avx512_core_x8s8s32x_convolution_fwd_t<true,
-                                                data_type::s8, data_type::u8>;
-template struct _jit_avx512_core_x8s8s32x_convolution_fwd_t<false,
-                                                data_type::u8, data_type::u8>;
-template struct _jit_avx512_core_x8s8s32x_convolution_fwd_t<true,
+template struct jit_avx512_core_x8s8s32x_convolution_fwd_t<
                                                 data_type::u8, data_type::u8>;
-template struct _jit_avx512_core_x8s8s32x_convolution_fwd_t<false,
-                                                data_type::s8, data_type::s8>;
-template struct _jit_avx512_core_x8s8s32x_convolution_fwd_t<true,
+template struct jit_avx512_core_x8s8s32x_convolution_fwd_t<
                                                 data_type::s8, data_type::s8>;
-template struct _jit_avx512_core_x8s8s32x_convolution_fwd_t<false,
+template struct jit_avx512_core_x8s8s32x_convolution_fwd_t<
                                                 data_type::u8, data_type::s8>;
-template struct _jit_avx512_core_x8s8s32x_convolution_fwd_t<true,
-                                                data_type::u8, data_type::s8>;
-template struct _jit_avx512_core_x8s8s32x_convolution_fwd_t<false,
-                                                data_type::s8, data_type::s32>;
-template struct _jit_avx512_core_x8s8s32x_convolution_fwd_t<true,
+template struct jit_avx512_core_x8s8s32x_convolution_fwd_t<
                                                 data_type::s8, data_type::s32>;
-template struct _jit_avx512_core_x8s8s32x_convolution_fwd_t<false,
-                                                data_type::u8, data_type::s32>;
-template struct _jit_avx512_core_x8s8s32x_convolution_fwd_t<true,
+template struct jit_avx512_core_x8s8s32x_convolution_fwd_t<
                                                 data_type::u8, data_type::s32>;
-template struct _jit_avx512_core_x8s8s32x_convolution_fwd_t<false,
+template struct jit_avx512_core_x8s8s32x_convolution_fwd_t<
                                                 data_type::s8, data_type::f32>;
-template struct _jit_avx512_core_x8s8s32x_convolution_fwd_t<true,
-                                                data_type::s8, data_type::f32>;
-template struct _jit_avx512_core_x8s8s32x_convolution_fwd_t<false,
-                                                data_type::u8, data_type::f32>;
-template struct _jit_avx512_core_x8s8s32x_convolution_fwd_t<true,
+template struct jit_avx512_core_x8s8s32x_convolution_fwd_t<
                                                 data_type::u8, data_type::f32>;
 }
 }