Publishing 2019 R1 content
[platform/upstream/dldt.git] / inference-engine / thirdparty / mkl-dnn / src / cpu / jit_avx512_common_1x1_conv_kernel.cpp
index 30f1823..bdfee81 100644 (file)
 * See the License for the specific language governing permissions and
 * limitations under the License.
 *******************************************************************************/
+
+#include <assert.h>
 #include <float.h>
+
 #include "c_types_map.hpp"
+#include "memory_tracking.hpp"
+#include "mkldnn_thread.hpp"
 #include "nstl.hpp"
 #include "type_helpers.hpp"
-#include "mkldnn_thread.hpp"
 #include "utils.hpp"
+
 #include "cpu_memory.hpp"
+#include "cpu_barrier.hpp"
 
 #include "jit_uni_1x1_conv_utils.hpp"
 #include "jit_avx512_common_1x1_conv_kernel.hpp"
@@ -257,14 +263,23 @@ void jit_avx512_common_1x1_conv_kernel::reduce_loop(int load_loop_blk,
         int depthwise_inj_idx = 0;
         const auto &p = attr_.post_ops_;
 
-        if (p.len_ == 0 && eltwise_injectors.size() == 1) {
-            eltwise_injectors[0]->compute_vector_range(0, ur * load_loop_blk);
-        }
-
         for (int i = 0; i < p.len_; i++) {
             auto& post_op = p.entry_[i];
             if (post_op.is_eltwise()) {
-                eltwise_injectors[eltwise_inj_idx]->compute_vector_range(0, ur * load_loop_blk);
+                if (jcp.ver == ver_4vnni) {
+                    zmm_t zmm_zero = vreg_bcast;
+                    vpxord(zmm_zero, zmm_zero, zmm_zero);
+
+                    for (int i_ur = 0; i_ur < ur; ++i_ur) {
+                        for (int i_load = 0; i_load < load_loop_blk; ++i_load) {
+                            Zmm zmm = vreg_accum(i_load, i_ur);
+                            vpcmpd(k1, zmm, zmm_zero, _cmp_lt_os);
+                            vpmulld(zmm | k1, zmm, zmm_zero);
+                        }
+                    }
+                } else {
+                    eltwise_injectors[eltwise_inj_idx]->compute_vector_range(0, ur * load_loop_blk);
+                }
                 eltwise_inj_idx++;
             } else if (post_op.is_depthwise()) {
                 mov(reg_d_weights, reinterpret_cast<size_t>(post_op.depthwise.weights_data));
@@ -502,12 +517,6 @@ void jit_avx512_common_1x1_conv_kernel::reduce_loop(int load_loop_blk,
 
 void jit_avx512_common_1x1_conv_kernel::generate()
 {
-    if (jcp.with_eltwise) {
-        eltwise_injectors.push_back(new jit_uni_eltwise_injector_f32<avx512_common>(
-                this, jcp.eltwise_alg, jcp.eltwise_alpha, 0
-        ));
-    }
-
     const auto &p = attr_.post_ops_;
     for (int i = 0; i < p.len_; i++) {
         auto &post_op = p.entry_[i];
@@ -542,6 +551,8 @@ void jit_avx512_common_1x1_conv_kernel::generate()
     mov(EVEX_compress_addr(rsp, bcast_loop_work_offt), reg_bcast_loop_work);
     mov(reg_reduce_loop_work, ptr[param1 + GET_OFF(reduce_dim)]);
     mov(reg_reduce_pos_flag, ptr[param1 + GET_OFF(first_last_flag)]);
+    if (one_of(jcp.prop_kind, forward_training, forward_inference))
+        mov(reg_relu_ns, reinterpret_cast<size_t>(&jcp.eltwise.alpha));
     if (jcp.prop_kind == backward_weights)
         mov(reg_output_stride, ptr[param1 + GET_OFF(output_stride)]);
     mov(reg_oc_off, ptr[param1 + GET_OFF(oc_off)]);
@@ -653,30 +664,20 @@ bool jit_avx512_common_1x1_conv_kernel::post_ops_ok(
     auto is_simple = [&](int idx) { return is_eltwise(idx) || is_depthwise(idx); };
 
     switch (p.len_) {
-    case 0: return true; // no post_ops
-    case 1:
-        return true // sum OR eltwise OR depthwise
-                && !jcp.with_eltwise && (is_simple(0) || is_sum(0));
-    case 2:
-        return true // sum->relu
-                && !jcp.with_eltwise && ((is_sum(0) && is_simple(1)) ||
-                                         (is_simple(0) && is_simple(1)));
-    case 3:
-        return true // sum->relu
-                && !jcp.with_eltwise && (is_sum(0) && is_simple(1) && is_simple(2));
+    case 0: return true;
+    case 1: return is_simple(0) || is_sum(0);
+    case 2: return (is_sum(0) && is_simple(1)) || (is_simple(0) && is_simple(1));
+    case 3: return is_sum(0) && is_simple(1) && is_simple(2);
     default: return false;
     }
 
     return false;
 }
 
-status_t jit_avx512_common_1x1_conv_kernel::init_conf(
-        jit_1x1_conv_conf_t &jcp, const convolution_desc_t &cd,
-        const memory_desc_wrapper &src_d, const memory_desc_wrapper &weights_d,
-        const memory_desc_wrapper &dst_d, const primitive_attr_t &attr,
-        bool with_relu, float relu_negative_slope,
-        int nthreads, bool reduce_src)
-{
+status_t jit_avx512_common_1x1_conv_kernel::init_conf(jit_1x1_conv_conf_t &jcp,
+        const convolution_desc_t &cd, const memory_desc_wrapper &src_d,
+        const memory_desc_wrapper &weights_d, const memory_desc_wrapper &dst_d,
+        const primitive_attr_t &attr, int nthreads, bool reduce_src) {
     if (!mayiuse(avx512_common)) return status::unimplemented;
 
     const bool with_groups = weights_d.ndims() == src_d.ndims() + 1;
@@ -715,11 +716,9 @@ status_t jit_avx512_common_1x1_conv_kernel::init_conf(
     jcp.stride_w = cd.strides[ndims - 3];
 
     jcp.src_fmt = src_d.format();
-    jcp.with_bias = one_of(jcp.prop_kind, forward_training, forward_inference)
-        ? cd.bias_desc.format != memory_format::undef : false;
-    jcp.with_eltwise = with_relu;
-    jcp.eltwise_alg = mkldnn_eltwise_relu;
-    jcp.eltwise_alpha = relu_negative_slope;
+    jcp.with_bias = pick_by_prop_kind(jcp.prop_kind, cd.bias_desc.format,
+            memory_format::undef, cd.diff_bias_desc.format)
+        != memory_format::undef;
 
     jcp.os = jcp.oh * jcp.ow;
     jcp.is = jcp.ih * jcp.iw;
@@ -730,6 +729,12 @@ status_t jit_avx512_common_1x1_conv_kernel::init_conf(
 
     const auto &p = attr.post_ops_;
     jcp.with_sum = p.find(primitive_kind::sum) != -1;
+    const int eltwise_ind = p.find(primitive_kind::eltwise);
+    jcp.with_eltwise = eltwise_ind != -1;
+    if (jcp.with_eltwise) {
+        jcp.eltwise = p.entry_[eltwise_ind].eltwise;
+        if (dst_d.data_type() == data_type::s32) return status::unimplemented;
+    }
 
     bool args_ok = true
         && jcp.ngroups == 1
@@ -894,9 +899,7 @@ status_t jit_avx512_common_1x1_conv_kernel::init_conf(
         } else {
             bool is4ops = (jcp.ver == ver_4fma || jcp.ver == ver_4vnni);
 
-//            max_regs = is4ops ? 28 : 30;
-            // FIXME (ichuraev): it is a fix for densnet-121
-            max_regs = 28;
+            max_regs = is4ops ? 28 : 30;
             min_regs = 9;
             size_treshold = is4ops ? 28 : 14;
             ur_step = is4ops ? 4 : 1;
@@ -1062,6 +1065,48 @@ status_t jit_avx512_common_1x1_conv_kernel::init_conf(
             load_blocking = jcp.load_block;
         }
 
+        if (jcp.ver == ver_4fma && jcp.bcast_dim * jcp.mb < jcp.load_dim
+                && jcp.oh * jcp.ow > 64
+                && IMPLICATION(reduce_src, jcp.load_dim < 1024)) {
+            /* Looking for best loading dimension blocking
+            * to get the best thread and data read/write efficiency
+            * by finding the optimal 'load_chunk' value
+            * Example:
+            * for 72 threads and convolution with mb=1, ih=iw=7, oc = 512
+            * the 'best' load_chunk value should be 1
+            * TODO: remove heuristic constants in above condition
+            * TODO: check this blocking for other ISA
+            */
+            float best_eff = -1.f;
+            int best_lgc = 1;
+
+            for (int load_chunk = 1; load_chunk <= nb_load; load_chunk++) {
+                int lgc = div_up(nb_load, load_chunk);
+                if (lgc > nthreads)
+                    continue;
+                int thr_per_grp = div_up(nthreads, lgc);
+                int bcast_per_thr = div_up(jcp.mb * nb_bcast, thr_per_grp)
+                        * jcp.bcast_block;
+                int load_per_thr = load_chunk * simd_w;
+                float data_norm = (bcast_per_thr + load_per_thr) / 2.f;
+                float data_eff = (bcast_per_thr * load_per_thr)
+                        / (data_norm * data_norm);
+                float thr_eff_over_grp = (float)nstl::max(1, nthreads / lgc)
+                        / div_up(nthreads, lgc);
+                float thr_eff_in_grp = ((float)jcp.mb * nb_bcast)
+                        / rnd_up(jcp.mb * nb_bcast, thr_per_grp);
+                float thr_eff = thr_eff_over_grp * thr_eff_in_grp;
+                float load_eff = (float)nb_load / rnd_up(nb_load, lgc);
+                float overall_eff = data_eff + thr_eff + load_eff;
+                if (overall_eff > best_eff) {
+                    best_eff = overall_eff;
+                    best_lgc = lgc;
+                }
+            }
+            jcp.load_grp_count = best_lgc;
+            load_blocking
+                    = div_up(nb_load, jcp.load_grp_count) * jcp.load_block;
+        }
         bcast_blocking = div_up(jcp.mb * jcp.ngroups * nb_bcast,
                                  div_up(nthreads, jcp.load_grp_count))
                 * jcp.bcast_block;
@@ -1230,6 +1275,30 @@ status_t jit_avx512_common_1x1_conv_kernel::init_conf(
     return status::success;
 }
 
+void jit_avx512_common_1x1_conv_kernel::init_scratchpad(
+        memory_tracking::registrar_t &scratchpad,
+        const jit_1x1_conv_conf_t &jcp) {
+    using namespace mkldnn::impl::memory_tracking::names;
+
+    if (jcp.prop_kind != backward_data && jcp.with_bias
+            && jcp.oc != jcp.oc_without_padding)
+        scratchpad.book(key_conv_padded_bias, jcp.typesize_out * jcp.oc);
+
+    if (jcp.prop_kind == backward_weights) {
+        const size_t wei_size = (size_t)jcp.ngroups * jcp.oc * jcp.ic;
+        scratchpad.book(key_conv_wei_reduction,
+                jcp.typesize_out * wei_size * (jcp.nthr_mb - 1));
+    }
+
+    if (jcp.transpose_src) {
+        const size_t tr_src_size =
+            (size_t)jcp.nthr_mb * jcp.ngroups * jcp.ic * jcp.tr_is;
+        scratchpad.book(key_conv_tr_src, jcp.typesize_out * tr_src_size);
+        scratchpad.book(key_conv_tr_src_bctx,
+                sizeof(simple_barrier::ctx_t) * jcp.nthr);
+    }
+}
+
 void jit_avx512_common_1x1_conv_kernel::balance(jit_1x1_conv_conf_t &jcp,
         int nthreads)
 {