Publishing 2019 R1 content
[platform/upstream/dldt.git] / inference-engine / thirdparty / mkl-dnn / src / cpu / gemm_x8s8s32x_convolution.cpp
index 5512626..d9b8205 100644 (file)
@@ -32,99 +32,547 @@ namespace cpu {
 
 using namespace mkldnn::impl::utils;
 using namespace mkldnn::impl::math;
+using namespace mkldnn::impl::memory_tracking::names;
 
-template <bool with_relu, data_type_t src_type, data_type_t dst_type>
-void _gemm_x8s8s32x_convolution_fwd_t<with_relu, src_type,
-        dst_type>::execute_forward() {
+template <data_type_t src_type, data_type_t dst_type>
+void _gemm_x8s8s32x_convolution_fwd_t<src_type, dst_type>::
+execute_forward() const {
     auto src_base = reinterpret_cast<const src_data_t *>(this->input_memory(0));
     auto wei_base = reinterpret_cast<const wei_data_t *>(this->input_memory(1));
     auto bia_base = reinterpret_cast<const char *>(this->input_memory(2));
     auto dst_base = reinterpret_cast<dst_data_t *>(this->memory());
 
-    jit_gemm_conv_conf_t &jcp = this->conf_.jcp_;
+    auto scratchpad = this->scratchpad();
 
-    char *scratchpad = (char *)this->scratchpad_->get();
-    uint8_t *col = (uint8_t *)scratchpad;
+    const jit_gemm_conv_conf_t &jcp = this->pd()->jcp_;
+
+    auto col = scratchpad.template get<uint8_t>(key_conv_gemm_col);
     parallel_nd(jcp.im2col_sz * jcp.nthr, [&](ptrdiff_t i) {
         col[i] = jcp.signed_input ? (uint8_t)128 : (uint8_t)0;
     });
 
     parallel(jcp.nthr, [&](const int ithr, const int nthr) {
-        execute_forward_thr(ithr, nthr, src_base, wei_base, bia_base,
-                dst_base, scratchpad);
+        execute_forward_thr(ithr, nthr, src_base, wei_base, bia_base, dst_base,
+                scratchpad);
     });
 }
 
-template <bool with_relu, data_type_t src_type, data_type_t dst_type>
-void _gemm_x8s8s32x_convolution_fwd_t<with_relu, src_type,
-        dst_type>::execute_forward_thr(const int ithr, const int nthr,
-        const src_data_t *src_base, const wei_data_t *wei_base,
-        const char *bia_base, dst_data_t *dst_base, char *scratchpad) {
-#if USE_MKL_IGEMM
-    jit_gemm_conv_conf_t &jcp = this->conf_.jcp_;
+template <data_type_t src_type, data_type_t dst_type>
+_gemm_x8s8s32x_convolution_fwd_t<src_type, dst_type>::pp_ker_t::pp_ker_t(
+    const pd_t *pd)
+    : ker_(nullptr)
+    , jcp_(pd->jcp_)
+    , OC_(pd->jcp_.oc)
+    , OS_(pd->jcp_.os)
+    , bias_data_type_(data_type::undef)
+    , bias_data_type_size_(0)
+    , scale_idx_mult_(0)
+    , rmode_(round_mode::nearest)
+    , do_bias_(false)
+    , do_relu_(false)
+    , do_sum_(false)
+{
+    using namespace types;
 
-    const auto src_md = memory_desc_wrapper(conf_.src_pd());
-    const size_t src_mb_stride = src_md.blk_off(1);
-    const size_t src_g_stride = src_md.blk_off(0, 1) * jcp.ic;
+    const auto dst_md = memory_desc_wrapper(pd->dst_pd());
+    dst_os_stride_ = dst_md.blk_off(0, 0, 0, 1);
 
-    const auto wei_md = memory_desc_wrapper(conf_.weights_pd(0));
-    const size_t wei_g_stride = conf_.with_groups() ? wei_md.blk_off(1) : 0;
+    scale_idx_mult_ = (pd->attr()->output_scales_.mask_ == (1 << 1));
+    rmode_ = pd->attr()->round_mode_;
 
-    const auto dst_md = memory_desc_wrapper(conf_.dst_pd());
-    const size_t dst_mb_stride = dst_md.blk_off(1);
-    const size_t dst_g_stride = dst_md.blk_off(0, 1) * jcp.oc;
-    const size_t dst_os_stride = dst_md.blk_off(0, 0, 0, 1);
-
-    auto get_bias = [=, &bia_base](size_t off) -> acc_data_t {
-#       define CASE(dt) case dt: return (acc_data_t)\
-        (*((const prec_traits<dt>::type *)bia_base + off))
-        switch (conf_.cdesc()->bias_desc.data_type) {
-        CASE(data_type::s8);
-        CASE(data_type::u8);
-        CASE(data_type::s32);
-        CASE(data_type::f32);
+    auto &post_ops = pd->attr()->post_ops_;
+
+    int entry_idx = -1;
+    for (int idx = 0; idx < post_ops.len_; ++idx) {
+        const auto &e = post_ops.entry_[idx];
+        if (e.is_relu(true, false)) {
+            entry_idx = idx;
+            break;
+        }
+    }
+    do_relu_ = entry_idx >= 0;
+
+    do_signed_scaling_ = jcp_.signed_input;
+
+    do_sum_ = post_ops.contain(primitive_kind::sum, 0);
+    do_bias_ = pd->with_bias();
+    bias_data_type_ = pd->desc()->bias_desc.data_type;
+    if (do_bias_) {
+        assert(bias_data_type_ != data_type::undef);
+        bias_data_type_size_ = data_type_size(bias_data_type_);
+    }
+    const size_t vlen_start
+            = cpu_isa_traits<avx512_common>::vlen / sizeof(float);
+
+    for (size_t i = vlen_start; i > 0; i--) {
+        if (OC_ % i == 0) {
+            vlen_ = i;
+            break;
+        }
+    }
+
+    if (!mayiuse(avx512_core))
+        // use fallback code for older CPUs
+        return;
+    else
+        generate();
+}
+
+template <data_type_t src_type, data_type_t dst_type>
+void _gemm_x8s8s32x_convolution_fwd_t<src_type, dst_type>::pp_ker_t::generate()
+{
+    using namespace Xbyak;
+    using namespace utils;
+    using namespace round_mode;
+
+    // TODO: clean-up
+    Reg64 reg_param = abi_param1;
+    Reg64 reg_dst = rdx;
+    Reg64 reg_acc = rax;
+    Reg64 reg_bias = rbx;
+    Reg64 reg_scales = rsi;
+
+    Reg64 reg_len = r8;
+    Reg64 reg_tmp = rcx; // intentional for shifting purposes
+    Reg64 reg_oc_offset = r9;
+    Reg64 reg_rem_mask_short = r10;
+    Reg64 reg_rem_mask_vlen = r11;
+    Opmask kreg_rem_mask_short = k1;
+    Opmask kreg_rem_mask_vlen = k3;
+    Opmask kreg_relu_cmp = k2;
+
+    const size_t vlen = 4;
+
+    Zmm vreg_zero = Zmm(0);
+    Zmm vreg_scale = Zmm(1);
+    Zmm vreg_nslope = Zmm(2);
+    Zmm vreg_sum_scale = Zmm(3);
+    Zmm vreg_signed_scale = Zmm(4);
+
+    size_t def_unroll = 4;
+    size_t max_unroll = 12;
+    size_t zmm_step = 2;
+    if (do_sum_) {
+        max_unroll = 8;
+        zmm_step = 3;
+    }
+
+    auto vreg_dst = [&](int idx) {
+        return Zmm(5 + idx * zmm_step + 0);
+    };
+    auto vreg_bias = [&](int idx) {
+        return Zmm(5 + idx * zmm_step + 1);
+    };
+    auto vreg_prev_dst = [&](int idx) {
+        return Zmm(5 + idx * zmm_step + 2);
+    };
+
+    preamble();
+
+#define PARAM_OFF(x) offsetof(ker_args, x)
+    mov(reg_dst, ptr[reg_param + PARAM_OFF(dst)]);
+    mov(reg_acc, ptr[reg_param + PARAM_OFF(acc)]);
+    mov(reg_bias, ptr[reg_param + PARAM_OFF(bias)]);
+    mov(reg_scales, ptr[reg_param + PARAM_OFF(scales)]);
+    mov(reg_len, ptr[reg_param + PARAM_OFF(len)]);
+    mov(reg_oc_offset, ptr[reg_param + PARAM_OFF(oc_offset)]);
+    vbroadcastss(vreg_nslope, ptr[reg_param + PARAM_OFF(nslope)]);
+    vbroadcastss(vreg_sum_scale, ptr[reg_param + PARAM_OFF(sum_scale)]);
+    vbroadcastss(vreg_signed_scale, ptr[reg_param + PARAM_OFF(signed_scale)]);
+    if (scale_idx_mult_ == 0)
+        vbroadcastss(vreg_scale, dword[reg_scales]);
+
+#undef PARAM_OFF
+
+    mov(reg_rem_mask_vlen, 1);
+    shl(reg_rem_mask_vlen, vlen);
+    sub(reg_rem_mask_vlen, 1);
+    kmovq(kreg_rem_mask_vlen, reg_rem_mask_vlen);
+
+    if (do_relu_ || dst_type == data_type::u8)
+        vxorps(vreg_zero, vreg_zero, vreg_zero);
+
+    // Load accumulated value, convert to float, apply sum (if any),
+    // bias (if any), scaling, and relu (if any);
+    // then convert to destination type and store
+    auto compute = [&](size_t offset, int idx, bool apply_mask) {
+        auto acc_addr = ptr[reg_acc + offset * sizeof(acc_data_t)];
+
+        if (scale_idx_mult_ > 0) {
+            assert(scale_idx_mult_ == 1);
+            auto scale_addr = ptr[reg_scales + offset * sizeof(float)];
+            auto vreg_scale_ = vreg_scale;
+            if (apply_mask)
+                vreg_scale_ = vreg_scale_ | kreg_rem_mask_short;
+            else
+                vreg_scale_ = vreg_scale_ | kreg_rem_mask_vlen;
+            vmovups(vreg_scale_, scale_addr);
+        }
+
+        auto vreg_dst_ = vreg_dst(idx);
+        if (apply_mask)
+            vreg_dst_ = vreg_dst_ | kreg_rem_mask_short;
+        else
+            vreg_dst_ = vreg_dst_ | kreg_rem_mask_vlen;
+        vcvtdq2ps(vreg_dst_, acc_addr);
+
+        if (do_signed_scaling_)
+            vmulps(vreg_dst(idx), vreg_dst(idx), vreg_signed_scale);
+
+        if (do_bias_) {
+            auto bias_addr = ptr[reg_bias + offset * bias_data_type_size_];
+            auto vreg_bias_ = vreg_bias(idx);
+            if (apply_mask)
+                vreg_bias_ = vreg_bias_ | kreg_rem_mask_short;
+            else
+                vreg_bias_ = vreg_bias_ | kreg_rem_mask_vlen;
+
+            switch (bias_data_type_) {
+            case data_type::s8:
+                vpmovsxbd(vreg_bias_, bias_addr);
+                break;
+            case data_type::u8:
+                vpmovzxbd(vreg_bias_, bias_addr);
+                break;
+            case data_type::s32:
+                vcvtdq2ps(vreg_bias_, bias_addr);
+                break;
+            case data_type::f32:
+                vmovups(vreg_bias_, bias_addr);
+                break;
+            default: assert(!"unimplemented");
+            }
+            vaddps(vreg_dst(idx), vreg_dst(idx), vreg_bias(idx));
+        }
+
+        vmulps(vreg_dst(idx), vreg_dst(idx), vreg_scale);
+
+        auto dst_addr = ptr[reg_dst + offset * sizeof(dst_data_t)];
+
+        if (do_sum_)
+        {
+            auto vreg_prev_dst_ = vreg_prev_dst(idx);
+            if (apply_mask)
+                vreg_prev_dst_ = vreg_prev_dst_ | kreg_rem_mask_short;
+            else
+                vreg_prev_dst_ = vreg_prev_dst_ | kreg_rem_mask_vlen;
+
+            switch (dst_type) {
+            case data_type::f32:
+            case data_type::s32: vmovups(vreg_prev_dst_, dst_addr); break;
+            case data_type::s8: vpmovsxbd(vreg_prev_dst_, dst_addr); break;
+            case data_type::u8: vpmovzxbd(vreg_prev_dst_, dst_addr); break;
+            default: assert(!"unsupported data type");
+            }
+            if (dst_type != data_type::f32)
+                vcvtdq2ps(vreg_prev_dst(idx), vreg_prev_dst(idx));
+
+            vfmadd231ps(vreg_dst(idx), vreg_prev_dst(idx), vreg_sum_scale);
+        }
+
+        if (do_relu_) {
+            vcmpps(kreg_relu_cmp, vreg_dst(idx), vreg_zero, _cmp_lt_os);
+            vmulps(vreg_dst(idx) | kreg_relu_cmp, vreg_dst(idx), vreg_nslope);
+        }
+
+        if (dst_type != data_type::f32) {
+            auto rmode_control = (rmode_ == nearest ? T_rn_sae : T_rd_sae);
+            vcvtps2dq(vreg_dst(idx) | rmode_control, vreg_dst(idx));
+        }
+
+        if (dst_type == data_type::u8)
+            vpmaxsd(vreg_dst(idx), vreg_dst(idx), vreg_zero);
+
+        switch (dst_type) {
+        case data_type::s8:
+            vpmovsdb(dst_addr, vreg_dst_);
+            break;
+        case data_type::u8:
+            vpmovusdb(dst_addr, vreg_dst_);
+            break;
+        case data_type::f32:
+        case data_type::s32:
+            vmovups(dst_addr, vreg_dst_);
+            break;
         default: assert(!"unimplemented");
         }
-#       undef CASE
-        return 0;
     };
 
-    /* scale_idx_mult = 1 for per_oc scales and 0, otherwise */
-    const int scale_idx_mult = conf_.attr()->output_scales_.mask_ == (1 << 1);
-    const float *scales = conf_.attr()->output_scales_.scales_;
+    // Advance all pointers by an immediate
+    auto advance_ptrs_imm = [&](size_t offset) {
+        add(reg_dst, offset * sizeof(dst_data_t));
+        add(reg_acc, offset * sizeof(acc_data_t));
+        if (scale_idx_mult_) {
+            assert(scale_idx_mult_ == 1);
+            add(reg_scales, offset * sizeof(float));
+        }
+        if (do_bias_)
+            add(reg_bias, offset * bias_data_type_size_);
+    };
+
+    // Advance all pointers by a value stored in a register
+    auto advance_ptrs_reg = [&](Reg64 offset) {
+        lea(reg_dst, ptr[reg_dst + offset * sizeof(dst_data_t)]);
+        lea(reg_acc, ptr[reg_acc + offset * sizeof(acc_data_t)]);
+        if (scale_idx_mult_) {
+            assert(scale_idx_mult_ == 1);
+            lea(reg_scales, ptr[reg_scales + offset * sizeof(float)]);
+        }
+        if (do_bias_)
+            lea(reg_bias, ptr[reg_bias + offset * bias_data_type_size_]);
+    };
+
+    // Rewind pointers that point to data that is indexed by output channel
+    // (bias or per-oc scaling factors)
+    auto rewind_ptrs = [&]() {
+        if (do_bias_)
+            sub(reg_bias, OC_ * bias_data_type_size_);
+        if (scale_idx_mult_) {
+            assert(scale_idx_mult_ == 1);
+            sub(reg_scales, OC_ * sizeof(float));
+        }
+        add(reg_dst, (dst_os_stride_ - OC_) * sizeof(dst_data_t));
+    };
+
+    //                    <--------- OC --------------->
+    //
+    // ^  ................+..............+-------------+.......................
+    // |  .               : not accessed |Prologue loop|                      .
+    // |  .               +--------------+-------------+                      .
+    //    .               |                            |                      .
+    // O  .               |  Main loop (unrolled)      |                      .
+    // S  .               |                            |                      .
+    //    .               +--------------+-------------+                      .
+    // |  .               | Epilogue loop|not accessed :                      .
+    // v  ................+--------------+.............+.......................
+
+    Label prologue_end;
+    cmp(reg_oc_offset, 0);
+    je(prologue_end, T_NEAR);
+
+    // Prologue loop
+    {
+        mov(reg_tmp, OC_);
+        sub(reg_tmp, reg_oc_offset);
+        cmp(reg_tmp, reg_len);
+        cmovg(reg_tmp, reg_len);
+        sub(reg_len, reg_tmp);
+
+        Label prologue_loop, prologue_loop_tail, prologue_loop_end;
+        cmp(reg_tmp, vlen);
+        jle(prologue_loop_tail, T_NEAR);
+        L(prologue_loop); {
+            compute(0, 0, false);
+            advance_ptrs_imm(vlen);
+            sub(reg_tmp, vlen);
+            cmp(reg_tmp, vlen);
+            jge(prologue_loop, T_NEAR);
+        }
+
+        L(prologue_loop_tail);
+        mov(reg_rem_mask_short, 1);
+        // cl == reg_tmp because reg_tmp <= vlen here
+        shl(reg_rem_mask_short, cl);
+        sub(reg_rem_mask_short, 1);
+        jz(prologue_loop_end, T_NEAR);
+
+        kmovq(kreg_rem_mask_short, reg_rem_mask_short);
+        compute(0, 0, true);
+        advance_ptrs_reg(reg_tmp);
 
-    const auto rmode = conf_.attr()->round_mode_;
+        L(prologue_loop_end);
+        rewind_ptrs();
+    }
+    L(prologue_end);
+
+    // Main loop
+    Label main_loop_end;
+    {
+        cmp(reg_len, OC_);
+        jle(main_loop_end, T_NEAR);
+
+        Label main_loop;
+        L(main_loop); {
+            size_t OC_loop, OC_tail;
+            if (OC_ < max_unroll * vlen) {
+                // Fully unroll small loops
+                OC_loop = 0;
+                OC_tail = OC_;
+            }
+            else {
+                OC_loop = vlen * def_unroll;
+                OC_tail = OC_ % OC_loop;
+            }
+
+            assert(!!OC_loop || !!OC_tail);
+
+            if (OC_tail % vlen) {
+                int vlen_tail = OC_tail % vlen;
+                unsigned tail_mask = (1 << vlen_tail) - 1;
+                mov(reg_tmp, tail_mask);
+                kmovq(kreg_rem_mask_short, reg_tmp);
+            }
+
+            if (OC_loop) {
+                mov(reg_tmp, rnd_dn(OC_, OC_loop));
+                Label oc_loop;
+                L(oc_loop); {
+                    for (size_t offset = 0; offset < OC_loop; offset += vlen)
+                        compute(offset, offset / vlen, false);
+                    advance_ptrs_imm(OC_loop);
+                    sub(reg_tmp, OC_loop);
+                    jnz(oc_loop);
+                }
+            }
+
+            if (OC_tail) {
+                for (size_t offset = 0; offset < OC_tail; offset += vlen) {
+                    bool use_mask = (offset + vlen) > OC_tail;
+                    compute(offset, offset / vlen, use_mask);
+                }
+                advance_ptrs_imm(OC_tail);
+            }
+
+            rewind_ptrs();
+            sub(reg_len, OC_);
+            cmp(reg_len, OC_);
+            jge(main_loop, T_NEAR);
+        }
+    }
+    L(main_loop_end);
+
+    // Epilogue loop
+    Label epilogue_end;
+    {
+        cmp(reg_len, 0);
+        je(epilogue_end, T_NEAR);
+
+        Label epilogue_loop, epilogue_loop_tail;
+        cmp(reg_len, vlen);
+        jle(epilogue_loop_tail, T_NEAR);
+        L(epilogue_loop); {
+            compute(0, 0, false);
+            sub(reg_len, vlen);
+            advance_ptrs_imm(vlen);
+            cmp(reg_len, vlen);
+            jge(epilogue_loop, T_NEAR);
+        }
+
+        L(epilogue_loop_tail);
+        mov(reg_tmp, reg_len); // reg_tmp is rcx, and we need cl for the shift
+        mov(reg_rem_mask_short, 1);
+        shl(reg_rem_mask_short, cl); // reg_tmp == rcx and reg_tail < vlen
+        sub(reg_rem_mask_short, 1);
+        jz(epilogue_end, T_NEAR);
+        kmovq(kreg_rem_mask_short, reg_rem_mask_short);
+        compute(0, 0, true);
+    }
 
-    const bool use_fast_path = true
-        && scale_idx_mult == 0
-        && jcp.ngroups == 1
-        && !jcp.with_bias;
-    const float fast_path_alpha = scales[0] / jcp.wei_adj_scale;
+    L(epilogue_end);
 
-    const auto &post_ops = conf_.attr()->post_ops_;
+    postamble();
+
+    ker_ = getCode<decltype(ker_)>();
+}
+
+template <data_type_t src_type, data_type_t dst_type>
+void _gemm_x8s8s32x_convolution_fwd_t<src_type, dst_type>::pp_ker_t::operator ()
+    (dst_data_t *dst, const acc_data_t *acc, const char *bias,
+        const float *scales, float nslope, float sum_scale, float signed_scale,
+        int g, size_t start, size_t end)
+{
+    using math::get_bias;
+
+    if (end <= start)
+        return;
+
+    if (ker_) {
+        // JIT
+        ker_args args;
+        size_t oc_offset = start % OC_;
+        size_t os_offset = start / OC_;
+        args.acc = acc + start;
+        args.dst = dst + os_offset * dst_os_stride_ + oc_offset;
+        args.bias = bias + (g * jcp_.oc + oc_offset) * bias_data_type_size_;
+        args.scales = scales + scale_idx_mult_ * (g * jcp_.oc + oc_offset);
+        args.nslope = nslope;
+        args.sum_scale = sum_scale;
+        args.signed_scale = signed_scale;
+        args.len = end - start;
+        args.oc_offset = oc_offset;
+        ker_(&args);
+    }
+    else {
+        // Fallback
+        const size_t first_oc = start % OC_;
+        const size_t last_oc = (end - 1) % OC_;
+        const size_t first_os = start / OC_;
+        const size_t last_os = (end - 1) / OC_;
+        for (size_t os = first_os; os <= last_os; os++) {
+            const size_t start_oc = (os == first_os) ? first_oc : 0;
+            const size_t end_oc = (os == last_os) ? last_oc : OC_ - 1;
+            for (size_t oc = start_oc; oc <= end_oc; oc++) {
+                const size_t acc_off = os * jcp_.oc + oc;
+                const size_t dst_off = os * dst_os_stride_ + oc;
+
+                float d = (float)(acc[acc_off]);
+                if (jcp_.signed_input)
+                    d *= signed_scale;
+
+                if (do_bias_)
+                    d += get_bias(bias, g * jcp_.oc + oc,
+                        bias_data_type_);
+
+                d *= scales[(g * jcp_.oc + oc) * scale_idx_mult_];
+                if (do_sum_)
+                    d += sum_scale * dst[dst_off];
+                if (do_relu_ && d < 0)
+                    d *= nslope;
+                dst[dst_off] = qz_a1b0<float, dst_data_t>()(d, rmode_);
+            }
+        }
+    }
+};
+
+template <data_type_t src_type, data_type_t dst_type>
+void _gemm_x8s8s32x_convolution_fwd_t<src_type, dst_type>::
+execute_forward_thr(const int ithr, const int nthr, const src_data_t *src_base,
+        const wei_data_t *wei_base, const char *bia_base, dst_data_t *dst_base,
+        const memory_tracking::grantor_t &scratchpad) const {
+    const jit_gemm_conv_conf_t &jcp = this->pd()->jcp_;
+
+    const auto src_md = memory_desc_wrapper(pd()->src_pd());
+    const size_t src_mb_stride = src_md.blk_off(1);
+    const size_t src_g_stride = src_md.blk_off(0, 1) * jcp.ic;
+
+    const auto wei_md = memory_desc_wrapper(pd()->weights_pd(0));
+    const size_t wei_g_stride = pd()->with_groups() ? wei_md.blk_off(1) : 0;
+
+    const auto dst_md = memory_desc_wrapper(pd()->dst_pd());
+    const size_t dst_mb_stride = dst_md.blk_off(1);
+    const size_t dst_g_stride = dst_md.blk_off(0, 1) * jcp.oc;
+
+    const float *scales = pd()->attr()->output_scales_.scales_;
+
+    const auto &post_ops = pd()->attr()->post_ops_;
     const bool do_sum = post_ops.contain(primitive_kind::sum, 0);
     const float sum_scale = do_sum ? post_ops.entry_[0].sum.scale : 0;
 
-    float nslope = jcp.with_relu ? jcp.relu_negative_slope : 0;
-    int entry_idx = -1;
+    float nslope = 0;
     for (int idx = 0; idx < post_ops.len_; ++idx) {
         const auto &e = post_ops.entry_[idx];
         if (e.is_relu(true, false)) {
-            entry_idx = idx;
             nslope = e.eltwise.alpha;
             break;
         }
     }
-    const bool do_relu = jcp.with_relu || (entry_idx >= 0);
-
-    uint8_t *_col = (uint8_t *)scratchpad;
-    ptrdiff_t offset = (ptrdiff_t)jcp.im2col_sz * sizeof(uint8_t) * jcp.nthr;
-    acc_data_t *_acc = (acc_data_t *)(scratchpad + offset);
 
-    uint8_t *col = _col + (ptrdiff_t)ithr * jcp.im2col_sz;
-    acc_data_t *acc = _acc + (ptrdiff_t)ithr * jcp.os * jcp.oc;
+    auto col = scratchpad.get<uint8_t>(key_conv_gemm_col)
+        + (ptrdiff_t)ithr * jcp.im2col_sz;
+    auto acc = scratchpad.get<acc_data_t>(key_conv_int_dat_in_acc_dt)
+        + (ptrdiff_t)ithr * jcp.os * jcp.oc;
 
-    offset = (ptrdiff_t)jcp.ngroups * jcp.ks * jcp.ic * jcp.oc;
+    const ptrdiff_t offset = (ptrdiff_t)jcp.ngroups * jcp.ks * jcp.ic * jcp.oc;
     const int32_t *_wei_comp = (const int32_t *)(wei_base + offset);
 
     int n{0}, g{0};
@@ -147,62 +595,40 @@ void _gemm_x8s8s32x_convolution_fwd_t<with_relu, src_type,
         const int M = jcp.oc;
         const int K = jcp.ks * jcp.ic;
         const int N = jcp.os;
-        const CBLAS_OFFSET offsetc
-                = jcp.signed_input ? CblasColOffset : CblasFixOffset;
+        const int LD = M * jcp.ngroups;
         const int8_t off_a = 0, off_b = 0;
         const int32_t off_c = 0;
+        const float onef = 1.0, zerof = 0.0;
+
+        mkldnn_gemm_s8u8s32("N", "N", jcp.signed_input ? "C" : "F",
+                &M, &N, &K, &onef, wei, &LD, &off_a,
+                jcp.im2col_sz ? col : (uint8_t *)src, &K, &off_b,
+                &zerof, acc, &M, jcp.signed_input ? wei_comp : &off_c);
+
+        parallel(0, [&](int ithr, int nthr) {
+            size_t start, end;
+            balance211((size_t)jcp.os * jcp.oc, nthr, ithr, start, end);
+            (*pp_ker_)(dst, acc, bia_base, scales, nslope, sum_scale,
+                    jcp.signed_input ? 1.f / jcp.wei_adj_scale : 1.f,
+                    g, start, end);
+        });
 
-        cblas_gemm_s8u8s32(CblasColMajor, CblasNoTrans, CblasNoTrans, offsetc,
-                M, N, K, 1.0f, wei, M * jcp.ngroups, off_a,
-                jcp.im2col_sz ? col : (uint8_t *)src, K, off_b, 0.0f, acc, M,
-                jcp.signed_input ? wei_comp : &off_c);
-
-        if (use_fast_path) {
-            auto body = [&](int o) {
-                float d = fast_path_alpha * acc[o] + sum_scale * dst[o];
-                if (do_relu && d < 0) d *= nslope;
-                dst[o] = qz_a1b0<float, dst_data_t>()(d, rmode);
-            };
-
-#           if MKLDNN_THR == MKLDNN_THR_OMP && _OPENMP >= 201307
-#           pragma omp parallel for simd
-            for (int o = 0; o < jcp.os * jcp.oc; ++o) body(o);
-#           else
-            parallel_nd(jcp.os * jcp.oc, body);
-#           endif
-        } else {
-            parallel_nd(jcp.os, jcp.oc, [&](const int os, const int oc) {
-                const size_t acc_off = os * jcp.oc + oc;
-                float d = (float)acc[acc_off];
-                if (jcp.signed_input)
-                    d /= jcp.wei_adj_scale;
-
-                if (jcp.with_bias)
-                    d += get_bias(g * jcp.oc + oc);
-
-                d *= scales[(g * jcp.oc + oc) * scale_idx_mult];
-
-                const size_t dst_off = os * dst_os_stride + oc;
-                if (do_sum) d += sum_scale * dst[dst_off];
-                if (do_relu && d < 0) d *= nslope;
-                dst[dst_off] = qz_a1b0<float, dst_data_t>()(d, rmode);
-            });
-        }
         nd_iterator_step(n, jcp.mb, g, jcp.ngroups);
     }
-#endif
 }
 
 template <data_type_t dst_type>
-void _gemm_u8s8s32x_convolution_bwd_data_t<dst_type>::execute_backward_data() {
+void _gemm_u8s8s32x_convolution_bwd_data_t<dst_type>::
+execute_backward_data() const {
     auto diff_dst_base = reinterpret_cast<const diff_dst_data_t *>
             (this->input_memory(0));
     auto wei_base = reinterpret_cast<const wei_data_t *>(this->input_memory(1));
     auto bia_base = reinterpret_cast<const char *>(this->input_memory(2));
     auto diff_src_base = reinterpret_cast<diff_src_data_t *>(this->memory());
 
-    jit_gemm_conv_conf_t &jcp = this->conf_.jcp_;
-    char *scratchpad = (char *)this->scratchpad_->get();
+    auto scratchpad = this->scratchpad();
+
+    const jit_gemm_conv_conf_t &jcp = this->pd()->jcp_;
 
     parallel(jcp.nthr, [&](const int ithr, const int nthr) {
         execute_backward_data_thr(ithr, nthr, diff_dst_base, wei_base,
@@ -211,53 +637,36 @@ void _gemm_u8s8s32x_convolution_bwd_data_t<dst_type>::execute_backward_data() {
 }
 
 template <data_type_t dst_type>
-void _gemm_u8s8s32x_convolution_bwd_data_t<dst_type>
-::execute_backward_data_thr(const int ithr, const int nthr,
+void _gemm_u8s8s32x_convolution_bwd_data_t<dst_type>::
+execute_backward_data_thr(const int ithr, const int nthr,
         const diff_dst_data_t *diff_dst_base, const wei_data_t *wei_base,
-        const char *bia_base, diff_src_data_t *diff_src_base, char *scratchpad)
+        const char *bia_base, diff_src_data_t *diff_src_base,
+        const memory_tracking::grantor_t &scratchpad) const
 {
-#if USE_MKL_IGEMM
-    jit_gemm_conv_conf_t &jcp = this->conf_.jcp_;
+    const jit_gemm_conv_conf_t &jcp = this->pd()->jcp_;
 
-    const auto diff_dst_md = memory_desc_wrapper(conf_.diff_dst_pd());
+    const auto diff_dst_md = memory_desc_wrapper(pd()->diff_dst_pd());
     const size_t diff_dst_mb_stride = diff_dst_md.blk_off(1);
     const size_t diff_dst_g_stride = diff_dst_md.blk_off(0, 1) * jcp.oc;
 
-    const auto wei_md = memory_desc_wrapper(conf_.weights_pd(0));
-    const size_t wei_g_stride = conf_.with_groups() ? wei_md.blk_off(1) : 0;
+    const auto wei_md = memory_desc_wrapper(pd()->weights_pd(0));
+    const size_t wei_g_stride = pd()->with_groups() ? wei_md.blk_off(1) : 0;
 
-    const auto diff_src_md = memory_desc_wrapper(conf_.diff_src_pd());
+    const auto diff_src_md = memory_desc_wrapper(pd()->diff_src_pd());
     const size_t diff_src_mb_stride = diff_src_md.blk_off(1);
     const size_t diff_src_g_stride = diff_src_md.blk_off(0, 1) * jcp.ic;
     const size_t diff_src_os_stride = diff_src_md.blk_off(0, 0, 0, 1);
 
-    auto get_bias = [=, &bia_base](size_t off) -> acc_data_t {
-#       define CASE(dt) case dt: return (acc_data_t)\
-        (*((const prec_traits<dt>::type *)bia_base + off))
-        switch (conf_.desc()->bias_desc.data_type) {
-        CASE(data_type::s8);
-        CASE(data_type::u8);
-        CASE(data_type::s32);
-        CASE(data_type::f32);
-        default: assert(!"unimplemented");
-        }
-#       undef CASE
-        return 0;
-    };
-
     /* scale_idx_mult = 1 for per_oc scales and 0, otherwise */
-    const int scale_idx_mult = conf_.attr()->output_scales_.mask_ == (1 << 1);
-    const float *scales = conf_.attr()->output_scales_.scales_;
-    const auto rmode = conf_.attr()->round_mode_;
+    const int scale_idx_mult = pd()->attr()->output_scales_.mask_ == (1 << 1);
+    const float *scales = pd()->attr()->output_scales_.scales_;
+    const auto rmode = pd()->attr()->round_mode_;
     const size_t work_amount = jcp.ngroups * jcp.mb;
 
-    acc_data_t *_col = (acc_data_t *)scratchpad;
-    ptrdiff_t offset = (ptrdiff_t)jcp.im2col_sz
-                                    * sizeof(acc_data_t) * jcp.nthr;
-    acc_data_t *_acc = (acc_data_t *)(scratchpad + offset);
-
-    acc_data_t *col = _col + (ptrdiff_t)ithr * jcp.im2col_sz;
-    acc_data_t *acc = _acc + (ptrdiff_t)ithr * jcp.is * jcp.ic;
+    auto col = scratchpad.get<acc_data_t>(key_conv_gemm_col)
+        + (ptrdiff_t)ithr * jcp.im2col_sz;
+    auto acc = scratchpad.get<acc_data_t>(key_conv_int_dat_in_acc_dt)
+        + (ptrdiff_t)ithr * jcp.is * jcp.ic;
 
     int n{0}, g{0};
     size_t start = 0, end = 0;
@@ -277,11 +686,12 @@ void _gemm_u8s8s32x_convolution_bwd_data_t<dst_type>
         const int K = jcp.oc;
         const int8_t off_a = 0, off_b = 0;
         const int32_t off_c = 0;
+        const float onef = 1.0, zerof = 0.0;
+        const int LD = K * jcp.ngroups;
 
-        cblas_gemm_s8u8s32(CblasColMajor, CblasTrans, CblasNoTrans,
-                CblasFixOffset, M, N, K, 1., wei, K * jcp.ngroups, off_a,
-                diff_dst, K * jcp.ngroups, off_b, 0., jcp.im2col_sz ? col
-                : acc, M, &off_c);
+        mkldnn_gemm_s8u8s32("T", "N", "F", &M, &N, &K, &onef,
+                wei, &LD, &off_a, diff_dst, &LD, &off_b,
+                &zerof, jcp.im2col_sz ? col : acc, &M, &off_c);
 
         if (jcp.im2col_sz)
             jit_gemm_convolution_utils::col2im_s32(jcp, col, acc);
@@ -289,7 +699,8 @@ void _gemm_u8s8s32x_convolution_bwd_data_t<dst_type>
         parallel_nd(jcp.is, jcp.ic, [&](int is, int ic) {
             float d = (float)acc[is * jcp.ic + ic];
             if (jcp.with_bias)
-                d += get_bias(g * jcp.ic + ic);
+                d += get_bias(bia_base, g * jcp.ic + ic,
+                        pd()->desc()->bias_desc.data_type);
             d *= scales[(g * jcp.ic + ic) * scale_idx_mult];
             const size_t diff_src_off = is * diff_src_os_stride + ic;
             diff_src[diff_src_off] =
@@ -297,28 +708,19 @@ void _gemm_u8s8s32x_convolution_bwd_data_t<dst_type>
         });
         nd_iterator_step(n, jcp.mb, g, jcp.ngroups);
     }
-#endif
 }
 
 using namespace data_type;
 
-template struct _gemm_x8s8s32x_convolution_fwd_t<true, u8, f32>;
-template struct _gemm_x8s8s32x_convolution_fwd_t<true, u8, s32>;
-template struct _gemm_x8s8s32x_convolution_fwd_t<true, u8, s8>;
-template struct _gemm_x8s8s32x_convolution_fwd_t<true, u8, u8>;
-template struct _gemm_x8s8s32x_convolution_fwd_t<false, u8, f32>;
-template struct _gemm_x8s8s32x_convolution_fwd_t<false, u8, s32>;
-template struct _gemm_x8s8s32x_convolution_fwd_t<false, u8, s8>;
-template struct _gemm_x8s8s32x_convolution_fwd_t<false, u8, u8>;
-
-template struct _gemm_x8s8s32x_convolution_fwd_t<true, s8, f32>;
-template struct _gemm_x8s8s32x_convolution_fwd_t<true, s8, s32>;
-template struct _gemm_x8s8s32x_convolution_fwd_t<true, s8, s8>;
-template struct _gemm_x8s8s32x_convolution_fwd_t<true, s8, u8>;
-template struct _gemm_x8s8s32x_convolution_fwd_t<false, s8, f32>;
-template struct _gemm_x8s8s32x_convolution_fwd_t<false, s8, s32>;
-template struct _gemm_x8s8s32x_convolution_fwd_t<false, s8, s8>;
-template struct _gemm_x8s8s32x_convolution_fwd_t<false, s8, u8>;
+template struct _gemm_x8s8s32x_convolution_fwd_t<u8, f32>;
+template struct _gemm_x8s8s32x_convolution_fwd_t<u8, s32>;
+template struct _gemm_x8s8s32x_convolution_fwd_t<u8, s8>;
+template struct _gemm_x8s8s32x_convolution_fwd_t<u8, u8>;
+
+template struct _gemm_x8s8s32x_convolution_fwd_t<s8, f32>;
+template struct _gemm_x8s8s32x_convolution_fwd_t<s8, s32>;
+template struct _gemm_x8s8s32x_convolution_fwd_t<s8, s8>;
+template struct _gemm_x8s8s32x_convolution_fwd_t<s8, u8>;
 
 template struct _gemm_u8s8s32x_convolution_bwd_data_t<f32>;
 template struct _gemm_u8s8s32x_convolution_bwd_data_t<s32>;