Publishing 2019 R1 content
[platform/upstream/dldt.git] / inference-engine / thirdparty / mkl-dnn / src / cpu / jit_avx512_core_u8s8s32x_wino_convolution.cpp
index 45f516c..1377290 100644 (file)
@@ -17,6 +17,7 @@
 #include <assert.h>
 
 #include "c_types_map.hpp"
+#include "memory_tracking.hpp"
 #include "cpu_convolution_pd.hpp"
 #include "cpu_engine.hpp"
 #include "mkldnn_thread.hpp"
@@ -33,6 +34,7 @@ namespace impl {
 namespace cpu {
 
 using namespace mkldnn::impl::memory_format;
+using namespace mkldnn::impl::memory_tracking::names;
 using namespace mkldnn::impl::utils;
 using namespace Xbyak;
 
@@ -100,7 +102,6 @@ struct jit_avx512_core_u8s8s32x_wino_conv_src_trans_t: public jit_generator {
         return Opmask(3 + id);
     }
 
-    Reg64 reg_ptr_offset = r15;
     Reg64 reg_ptr_src = r14;
     Reg64 reg_ptr_dst = r13;
 
@@ -117,12 +118,49 @@ struct jit_avx512_core_u8s8s32x_wino_conv_src_trans_t: public jit_generator {
     Reg64 reg_scratch_src_alpha = rdx;
     Xmm xmm_src_alpha = Xmm(0);
     Zmm zmm_src_alpha = Zmm(0);
+
+    Reg64 reg_shift = rax;
+    Xmm xmm_shift = Xmm(1);
+    Xmm xmm_zero = Xmm(0);
+
+    Reg64 reg_maskx = rbx;
+    Reg64 reg_masky = rsi;
+    Reg64 reg_nomask = reg_maskx;
 };
 
 void jit_avx512_core_u8s8s32x_wino_conv_src_trans_t::generate() {
     Label ic_block_label;
+    Label end_label;
+    Label mask_label;
+    Label nomask_label;
+
+    auto load_src = [=](bool mask) {
+        for (int y = 0; y < jcp.alpha; y++) {
+            if (mask)
+                kmovw(y_mask, ptr[reg_ptr_v_y_masks + sizeof(uint16_t) * y]);
+            for (int x = 0; x < jcp.alpha; x++) {
+                Zmm zmm_i = zmm_inp(y * jcp.alpha + x);
+                Xmm vreg_i = vreg_inp(y * jcp.alpha + x);
+                int inp_offset = sizeof(uint8_t)
+                        * ((-jcp.t_pad + y) * jcp.iw * jcp.ic
+                                + (-jcp.l_pad + x) * jcp.ic);
+                if (mask) {
+                    kandw(r_mask, y_mask, x_mask(x));
+                    vmovdqu8(vreg_i | r_mask | T_z,
+                            EVEX_compress_addr(reg_aux_ptr_src, inp_offset));
+                } else {
+                    vmovdqu8(vreg_i,
+                            EVEX_compress_addr(reg_aux_ptr_src, inp_offset));
+                }
+                vpmovzxbd(zmm_i, vreg_i); // to int32
+                vcvtdq2ps(zmm_i, zmm_i); // to fp32
+                vmulps(zmm_i, zmm_i, zmm_src_alpha); // *alpha
+                vcvtps2dq(zmm_i | T_rn_sae, zmm_i); // to int32
+                vpmovusdb(vreg_i, zmm_i); // to u8
+            }
+        }
+    };
 
-    int out_offset = 0, inp_offset = 0;
     preamble();
 
 #   define READ_PARAM(reg, field) \
@@ -133,14 +171,24 @@ void jit_avx512_core_u8s8s32x_wino_conv_src_trans_t::generate() {
     READ_PARAM(reg_ptr_v_x_masks, v_x_masks);
 #   undef READ_PARAM
 
-    xor_(eax, eax);
-    mov(ax, (int8_t)-128);
+    mov(reg_maskx, ptr[reg_ptr_v_x_masks]);
+    mov(reg_masky, ptr[reg_ptr_v_y_masks]);
+    test(reg_maskx, reg_maskx);
+    jz(end_label, T_NEAR); // skip kernel if x mask is all 0's
+    test(reg_masky, reg_masky);
+    jz(end_label, T_NEAR); // skip kernel if y mask is all 0's
+    and_(reg_maskx, reg_masky);
+    mov(reg_nomask, reg_maskx);
+    not_(reg_nomask); // zero if x and y masks are all 1's
+
+    xor_(reg_shift, reg_shift);
+    mov(reg_shift.cvt8(), (int8_t)-128);
 
     mov(reg_aux_ptr_src, reg_ptr_src);
     mov(reg_aux_ptr_dst, reg_ptr_dst);
 
     for (int i = 0; i < jcp.alpha; i++) {
-        kmovw(x_mask(i), ptr[reg_ptr_v_x_masks + sizeof(int16_t) * i]);
+        kmovw(x_mask(i), ptr[reg_ptr_v_x_masks + sizeof(uint16_t) * i]);
     }
 
     mov(reg_scratch_src_alpha, float2int(adj_src_scale));
@@ -151,24 +199,14 @@ void jit_avx512_core_u8s8s32x_wino_conv_src_trans_t::generate() {
         vmovq(xmm_src_alpha, reg_scratch_src_alpha);
         vbroadcastss(zmm_src_alpha, xmm_src_alpha);
 
-        for(int y = 0; y < jcp.alpha; y++) {
-            kmovw(y_mask, ptr[reg_ptr_v_y_masks + sizeof(int16_t) * y]);
-            for(int x = 0; x < jcp.alpha; x++) {
-                Zmm zmm_i = zmm_inp(y*jcp.alpha + x);
-                Xmm vreg_i = vreg_inp(y*jcp.alpha + x);
-                vpxord(vreg_i, vreg_i, vreg_i);
-                kandw(r_mask, y_mask, x_mask(x));
-                inp_offset = sizeof(uint8_t) *
-                   ((-jcp.t_pad + y) * jcp.iw * jcp.ic
-                        + (-jcp.l_pad + x) * jcp.ic);
-                vmovdqu8(vreg_i | r_mask, EVEX_compress_addr(reg_aux_ptr_src, inp_offset));
-                vpmovzxbd(zmm_i, vreg_i); // to int32
-                vcvtdq2ps(zmm_i, zmm_i); // to fp32
-                vmulps(zmm_i, zmm_i, zmm_src_alpha); // *alpha
-                vcvtps2dq(zmm_i | T_rn_sae, zmm_i); // to int32
-                vpmovusdb(vreg_i, zmm_i); // to u8
-            }
-        }
+        test(reg_nomask, reg_nomask);
+        jz(nomask_label, T_NEAR);
+        load_src(true);
+        jmp(mask_label, T_NEAR);
+        L(nomask_label);
+        load_src(false);
+        L(mask_label);
+
         for(int y = 0; y < 4; y++) {
             vpsubb(vreg_tmp(y*4+0), vreg_inp(y*4+0), vreg_inp(y*4+2));
             vpaddb(vreg_tmp(y*4+1), vreg_inp(y*4+1), vreg_inp(y*4+2));
@@ -182,12 +220,12 @@ void jit_avx512_core_u8s8s32x_wino_conv_src_trans_t::generate() {
             vpsubb(vreg_out(x+3*4), vreg_tmp(x+4*1), vreg_tmp(x+4*3));
         }
 
-        movd(Xmm(1), eax);
-        pxor(Xmm(0), Xmm(0));
-        pshufb(Xmm(1), Xmm(0));
+        vmovd(xmm_shift, reg_shift.cvt32());
+        vpxor(xmm_zero, xmm_zero, xmm_zero);
+        vpshufb(xmm_shift, xmm_shift, xmm_zero);
 
         for (int i = 0; i < 16; i++) {
-            out_offset = sizeof(uint8_t) * (jcp.inp_stride * i);
+            int out_offset = sizeof(uint8_t) * (jcp.inp_stride * i);
             if (i != unsign_val_in_wino_domain)
                 vpsubb(vreg_out(i), vreg_out(i), Xmm(1));
             vmovups(EVEX_compress_addr(reg_aux_ptr_dst, out_offset), vreg_out(i));
@@ -199,6 +237,7 @@ void jit_avx512_core_u8s8s32x_wino_conv_src_trans_t::generate() {
     dec(reg_ic_block);
     jnz(ic_block_label, T_NEAR);
 
+    L(end_label);
     postamble();
 }
 
@@ -294,7 +333,6 @@ bool jit_avx512_core_u8s8s32x_wino_conv_dst_trans_t::maybe_relu(int position) {
     if (position == 0) {
         /* relu before sum */
         return false
-            || jcp.with_relu
             || p.contain(eltwise, 0)
             || (jcp.dst_dt == data_type::u8 && !p.contain(sum, 0));
     } else if (position == 1) {
@@ -362,7 +400,7 @@ void jit_avx512_core_u8s8s32x_wino_conv_dst_trans_t::generate() {
             vmulps(vreg_bias, vreg_bias, zmm_bias_alpha); // *alpha
         }
         for(int y = 0; y < jcp.m; y++) {
-            kmovw(y_mask, ptr[ reg_ptr_v_y_masks + sizeof(int16_t) * y ]);
+            kmovw(y_mask, ptr[ reg_ptr_v_y_masks + sizeof(uint16_t) * y ]);
             for(int x = 0; x < jcp.m; x++) {
                 kandw(r_mask, y_mask, x_mask(x));
 
@@ -442,11 +480,9 @@ void jit_avx512_core_u8s8s32x_wino_conv_dst_trans_t::generate() {
     mov(reg_aux_ptr_dst, reg_ptr_dst);
 
     vpxord(vreg_zero, vreg_zero, vreg_zero);
-    for (int i = 0; i < jcp.alpha * jcp.alpha; i++)
-        vpxord(vreg_inp(i), vreg_inp(i), vreg_inp(i));
 
-    for (int i = 0; i < jcp.alpha; i++)
-        kmovw(x_mask(i), ptr[reg_ptr_v_x_masks + sizeof(int16_t) * i]);
+    for (int i = 0; i < jcp.m; i++)
+        kmovw(x_mask(i), ptr[reg_ptr_v_x_masks + sizeof(uint16_t) * i]);
 
     int oc_blocks = jcp.oc / load_block;
     mov(reg_oc_block, oc_blocks);
@@ -461,9 +497,6 @@ void jit_avx512_core_u8s8s32x_wino_conv_dst_trans_t::generate() {
     dec(reg_oc_block);
     jnz(oc_block_label, T_NEAR);
 
-    sub(reg_ptr_scales, jcp.is_oc_scale *  sizeof(float) * load_block);
-    sub(reg_ptr_bias, oc_blocks * sizeof(jcp.typesize_bia) * load_block);
-
     postamble();
 
 }
@@ -498,8 +531,7 @@ struct jit_avx512_core_u8s8s32x_wino_conv_fwd_ker_t: public jit_generator {
             jit_conv_conf_2x3_wino_t &jcp, const convolution_desc_t &cd,
             cpu_memory_t::pd_t &src_pd, cpu_memory_t::pd_t &weights_pd,
             cpu_memory_t::pd_t &dst_pd, cpu_memory_t::pd_t &bias_pd,
-            const primitive_attr_t &attr,
-            bool with_relu, float relu_negative_slope);
+            const primitive_attr_t &attr);
 
     Zmm vreg_out(int n, int m) {
         const int id_reg_out = n * jcp.m_block + m;
@@ -536,26 +568,14 @@ bool jit_avx512_core_u8s8s32x_wino_conv_fwd_ker_t::post_ops_ok(
     using namespace primitive_kind;
     const auto &p = attr.post_ops_;
 
-    auto is_relu = [&](int idx) {
-        return p.entry_[idx].kind == eltwise
-            && p.entry_[idx].eltwise.scale == 1.
-            && p.entry_[idx].eltwise.alg == alg_kind::eltwise_relu
-            && p.entry_[idx].eltwise.alpha == 0.;
-    };
+    auto is_relu = [&](int idx) { return p.entry_[idx].is_relu(); };
 
-   switch (p.len_) {
+    switch (p.len_) {
     case 0: return true;
-    case 1: return true
-                && IMPLICATION(jcp.with_relu, p.contain(sum, 0))
-                && IMPLICATION(!jcp.with_relu, is_relu(0) || p.contain(sum, 0));
-    case 2: return true
-                && IMPLICATION(jcp.with_relu, p.contain(sum, 0) && is_relu(1))
-                && IMPLICATION(!jcp.with_relu, false
-                        || (p.contain(sum, 0) && is_relu(1))
-                        || (p.contain(sum, 1) && is_relu(0)));
-    case 3: return true
-                && jcp.with_relu == false
-                && (is_relu(0) && p.contain(sum, 1) && is_relu(2));
+    case 1: return is_relu(0) || p.contain(sum, 0);
+    case 2: return (p.contain(sum, 0) && is_relu(1)) ||
+                       (p.contain(sum, 1) && is_relu(0));
+    case 3: return is_relu(0) && p.contain(sum, 1) && is_relu(2);
     default: return false;
     }
 
@@ -657,13 +677,24 @@ void jit_avx512_core_u8s8s32x_wino_conv_fwd_ker_t::generate() {
 
     postamble();
 }
+namespace {
+bool is_winograd_faster_than_direct(const jit_conv_conf_2x3_wino_t &jcp) {
+    if (jcp.ver == ver_vnni) {
+        return (jcp.mb <= mkldnn_get_max_threads()
+            && (jcp.mb > 4
+                && jcp.ic > 64
+                && !(jcp.oc > 128 && jcp.ih < 14)))
+            || jcp.mb > mkldnn_get_max_threads();
+    }
+    return true;
+}
+}
 
 status_t jit_avx512_core_u8s8s32x_wino_conv_fwd_ker_t
 ::init_conf(jit_conv_conf_2x3_wino_t &jcp,
             const convolution_desc_t &cd, cpu_memory_t::pd_t &src_pd,
             cpu_memory_t::pd_t &wei_pd, cpu_memory_t::pd_t &dst_pd,
-            cpu_memory_t::pd_t &bias_pd, const primitive_attr_t &attr,
-            bool with_relu, float relu_negative_slope) {
+            cpu_memory_t::pd_t &bias_pd, const primitive_attr_t &attr) {
     const memory_desc_wrapper src_d(&src_pd);
     const memory_desc_wrapper wei_d(&wei_pd);
     const memory_desc_wrapper dst_d(&dst_pd);
@@ -671,6 +702,8 @@ status_t jit_avx512_core_u8s8s32x_wino_conv_fwd_ker_t
 
     const bool with_groups = wei_d.ndims() == src_d.ndims() + 1;
 
+    jcp.nthr = mkldnn_get_max_threads();
+
     jcp.ngroups = with_groups ? wei_d.dims()[0] : 1;
     jcp.mb = src_d.dims()[0];
     jcp.oc = dst_d.dims()[1] / jcp.ngroups;
@@ -700,6 +733,10 @@ status_t jit_avx512_core_u8s8s32x_wino_conv_fwd_ker_t
     if (mayiuse(avx512_core_vnni))
         jcp.ver = ver_vnni;
 
+    if (!IMPLICATION(cd.alg_kind == alg_kind::convolution_auto,
+               is_winograd_faster_than_direct(jcp)))
+        return status::unimplemented;
+
     // block sizes needed for GEMM kernel
     jcp.ic_block = 4;
     jcp.oc_block = 16;
@@ -718,10 +755,7 @@ status_t jit_avx512_core_u8s8s32x_wino_conv_fwd_ker_t
 
     jcp.src_fmt = src_d.format();
     jcp.with_bias = cd.bias_desc.format != memory_format::undef;
-    jcp.with_relu = with_relu;
-    jcp.relu_negative_slope = relu_negative_slope;
-    if (!IMPLICATION(with_relu, relu_negative_slope == 0.))
-        return status::unimplemented;
+
     if (!post_ops_ok(jcp, attr))
         return status::unimplemented;
 
@@ -743,7 +777,6 @@ status_t jit_avx512_core_u8s8s32x_wino_conv_fwd_ker_t
     jcp.alpha = jcp.m + jcp.r - 1;
 
     int aa = jcp.alpha * jcp.alpha;
-    int nthr = mkldnn_get_max_threads();
     int L1_cap = get_cache_size(1, true);
     int L2_cap = get_cache_size(2, true);
     // need 1 extra reg for bcast, and 2 tmp regs for non-vnni
@@ -755,12 +788,12 @@ status_t jit_avx512_core_u8s8s32x_wino_conv_fwd_ker_t
         float Y = (float)jcp.ic * jcp.oc;
         if (small_mb == 0) { // outer par
             int nblocks = jcp.mb * div_up(jcp.oh, iy) * div_up(jcp.ow, ix);
-            thr_eff = (float)nblocks / rnd_up(nblocks, nthr);
+            thr_eff = (float)nblocks / rnd_up(nblocks, jcp.nthr);
         } else { // inner par
             int tranw = iy * ix / jcp.alpha;
             int gemmw = aa * (jcp.nb_oc / n2_b);
-            int tranw_r = rnd_up(tranw, nthr);
-            int gemmw_r = rnd_up(gemmw, nthr);
+            int tranw_r = rnd_up(tranw, jcp.nthr);
+            int gemmw_r = rnd_up(gemmw, jcp.nthr);
             thr_eff = (Z * tranw / tranw_r + Y * gemmw / gemmw_r) / (Z + Y);
         }
         return thr_eff;
@@ -779,7 +812,7 @@ status_t jit_avx512_core_u8s8s32x_wino_conv_fwd_ker_t
             req_mem = (float)jcp.ic * (M + N) + jcp.typesize_acc * M * N;
             mem_eff = nstl::min(1.f, L2_cap / req_mem);
             // memory used during wino transforms
-            int M_per_thr = div_up(M, nthr);
+            int M_per_thr = div_up(M, jcp.nthr);
             req_mem = (float)aa * M_per_thr
                     * (jcp.ic + jcp.typesize_acc * jcp.oc);
             if (req_mem > L2_cap)
@@ -868,15 +901,34 @@ status_t jit_avx512_core_u8s8s32x_wino_conv_fwd_ker_t
     assert((jcp.m_block + 1) * jcp.n2_block <= free_regs);
     assert(jcp.xb % 2 == 0 && jcp.yb % 2 == 0);
 
-    jcp.inp_stride = jcp.yb * jcp.xb / 4 * jcp.ic;
-    jcp.out_stride = jcp.yb * jcp.xb / 4 * jcp.oc;
-    jcp.wei_stride = jcp.ic * jcp.oc;
-    jcp.bia_stride = jcp.oc;
+    jcp.mb_block = 1;
+    if (jcp.small_mb) {
+        // For small mb harness, set mb_block as large as possible subject to
+        // the constraint that winograd activations fit into available L3 cache
+        int L3_cap = get_cache_size(3, true);
+        int M = jcp.xb * jcp.yb / 4;
+        int wino_src_size = 16 * M * jcp.ic * jcp.typesize_in;
+        int wino_dst_size = 16 * M * jcp.oc * jcp.typesize_acc;
+        int max_mb_block = nstl::min(
+                jcp.mb, jcp.nthr * L3_cap / (wino_src_size + wino_dst_size));
+        for (int i = max_mb_block; i > 1; i--) {
+            if (jcp.mb % i == 0) {
+                jcp.mb_block = i;
+                break;
+            }
+        }
+    }
+    jcp.nb_mb = jcp.mb / jcp.mb_block;
 
-    jcp.M = jcp.xb * jcp.yb / 4;
+    jcp.M = jcp.mb_block * jcp.xb * jcp.yb / 4;
     jcp.N = jcp.oc;
     jcp.K = jcp.ic;
 
+    jcp.inp_stride = jcp.M * jcp.ic;
+    jcp.out_stride = jcp.M * jcp.oc;
+    jcp.wei_stride = jcp.ic * jcp.oc;
+    jcp.bia_stride = jcp.oc;
+
     jcp.n_block = jcp.oc_block;
     jcp.k_block = jcp.ic_block;
 
@@ -922,69 +974,82 @@ status_t jit_avx512_core_u8s8s32x_wino_conv_fwd_ker_t
     if (!wei_pd.is_equal(&new_weights_pd))
         return status::unimplemented;
 
+    const int tilesize = jcp.alpha * jcp.alpha;
+    const int numtiles = jcp.M;
+    const int alltiles = numtiles * tilesize;
+
+    jcp.size_wino_src
+        = utils::rnd_up(jcp.typesize_in * alltiles * jcp.ic, PAGE_4K)
+        / jcp.typesize_in;
+    jcp.size_wino_wei = tilesize * jcp.oc * jcp.ic;
+    jcp.size_wino_dst = alltiles * jcp.oc;
+
     return status::success;
 }
 ////////////////////////////////////////////////////////////////////////////////
 
-template <bool with_relu, data_type_t dst_data_type>
-status_t _jit_avx512_core_u8s8s32x_wino_convolution_fwd_t<with_relu,
-        dst_data_type>::pd_t::jit_conf() {
+template <data_type_t dst_data_type>
+status_t jit_avx512_core_u8s8s32x_wino_convolution_fwd_t<dst_data_type>::
+        pd_t::jit_conf() {
     return jit_avx512_core_u8s8s32x_wino_conv_fwd_ker_t::init_conf(
-            jcp_, this->cdesc_(), this->src_pd_, this->weights_pd_,
-            this->dst_pd_,this->bias_pd_, *this->attr(),
-            with_relu, this->negative_slope());
+            jcp_, *this->desc(), this->src_pd_, this->weights_pd_,
+            this->dst_pd_,this->bias_pd_, *this->attr());
 }
 
-template <bool with_relu, data_type_t dst_data_type>
-_jit_avx512_core_u8s8s32x_wino_convolution_fwd_t<with_relu, dst_data_type>::
-        _jit_avx512_core_u8s8s32x_wino_convolution_fwd_t(const pd_t *pd,
+template <data_type_t dst_data_type>
+void jit_avx512_core_u8s8s32x_wino_convolution_fwd_t<dst_data_type>::pd_t::
+init_scratchpad() {
+    auto scratchpad = this->scratchpad_registry().registrar();
+
+    int nthr_multiplier = jcp_.small_mb ? 1 : jcp_.nthr;
+    scratchpad.book(key_wino_V,
+            sizeof(src_data_t) * jcp_.size_wino_src * nthr_multiplier, PAGE_4K);
+    scratchpad.book(key_wino_M,
+            sizeof(acc_data_t) * jcp_.size_wino_dst * nthr_multiplier, PAGE_4K);
+
+    scratchpad.book(key_conv_adjusted_scales,
+            sizeof(float) * nstl::max(attr()->output_scales_.count_, 16));
+}
+
+template <data_type_t dst_data_type>
+jit_avx512_core_u8s8s32x_wino_convolution_fwd_t<dst_data_type>::
+        jit_avx512_core_u8s8s32x_wino_convolution_fwd_t(const pd_t *apd,
                 const input_vector &inputs, const output_vector &outputs)
-    : cpu_primitive_t(&conf_, inputs, outputs)
-    , conf_(*pd)
-    , scratchpad_(nullptr) {
-    const int nthreads = mkldnn_get_max_threads();
+    : cpu_primitive_t(apd, inputs, outputs, true)
+{
     kernel_ = new jit_avx512_core_u8s8s32x_wino_conv_fwd_ker_t(
-            conf_.jcp_, *conf_.attr());
+            pd()->jcp_, *pd()->attr());
     src_trans_ = new jit_avx512_core_u8s8s32x_wino_conv_src_trans_t(
-            conf_.jcp_, *conf_.attr());
+            pd()->jcp_, *pd()->attr());
     dst_trans_ = new jit_avx512_core_u8s8s32x_wino_conv_dst_trans_t(
-            conf_.jcp_, *conf_.attr());
-
-    const int tilesize = conf_.jcp_.alpha * conf_.jcp_.alpha;
-    const int numtiles = (conf_.jcp_.yb / 2) * (conf_.jcp_.xb / 2);
-    const int alltiles = tilesize * numtiles;
-    size_wino_wei_ = tilesize * conf_.jcp_.oc * conf_.jcp_.ic;
-    size_wino_src_ = sizeof(src_data_t) * alltiles * conf_.jcp_.ic;
-    size_wino_src_ = rnd_up(size_wino_src_, PAGE_4K);
-    size_wino_src_ /= sizeof(src_data_t);
-    size_wino_dst_ = alltiles * conf_.jcp_.oc;
-
-    size_t workspace_size = (conf_.jcp_.small_mb ? 1 : nthreads)
-            * (sizeof(src_data_t) * size_wino_src_
-                                    + sizeof(acc_data_t) * size_wino_dst_);
-
-    scratchpad_ = create_scratchpad(workspace_size);
-    assert(scratchpad_); // TODO: add proper check and raise exception?
-
-    wino_shift_ = (conf_.jcp_.small_mb ? 1 : nthreads) * sizeof(src_data_t)
-            * size_wino_src_;
-
-    updated_output_scales_ = conf_.attr()->output_scales_;
-    updated_output_scales_.scale(1.f / (adj_src_scale * adj_wei_scale));
+            pd()->jcp_, *pd()->attr());
 }
 
-template <bool with_relu, data_type_t dst_data_type>
-_jit_avx512_core_u8s8s32x_wino_convolution_fwd_t<with_relu,
-        dst_data_type>::~_jit_avx512_core_u8s8s32x_wino_convolution_fwd_t() {
+template <data_type_t dst_data_type>
+jit_avx512_core_u8s8s32x_wino_convolution_fwd_t<dst_data_type>::
+        ~jit_avx512_core_u8s8s32x_wino_convolution_fwd_t() {
     delete kernel_;
     delete src_trans_;
     delete dst_trans_;
-    delete scratchpad_;
 }
 
-template <bool with_relu, data_type_t dst_data_type>
-void _jit_avx512_core_u8s8s32x_wino_convolution_fwd_t<with_relu,
-        dst_data_type>::execute_forward() {
+template <data_type_t dst_data_type>
+const float *jit_avx512_core_u8s8s32x_wino_convolution_fwd_t<dst_data_type>::
+adjust_oscales(const memory_tracking::grantor_t &scratchpad) const {
+    const float *oscales = pd()->attr()->output_scales_.scales_;
+    auto loc_scales = scratchpad.template get<float>(key_conv_adjusted_scales);
+    size_t count = pd()->attr()->output_scales_.count_;
+    float factor = 1.f / (adj_src_scale * adj_wei_scale);
+    if (count == 1)
+        utils::array_set(loc_scales, oscales[0] * factor, 16);
+    else
+        for (size_t c = 0; c < count; c++) loc_scales[c] = oscales[c] * factor;
+    return loc_scales;
+}
+
+template <data_type_t dst_data_type>
+void jit_avx512_core_u8s8s32x_wino_convolution_fwd_t<dst_data_type>::
+execute_forward() const {
     const auto &jcp = kernel_->jcp;
     if (jcp.small_mb)
         execute_forward_small_mb();
@@ -992,21 +1057,22 @@ void _jit_avx512_core_u8s8s32x_wino_convolution_fwd_t<with_relu,
         execute_forward_mbN();
 }
 
-template <bool with_relu, data_type_t dst_data_type>
-void _jit_avx512_core_u8s8s32x_wino_convolution_fwd_t<with_relu,
-        dst_data_type>::execute_forward_mbN() {
+template <data_type_t dst_data_type>
+void jit_avx512_core_u8s8s32x_wino_convolution_fwd_t<dst_data_type>::
+execute_forward_mbN() const {
     auto src = reinterpret_cast<const src_data_t *>(input_memory(0));
     auto wei = reinterpret_cast<const wei_data_t *>(input_memory(1));
     auto bia = reinterpret_cast<const char *>(input_memory(2));
     auto dst = reinterpret_cast<dst_data_t *>(memory(0));
 
+    auto scratchpad = this->scratchpad();
+
     const auto &jcp = kernel_->jcp;
-    const auto &oscales = updated_output_scales_;
+    const float *oscales = adjust_oscales(scratchpad);
 
-    auto wino_wei = wei;
-    auto dst_bias = (const acc_data_t *)(wei + size_wino_wei_);
-    auto wino_src_base = (src_data_t *)scratchpad_->get();
-    auto wino_dst_base = (acc_data_t *)(scratchpad_->get() + wino_shift_);
+    auto dst_bias = (const acc_data_t *)(wei + jcp.size_wino_wei);
+    auto wino_src_base = scratchpad.template get<src_data_t>(key_wino_V);
+    auto wino_dst_base = scratchpad.template get<acc_data_t>(key_wino_M);
 
     parallel_nd(jcp.mb, div_up(jcp.oh, jcp.yb), div_up(jcp.ow, jcp.xb),
             [&](int mb, int tile_y_b, int tile_x_b) {
@@ -1015,8 +1081,8 @@ void _jit_avx512_core_u8s8s32x_wino_convolution_fwd_t<with_relu,
         int tile_x = tile_x_b * jcp.xb;
 
         int ithr = mkldnn_get_thread_num();
-        auto wino_src = wino_src_base + size_wino_src_ * ithr;
-        auto wino_dst = wino_dst_base + size_wino_dst_ * ithr;
+        auto wino_src = wino_src_base + jcp.size_wino_src * ithr;
+        auto wino_dst = wino_dst_base + jcp.size_wino_dst * ithr;
 
         auto src_trans_p =
             jit_avx512_core_u8s8s32x_wino_conv_src_trans_t::call_params_t();
@@ -1028,7 +1094,7 @@ void _jit_avx512_core_u8s8s32x_wino_convolution_fwd_t<with_relu,
         /* transformation of input tensor to winograd domain */
         for (int y_in_block = 0; y_in_block < jcp.yb; y_in_block += 2) {
             for (int x_in_block = 0; x_in_block < jcp.xb; x_in_block += 2) {
-                unsigned short v_y_masks[4], v_x_masks[4];
+                uint16_t v_y_masks[4], v_x_masks[4];
 
                 int y = y_in_block + tile_y;
                 int x = x_in_block + tile_x;
@@ -1044,8 +1110,8 @@ void _jit_avx512_core_u8s8s32x_wino_convolution_fwd_t<with_relu,
 
 #pragma unroll(4)
                 for (int i = 0; i < jcp.alpha; i++) {
-                    v_y_masks[i] = (i < v_ys || i >= v_ye) ? 0 : 0xffff;
-                    v_x_masks[i] = (i < v_xs || i >= v_xe) ? 0 : 0xffff;
+                    v_y_masks[i] = uint16_t(i < v_ys || i >= v_ye ? 0 : 0xffff);
+                    v_x_masks[i] = uint16_t(i < v_xs || i >= v_xe ? 0 : 0xffff);
                 }
                 auto local_s = src
                         + mb * jcp.ih * jcp.iw * jcp.ic
@@ -1066,7 +1132,7 @@ void _jit_avx512_core_u8s8s32x_wino_convolution_fwd_t<with_relu,
             int offset = (tile_ij + ithr) % 16;
             gemm_p.src = wino_src + jcp.inp_stride * offset;
             gemm_p.dst = wino_dst + jcp.out_stride * offset;
-            gemm_p.wei = wino_wei + jcp.wei_stride * offset;
+            gemm_p.wei = wei + jcp.wei_stride * offset;
             gemm_p.dst_b = dst_bias + jcp.bia_stride * offset;
 
             kernel_->ker_(&gemm_p);
@@ -1075,7 +1141,7 @@ void _jit_avx512_core_u8s8s32x_wino_convolution_fwd_t<with_relu,
         /* transformation from winograd domain to output tensor */
         for (int y_in_block = 0; y_in_block < jcp.yb; y_in_block += 2) {
             for (int x_in_block = 0; x_in_block < jcp.xb; x_in_block += 2) {
-                unsigned short v_y_masks[2], v_x_masks[2];
+                uint16_t v_y_masks[2], v_x_masks[2];
 
                 int y = y_in_block + tile_y;
                 int x = x_in_block + tile_x;
@@ -1083,15 +1149,15 @@ void _jit_avx512_core_u8s8s32x_wino_convolution_fwd_t<with_relu,
 
 #pragma unroll(2)
                 for (int i = 0; i < jcp.m; i++) {
-                    v_x_masks[i] = (x + i < jcp.ow) ? 0xffff : 0;
-                    v_y_masks[i] = (y + i < jcp.oh) ? 0xffff : 0;
+                    v_x_masks[i] = uint16_t(x + i < jcp.ow ? 0xffff : 0);
+                    v_y_masks[i] = uint16_t(y + i < jcp.oh ? 0xffff : 0);
                 }
                 auto local_d = dst
                         + mb * jcp.oh * jcp.ow * jcp.oc
                         + y * jcp.ow * jcp.oc + x * jcp.oc;
                 auto local_w = wino_dst + m * jcp.oc;
 
-                auto scales = oscales.scales_;
+                auto scales = oscales;
                 dst_trans_p.dst = local_d;
                 dst_trans_p.wino_dst = local_w;
                 dst_trans_p.v_y_masks = v_y_masks;
@@ -1106,39 +1172,41 @@ void _jit_avx512_core_u8s8s32x_wino_convolution_fwd_t<with_relu,
     });
 }
 
-template <bool with_relu, data_type_t dst_data_type>
-void _jit_avx512_core_u8s8s32x_wino_convolution_fwd_t<with_relu,
-        dst_data_type>::execute_forward_small_mb() {
+template <data_type_t dst_data_type>
+void jit_avx512_core_u8s8s32x_wino_convolution_fwd_t<dst_data_type>::
+execute_forward_small_mb() const {
     auto src = reinterpret_cast<const src_data_t *>(input_memory(0));
     auto wei = reinterpret_cast<const wei_data_t *>(input_memory(1));
     auto bia = reinterpret_cast<const char *>(input_memory(2));
     auto dst = reinterpret_cast<dst_data_t *>(memory(0));
 
+    auto scratchpad = this->scratchpad();
+
     const auto &jcp = kernel_->jcp;
-    const auto &oscales = updated_output_scales_;
+    const float *oscales = adjust_oscales(scratchpad);
 
-    auto wino_wei = wei;
-    auto dst_bias = (const acc_data_t *)(wei + size_wino_wei_);
-    auto wino_src = (src_data_t *)scratchpad_->get();
-    auto wino_dst = (acc_data_t *)(scratchpad_->get() + wino_shift_);
+    auto dst_bias = (const acc_data_t *)(wei + jcp.size_wino_wei);
+    auto wino_src = scratchpad.template get<src_data_t>(key_wino_V);
+    auto wino_dst = scratchpad.template get<acc_data_t>(key_wino_M);
 
-    for (int mb = 0; mb < jcp.mb; mb++) {
+    for (int mbb = 0; mbb < jcp.nb_mb; mbb++) {
     for (int tile_y = 0; tile_y < jcp.oh; tile_y += jcp.yb) {
     for (int tile_x = 0; tile_x < jcp.ow; tile_x += jcp.xb) {
         /* transformation of input tensor to winograd domain */
-        parallel_nd(div_up(jcp.yb, 2), div_up(jcp.xb, 2),
-            [&](int y_in_block_b, int x_in_block_b) {
+        parallel_nd(div_up(jcp.yb, 2), div_up(jcp.xb, 2), jcp.mb_block,
+            [&](int y_in_block_b, int x_in_block_b, int mb) {
             int y_in_block = y_in_block_b * 2;
             int x_in_block = x_in_block_b * 2;
 
             auto src_trans_p =
                 jit_avx512_core_u8s8s32x_wino_conv_src_trans_t::call_params_t();
 
-            unsigned short v_y_masks[4], v_x_masks[4];
+            uint16_t v_y_masks[4], v_x_masks[4];
 
             int y = y_in_block + tile_y;
             int x = x_in_block + tile_x;
-            int m = (y_in_block / 2) * (jcp.xb / 2) + (x_in_block / 2);
+            int m = (mb * (jcp.yb / 2) + (y_in_block / 2)) * (jcp.xb / 2)
+                    + (x_in_block / 2);
 
             int v_ys = nstl::max(0, jcp.t_pad - y);
             int v_ye = nstl::min(
@@ -1150,11 +1218,11 @@ void _jit_avx512_core_u8s8s32x_wino_convolution_fwd_t<with_relu,
 
 #pragma unroll(4)
             for (int i = 0; i < jcp.alpha; i++) {
-                v_y_masks[i] = (i < v_ys || i >= v_ye) ? 0 : 0xffff;
-                v_x_masks[i] = (i < v_xs || i >= v_xe) ? 0 : 0xffff;
+                v_y_masks[i] = uint16_t(i < v_ys || i >= v_ye ? 0 : 0xffff);
+                v_x_masks[i] = uint16_t(i < v_xs || i >= v_xe ? 0 : 0xffff);
             }
             auto local_s = src
-                    + mb * jcp.ih * jcp.iw * jcp.ic
+                    + (mbb * jcp.mb_block + mb) * jcp.ih * jcp.iw * jcp.ic
                     + y * jcp.iw * jcp.ic + x * jcp.ic;
             auto local_w = wino_src + m * jcp.ic;
 
@@ -1174,7 +1242,7 @@ void _jit_avx512_core_u8s8s32x_wino_convolution_fwd_t<with_relu,
             gemm_p.src = wino_src + jcp.inp_stride * tile_ij;
             gemm_p.dst = wino_dst + jcp.out_stride * tile_ij
                     + nnb * jcp.n2_block * jcp.n_block;
-            gemm_p.wei = wino_wei + jcp.wei_stride * tile_ij
+            gemm_p.wei = wei + jcp.wei_stride * tile_ij
                     + nnb * jcp.n2_block * jcp.n_block * jcp.K;
             gemm_p.dst_b = dst_bias + jcp.bia_stride * tile_ij
                     + nnb * jcp.n2_block * jcp.n_block;
@@ -1183,31 +1251,32 @@ void _jit_avx512_core_u8s8s32x_wino_convolution_fwd_t<with_relu,
         });
 
         /* transformation from winograd domain to output tensor */
-        parallel_nd(div_up(jcp.yb, 2), div_up(jcp.xb, 2),
-            [&](int y_in_block_b, int x_in_block_b) {
+        parallel_nd(div_up(jcp.yb, 2), div_up(jcp.xb, 2), jcp.mb_block,
+            [&](int y_in_block_b, int x_in_block_b, int mb) {
             int y_in_block = y_in_block_b * 2;
             int x_in_block = x_in_block_b * 2;
 
             auto dst_trans_p =
                 jit_avx512_core_u8s8s32x_wino_conv_dst_trans_t::call_params_t();
 
-            unsigned short v_y_masks[2], v_x_masks[2];
+            uint16_t v_y_masks[2], v_x_masks[2];
 
             int y = y_in_block + tile_y;
             int x = x_in_block + tile_x;
-            int m = (y_in_block / 2) * (jcp.xb / 2) + (x_in_block / 2);
+            int m = (mb * (jcp.yb / 2) + (y_in_block / 2)) * (jcp.xb / 2)
+                    + (x_in_block / 2);
 
 #pragma unroll(2)
             for (int i = 0; i < jcp.m; i++) {
-                v_x_masks[i] = (x + i < jcp.ow) ? 0xffff : 0;
-                v_y_masks[i] = (y + i < jcp.oh) ? 0xffff : 0;
+                v_x_masks[i] = uint16_t(x + i < jcp.ow ? 0xffff : 0);
+                v_y_masks[i] = uint16_t(y + i < jcp.oh ? 0xffff : 0);
             }
             auto local_d = dst
-                    + mb * jcp.oh * jcp.ow * jcp.oc
+                    + (mbb * jcp.mb_block + mb) * jcp.oh * jcp.ow * jcp.oc
                     + y * jcp.ow * jcp.oc + x * jcp.oc;
             auto local_w = wino_dst + m * jcp.oc;
 
-            auto scales = oscales.scales_;
+            auto scales = oscales;
             dst_trans_p.dst = local_d;
             dst_trans_p.wino_dst = local_w;
             dst_trans_p.v_y_masks = v_y_masks;
@@ -1221,22 +1290,10 @@ void _jit_avx512_core_u8s8s32x_wino_convolution_fwd_t<with_relu,
     }}}
 }
 
-template struct _jit_avx512_core_u8s8s32x_wino_convolution_fwd_t<true,
-        data_type::s8>;
-template struct _jit_avx512_core_u8s8s32x_wino_convolution_fwd_t<false,
-        data_type::s8>;
-template struct _jit_avx512_core_u8s8s32x_wino_convolution_fwd_t<true,
-        data_type::u8>;
-template struct _jit_avx512_core_u8s8s32x_wino_convolution_fwd_t<false,
-        data_type::u8>;
-template struct _jit_avx512_core_u8s8s32x_wino_convolution_fwd_t<true,
-        data_type::s32>;
-template struct _jit_avx512_core_u8s8s32x_wino_convolution_fwd_t<false,
-        data_type::s32>;
-template struct _jit_avx512_core_u8s8s32x_wino_convolution_fwd_t<true,
-        data_type::f32>;
-template struct _jit_avx512_core_u8s8s32x_wino_convolution_fwd_t<false,
-        data_type::f32>;
+template struct jit_avx512_core_u8s8s32x_wino_convolution_fwd_t<data_type::s8>;
+template struct jit_avx512_core_u8s8s32x_wino_convolution_fwd_t<data_type::u8>;
+template struct jit_avx512_core_u8s8s32x_wino_convolution_fwd_t<data_type::s32>;
+template struct jit_avx512_core_u8s8s32x_wino_convolution_fwd_t<data_type::f32>;
 
 } // namespace cpu
 } // namespace impl