Publishing 2019 R1 content
[platform/upstream/dldt.git] / inference-engine / thirdparty / mkl-dnn / src / cpu / jit_avx512_core_fp32_wino_conv_2x3.cpp
index 1239186..82a18b6 100644 (file)
@@ -32,6 +32,7 @@ namespace impl {
 namespace cpu {
 
 using namespace mkldnn::impl::memory_format;
+using namespace mkldnn::impl::memory_tracking::names;
 using namespace mkldnn::impl::utils;
 using namespace Xbyak;
 
@@ -247,7 +248,6 @@ bool jit_avx512_core_fp32_wino_conv_2x3_dst_trans_t::maybe_relu(int position) {
     if (position == 0) {
         /* relu before sum */
         return false
-            || jcp.with_relu
             || p.contain(eltwise, 0);
     } else if (position == 1) {
         /* relu after sum */
@@ -411,7 +411,6 @@ struct jit_avx512_core_fp32_wino_conv_2x3_fwd_ker_t: public jit_generator {
             cpu_memory_t::pd_t &src_pd, cpu_memory_t::pd_t &weights_pd,
             cpu_memory_t::pd_t &dst_pd, cpu_memory_t::pd_t &bias_pd,
             const primitive_attr_t &attr,
-            bool with_relu, float relu_negative_slope,
             memory_desc_t& expect_wei_md);
 
     Zmm vreg_out(int n, int m) {
@@ -448,26 +447,14 @@ bool jit_avx512_core_fp32_wino_conv_2x3_fwd_ker_t::post_ops_ok(
     using namespace primitive_kind;
     const auto &p = attr.post_ops_;
 
-    auto is_relu = [&](int idx) {
-        return p.entry_[idx].kind == eltwise
-            && p.entry_[idx].eltwise.scale == 1.
-            && p.entry_[idx].eltwise.alg == alg_kind::eltwise_relu
-            && p.entry_[idx].eltwise.alpha == 0.;
-    };
+    auto is_relu = [&](int idx) { return p.entry_[idx].is_relu(); };
 
-   switch (p.len_) {
+    switch (p.len_) {
     case 0: return true;
-    case 1: return true
-                && IMPLICATION(jcp.with_relu, p.contain(sum, 0))
-                && IMPLICATION(!jcp.with_relu, is_relu(0) || p.contain(sum, 0));
-    case 2: return true
-                && IMPLICATION(jcp.with_relu, p.contain(sum, 0) && is_relu(1))
-                && IMPLICATION(!jcp.with_relu, false
-                        || (p.contain(sum, 0) && is_relu(1))
-                        || (p.contain(sum, 1) && is_relu(0)));
-    case 3: return true
-                && jcp.with_relu == false
-                && (is_relu(0) && p.contain(sum, 1) && is_relu(2));
+    case 1: return is_relu(0) || p.contain(sum, 0);
+    case 2: return (p.contain(sum, 0) && is_relu(1)) ||
+                       (p.contain(sum, 1) && is_relu(0));
+    case 3: return is_relu(0) && p.contain(sum, 1) && is_relu(2);
     default: return false;
     }
 
@@ -577,12 +564,17 @@ void jit_avx512_core_fp32_wino_conv_2x3_fwd_ker_t::generate() {
     postamble();
 }
 
+namespace {
+bool is_winograd_faster_than_direct(const jit_conv_conf_2x3_wino_t &jcp) {
+    return jcp.mb >= 4;
+}
+}
+
 status_t jit_avx512_core_fp32_wino_conv_2x3_fwd_ker_t ::init_conf(
         jit_conv_conf_2x3_wino_t &jcp, const convolution_desc_t &cd,
         cpu_memory_t::pd_t &src_pd, cpu_memory_t::pd_t &wei_pd,
         cpu_memory_t::pd_t &dst_pd, cpu_memory_t::pd_t &bias_pd,
-        const primitive_attr_t &attr, bool with_relu, float relu_negative_slope,
-        memory_desc_t &expect_wei_md) {
+        const primitive_attr_t &attr, memory_desc_t &expect_wei_md) {
     const memory_desc_wrapper src_d(&src_pd);
     const memory_desc_wrapper wei_d(&wei_pd);
     const memory_desc_wrapper dst_d(&dst_pd);
@@ -590,6 +582,8 @@ status_t jit_avx512_core_fp32_wino_conv_2x3_fwd_ker_t ::init_conf(
 
     const bool with_groups = wei_d.ndims() == src_d.ndims() + 1;
 
+    jcp.nthr = mkldnn_get_max_threads();
+
     jcp.ngroups = with_groups ? wei_d.dims()[0] : 1;
     jcp.mb = src_d.dims()[0];
     jcp.oc = dst_d.dims()[1] / jcp.ngroups;
@@ -616,10 +610,7 @@ status_t jit_avx512_core_fp32_wino_conv_2x3_fwd_ker_t ::init_conf(
     int simdw = 16;
     jcp.src_fmt = src_d.format();
     jcp.with_bias = cd.bias_desc.format != memory_format::undef;
-    jcp.with_relu = with_relu;
-    jcp.relu_negative_slope = relu_negative_slope;
-    if (!IMPLICATION(with_relu, relu_negative_slope == 0.))
-        return status::unimplemented;
+
     if (!post_ops_ok(jcp, attr))
         return status::unimplemented;
 
@@ -639,6 +630,10 @@ status_t jit_avx512_core_fp32_wino_conv_2x3_fwd_ker_t ::init_conf(
     if (!(mayiuse(avx512_core)))
         return status::unimplemented;
 
+    if (!IMPLICATION(cd.alg_kind == alg_kind::convolution_auto,
+               is_winograd_faster_than_direct(jcp)))
+        return status::unimplemented;
+
     if (src_d.data_type() != data_type::f32)
         return status::unimplemented;
     if (wei_d.data_type() != data_type::f32)
@@ -673,7 +668,6 @@ status_t jit_avx512_core_fp32_wino_conv_2x3_fwd_ker_t ::init_conf(
     auto wei_sz = (float)aa * ic * oc;
     auto inp_sz = (float)mb * ih * iw * ic;
     auto sp_sz = (float)mb * ih * iw;
-    const int nthr = mkldnn_get_max_threads();
 
     /* Heuristics here. Numbers '28','196' is an observation from data. */
     if (wei_sz / inp_sz > 5)
@@ -681,10 +675,10 @@ status_t jit_avx512_core_fp32_wino_conv_2x3_fwd_ker_t ::init_conf(
     else
         jcp.small_mb = false;
 
-    if (mb > nstl::min(nthr, 28)
+    if (mb > nstl::min(jcp.nthr, 28)
         || (!jcp.small_mb
             && (wei_sz >= 0.9f * L2_cap
-                || inp_sz > L2_cap * nthr + L3_capacity))
+                || inp_sz > L2_cap * jcp.nthr + L3_capacity))
         || (jcp.small_mb && sp_sz > 196))
         return unimplemented;
 
@@ -749,7 +743,7 @@ status_t jit_avx512_core_fp32_wino_conv_2x3_fwd_ker_t ::init_conf(
 
             /* outer parallelization */
             int nblocks = mb * div_up(ih, iy) * div_up(iw, ix);
-            thr_eff = (float)nblocks / rnd_up(nblocks, nthr);
+            thr_eff = (float)nblocks / rnd_up(nblocks, jcp.nthr);
 
             mem_eff = 1.f;
             req_mem = (((float)ix + 2) * (iy + 2) + aa * M) * Z + aa * Y;
@@ -765,14 +759,15 @@ status_t jit_avx512_core_fp32_wino_conv_2x3_fwd_ker_t ::init_conf(
             /* inner parallelization */
             int bsz = iy * ix / a;
             int gemmw = aa * (nb_oc / n2_b);
-            int bsz_r = rnd_up(bsz, nthr);
-            int gemmw_r = rnd_up(gemmw, nthr);
+            int bsz_r = rnd_up(bsz, jcp.nthr);
+            int gemmw_r = rnd_up(gemmw, jcp.nthr);
             thr_eff = ((float)Z * bsz / bsz_r + Y * gemmw / gemmw_r) / (Z + Y);
 
             req_mem = (float)ix * iy * (ic + simdw * n2_b) + simdw * n2_b * ic;
             mem_eff = nstl::min(1.f, L2_cap / req_mem);
-            int M_per_thr = nstl::max(2, div_up(aa, nthr));
-            int oc_per_thr = nstl::min(oc, div_up(aa * (nb_oc / n2_b), nthr));
+            int M_per_thr = nstl::max(2, div_up(aa, jcp.nthr));
+            int oc_per_thr =
+                nstl::min(oc, div_up(aa * (nb_oc / n2_b), jcp.nthr));
             req_mem = (float)aa * oc_per_thr * ic + M_per_thr * M * Z;
             if (req_mem > L2_cap)
                 mem_eff = 0.1f;
@@ -839,63 +834,34 @@ status_t jit_avx512_core_fp32_wino_conv_2x3_fwd_ker_t ::init_conf(
 }
 ////////////////////////////////////////////////////////////////////////////////
 
-template <bool with_relu>
-status_t _jit_avx512_core_fp32_wino_conv_2x3_fwd_t<with_relu>
+status_t jit_avx512_core_fp32_wino_conv_2x3_fwd_t
     ::pd_t::jit_conf(memory_desc_t& expect_wei_md) {
     return jit_avx512_core_fp32_wino_conv_2x3_fwd_ker_t::init_conf(
-            jcp_, this->cdesc_(), this->src_pd_, this->weights_pd_,
-            this->dst_pd_,this->bias_pd_, *this->attr(),
-            with_relu, this->negative_slope(), expect_wei_md);
+            jcp_, *this->desc(), this->src_pd_, this->weights_pd_,
+            this->dst_pd_,this->bias_pd_, *this->attr(), expect_wei_md);
 }
 
-template <bool with_relu>
-_jit_avx512_core_fp32_wino_conv_2x3_fwd_t<with_relu>::
-        _jit_avx512_core_fp32_wino_conv_2x3_fwd_t(const pd_t *pd,
+jit_avx512_core_fp32_wino_conv_2x3_fwd_t::
+        jit_avx512_core_fp32_wino_conv_2x3_fwd_t(const pd_t *apd,
                 const input_vector &inputs, const output_vector &outputs)
-    : cpu_primitive_t(&conf_, inputs, outputs)
-    , conf_(*pd), padded_bias_(nullptr) {
-    const int nthreads = mkldnn_get_max_threads();
+    : cpu_primitive_t(apd, inputs, outputs)
+{
     kernel_ = new jit_avx512_core_fp32_wino_conv_2x3_fwd_ker_t(
-            conf_.jcp_, *conf_.attr());
+            pd()->jcp_, *pd()->attr());
     src_trans_ = new jit_avx512_core_fp32_wino_conv_2x3_src_trans_t(
-            conf_.jcp_, *conf_.attr());
+            pd()->jcp_, *pd()->attr());
     dst_trans_ = new jit_avx512_core_fp32_wino_conv_2x3_dst_trans_t(
-            conf_.jcp_, *conf_.attr());
-
-    int wino_size_offset
-            = (conf_.jcp_.yb / 2) * (conf_.jcp_.xb / 2) + (conf_.jcp_.xb);
-
-    size_wino_src = (conf_.jcp_.ic * 16) * (wino_size_offset);
-    size_wino_dst = (conf_.jcp_.oc * 16) * (wino_size_offset);
-
-    wino_src_ = (float *)malloc(sizeof(float) * nthreads * size_wino_src, 4096);
-    wino_dst_ = (float *)malloc(sizeof(float) * nthreads * size_wino_dst, 4096);
-    if (conf_.want_padded_bias()) {
-        const auto &j = conf_.jcp_;
-        assert(j.ngroups == 1);
-        padded_bias_ = (float *)malloc(sizeof(float) * j.oc, 64);
-        for (int oc = j.oc_without_padding; oc < j.oc; ++oc)
-            padded_bias_[oc] = 0;
-    }
-
-
+            pd()->jcp_, *pd()->attr());
 }
 
-template <bool with_relu>
-_jit_avx512_core_fp32_wino_conv_2x3_fwd_t<with_relu>
-    ::~_jit_avx512_core_fp32_wino_conv_2x3_fwd_t() {
+jit_avx512_core_fp32_wino_conv_2x3_fwd_t
+    ::~jit_avx512_core_fp32_wino_conv_2x3_fwd_t() {
     delete kernel_;
     delete src_trans_;
     delete dst_trans_;
-
-    free(wino_src_);
-    free(wino_dst_);
-    free(padded_bias_);
 }
 
-template <bool with_relu>
-void _jit_avx512_core_fp32_wino_conv_2x3_fwd_t<
-        with_relu>::execute_forward() {
+void jit_avx512_core_fp32_wino_conv_2x3_fwd_t::execute_forward() const {
     const auto &jcp = kernel_->jcp;
 
     if (jcp.small_mb)
@@ -904,33 +870,41 @@ void _jit_avx512_core_fp32_wino_conv_2x3_fwd_t<
         execute_forward_mbN();
 }
 
-template <bool with_relu>
-void _jit_avx512_core_fp32_wino_conv_2x3_fwd_t<with_relu>
-::execute_forward_mbN() {
+void jit_avx512_core_fp32_wino_conv_2x3_fwd_t::execute_forward_mbN() const {
     auto src = reinterpret_cast<const float *>(input_memory(0));
     auto wei = reinterpret_cast<const float *>(input_memory(1));
     auto bia = reinterpret_cast<const float *>(input_memory(2));
     auto dst = reinterpret_cast<float *>(memory(0));
 
-    const auto &jcp = kernel_->jcp;
-    const auto &oscales = conf_.attr()->output_scales_;
-
-    wino_wei_ = wei;
+    auto scratchpad = this->scratchpad();
 
-    if (conf_.want_padded_bias()) {
-        for (int oc = 0; oc < jcp.oc_without_padding; ++oc)
-            padded_bias_[oc] = bia[oc];
-        bia = padded_bias_;
+    const auto &jcp = kernel_->jcp;
+    const auto &oscales = pd()->attr()->output_scales_;
+
+    const size_t wino_size_offset =
+        (size_t)(pd()->jcp_.yb / 2) * (pd()->jcp_.xb / 2) + (pd()->jcp_.xb);
+    const size_t size_wino_src = wino_size_offset * pd()->jcp_.ic * 16;
+    const size_t size_wino_dst = wino_size_offset * pd()->jcp_.oc * 16;
+
+    if (pd()->wants_padded_bias()) {
+        auto padded_bias = scratchpad.get<float>(key_conv_padded_bias);
+        utils::array_copy(padded_bias, bia, jcp.oc_without_padding);
+        utils::array_set(padded_bias + jcp.oc_without_padding, 0.f,
+                jcp.oc - jcp.oc_without_padding);
+        bia = padded_bias;
     }
 
+    auto ptr_V = scratchpad.get<float>(key_wino_V);
+    auto ptr_M = scratchpad.get<float>(key_wino_M);
+
     parallel_nd(jcp.mb, div_up(jcp.oh,jcp.yb), div_up(jcp.ow, jcp.xb),
         [&](int mb, int tile_y_b, int tile_x_b) {
         int tile_y = tile_y_b * jcp.yb;
         int tile_x = tile_x_b * jcp.xb;
 
         int ithr = mkldnn_get_thread_num();
-        auto wino_src = wino_src_ + size_wino_src * ithr;
-        auto wino_dst = wino_dst_ + size_wino_dst * ithr;
+        auto wino_src = ptr_V + size_wino_src * ithr;
+        auto wino_dst = ptr_M + size_wino_dst * ithr;
 
         auto src_trans_p =
             jit_avx512_core_fp32_wino_conv_2x3_src_trans_t
@@ -985,7 +959,7 @@ void _jit_avx512_core_fp32_wino_conv_2x3_fwd_t<with_relu>
             int offset = (tile_ij + ithr) % 16;
             gemm_p.src = wino_src + jcp.inp_stride * offset;
             gemm_p.dst = wino_dst + jcp.out_stride * offset;
-            gemm_p.wei = wino_wei_ + jcp.wei_stride * offset;
+            gemm_p.wei = wei + jcp.wei_stride * offset;
 
             kernel_->ker_(&gemm_p);
         }
@@ -1027,25 +1001,29 @@ void _jit_avx512_core_fp32_wino_conv_2x3_fwd_t<with_relu>
     });
 }
 
-template <bool with_relu>
-void _jit_avx512_core_fp32_wino_conv_2x3_fwd_t<with_relu>
-    ::execute_forward_small_mb() {
+void jit_avx512_core_fp32_wino_conv_2x3_fwd_t::execute_forward_small_mb() const
+{
     auto src = reinterpret_cast<const float *>(input_memory(0));
     auto wei = reinterpret_cast<const float *>(input_memory(1));
     auto bia = reinterpret_cast<const float *>(input_memory(2));
     auto dst = reinterpret_cast<float *>(memory(0));
 
-    const auto &jcp = kernel_->jcp;
-    const auto &oscales = conf_.attr()->output_scales_;
-
-    wino_wei_ = wei;
+    auto scratchpad = this->scratchpad();
 
-    if (conf_.want_padded_bias()) {
-        for (int oc = 0; oc < jcp.oc_without_padding; ++oc)
-            padded_bias_[oc] = bia[oc];
-        bia = padded_bias_;
+    const auto &jcp = kernel_->jcp;
+    const auto &oscales = pd()->attr()->output_scales_;
+
+    if (pd()->wants_padded_bias()) {
+        auto padded_bias = scratchpad.get<float>(key_conv_padded_bias);
+        utils::array_copy(padded_bias, bia, jcp.oc_without_padding);
+        utils::array_set(padded_bias + jcp.oc_without_padding, 0.f,
+                jcp.oc - jcp.oc_without_padding);
+        bia = padded_bias;
     }
 
+    auto ptr_V = scratchpad.get<float>(key_wino_V);
+    auto ptr_M = scratchpad.get<float>(key_wino_M);
+
     for (int mb = 0; mb < jcp.mb; mb++) {
     for (int tile_y = 0; tile_y < jcp.oh; tile_y += jcp.yb) {
     for (int tile_x = 0; tile_x < jcp.ow; tile_x += jcp.xb) {
@@ -1080,7 +1058,7 @@ void _jit_avx512_core_fp32_wino_conv_2x3_fwd_t<with_relu>
             auto local_s = src
                     + mb * jcp.nb_ic * jcp.ih * jcp.iw * jcp.ic_block
                     + y * jcp.iw * jcp.ic_block + x * jcp.ic_block;
-            auto local_w = wino_src_ + m * jcp.ic;
+            auto local_w = ptr_V + m * jcp.ic;
 
             src_trans_p.src = local_s;
             src_trans_p.wino_src = local_w;
@@ -1095,10 +1073,10 @@ void _jit_avx512_core_fp32_wino_conv_2x3_fwd_t<with_relu>
             auto gemm_p = jit_avx512_core_fp32_wino_conv_2x3_fwd_ker_t ::
                     call_params_t();
 
-            gemm_p.src = wino_src_ + jcp.inp_stride * tile_ij;
-            gemm_p.dst = wino_dst_ + jcp.out_stride * tile_ij
+            gemm_p.src = ptr_V + jcp.inp_stride * tile_ij;
+            gemm_p.dst = ptr_M + jcp.out_stride * tile_ij
                     + nnb * jcp.n2_block * jcp.n_block;
-            gemm_p.wei = wino_wei_ + jcp.wei_stride * tile_ij
+            gemm_p.wei = wei + jcp.wei_stride * tile_ij
                     + nnb * jcp.n2_block * jcp.n_block * jcp.K;
 
             kernel_->ker_(&gemm_p);
@@ -1128,7 +1106,7 @@ void _jit_avx512_core_fp32_wino_conv_2x3_fwd_t<with_relu>
             auto local_d = dst
                     + mb * jcp.nb_oc * jcp.oh * jcp.ow * jcp.oc_block
                     + y * jcp.ow * jcp.oc_block + x * jcp.oc_block;
-            auto local_w = wino_dst_ + m * jcp.oc;
+            auto local_w = ptr_M + m * jcp.oc;
 
             auto scales = oscales.scales_;
             dst_trans_p.dst = local_d;
@@ -1144,9 +1122,6 @@ void _jit_avx512_core_fp32_wino_conv_2x3_fwd_t<with_relu>
     }}}
 }
 
-template struct _jit_avx512_core_fp32_wino_conv_2x3_fwd_t<true>;
-template struct _jit_avx512_core_fp32_wino_conv_2x3_fwd_t<false>;
-
 } // namespace cpu
 } // namespace impl
 } // namespace mkldnn