/*******************************************************************************
-* Copyright 2018 Intel Corporation
+* Copyright 2017-2018 Intel Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* limitations under the License.
*******************************************************************************/
+#include "jit_uni_i8i8_pooling.hpp"
+
#include <math.h>
#include "mkldnn_types.h"
#include "jit_generator.hpp"
-#include "jit_uni_i8i8_pooling.hpp"
namespace mkldnn {
namespace impl {
using namespace mkldnn::impl::types;
using namespace alg_kind;
-struct call_params_t {
- const char *src_i8;
- const char *dst_i8;
- size_t kw_range;
- size_t kh_range;
- float idivider;
-};
-
template <cpu_isa_t isa>
-struct jit_uni_i8i8_pool_fwd_ker_t : public jit_generator {
- DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_uni_i8i8_pool_fwd_ker_t)
-
+struct jit_uni_i8i8_pooling_fwd_ker_t: public jit_generator {
+ DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_uni_i8i8_pooling_fwd_ker_t)
+
+ struct call_params_t {
+ const char *src_i8;
+ const char *dst_i8;
+ size_t kw_range;
+ size_t kh_range;
+ float idivider;
+ };
+
+ using Vmm = typename cpu_isa_traits<isa>::Vmm;
+ Xmm xreg(int idx) const { return Xmm(idx); }
+ Ymm yreg(int idx) const { return Ymm(xreg(idx).getIdx()); }
+ Vmm vreg(int idx) const { return Vmm(xreg(idx).getIdx()); }
+
+ // Rounding modes for axv2
+ enum:uint8_t { rnd_op_nearest = 0x0 };
+
+ // In case of avx2 with data type i8 we need to use
+ // maskmovdqu instruction which has its destination hardcoded in rdi.
+ // Windows ABI: abi_param1 is rcx - nothing to do else
+ // Unix ABI: abi_param1 is rdi - copy it to rcx and use it as abi_param1
+ Reg64 reg_param = rcx; // Our "unified abi_param1"
Reg64 reg_ptr_src_i8 = r8;
Reg64 reg_ptr_dst_i8 = r9;
+ Reg64 reg_ptr_maskmovdqu_dst = rdi; // store destination - must be rdi
Reg64 ki = r10;
Reg64 kj = r11;
Reg64 aux_reg_src_w = rbx;
Reg64 reg_tmp = rdx;
- Reg64 reg_src_64 = r15;
- Reg32 reg_src_32 = r15d;
- Reg8 reg_src_8 = r15b;
- size_t sizeof_src_dt() const { return data_type_size(jpp.src_dt); }
- size_t sizeof_dst_dt() const { return data_type_size(jpp.dst_dt); }
+ Reg64 reg_mask = r15;
- using Vmm = typename utils::conditional3<isa == sse42, Xbyak::Xmm,
- isa == avx2, Xbyak::Ymm, Xbyak::Zmm>::type;
+ Opmask k_cmp_mask = Opmask(7);
- Xmm xmm_tmp = Xmm(0);
- Vmm vreg_tmp = Vmm(14);
- Vmm vreg_zeros = Vmm(15);
-
- /* max pooling */
- Vmm vmm_src(int jj, int ii) {
- return Vmm(2*jj + ii);
+ Opmask mask(int idx) {
+ return Opmask(6 - idx);
}
- Xmm xmm_src(int jj) {
- return Xmm(2*jj);
- }
+ // ref to any of XYZ-regs via xreg/yreg/vreg functions
+ Xmm xmm_tmp = xreg(0); // temp to init vreg_tmp
+ Vmm vreg_tmp = vreg(0); // max pooling : holds minimum values for data_type
+ Vmm vreg_zeros = vreg(1);
- Vmm vmm_dst(int jj, int ii) {
- return Vmm(2*jj + ii + 2 * jpp.ur_c);
- }
+ // only in case of <isa> == avx2
+ Vmm vreg_mask = vreg(2); // full byte-mask
+ Xmm xreg_mask_lo = xreg(2); // low 128-bits part of byte-mask (alias for xmm part of vreg_mask)
+ Xmm xreg_mask_hi = xreg(3); // "max" - high 128-bits part of byte-mask (stored separately)
+ Xmm xreg_mask_q = xreg(3); // "avg" - 1/4 part of the mask for s8/u8 operations
+ Vmm vreg_mask_q = vreg(3); // "avg" - 1/4 part for non-zero tails
- Xmm xmm_dst(int jj) {
- return Xmm(2*jj + 2 * jpp.ur_c);
- }
+ enum:int {vidx_base = isa == avx2 ? 4 : 2};
+ Vmm base_vr(int idx) const { return vreg(vidx_base + idx); }
- /* avg pooling */
- Vmm vmm_src_s32(int jj, int ii) {
- return Vmm(2*jj + ii);
- }
-
- Xmm xmm_src_s32(int jj, int ii) {
- return Xmm(2*jj + ii);
- }
-
- Vmm vmm_dst_s32(int jj, int ii) {
- return Vmm(2*jj + ii + 2 * jpp.ur_c);
- }
-
- Ymm ymm_dst_s32(int jj, int ii) {
- return Ymm(2*jj + ii + 2 * jpp.ur_c);
- }
+ size_t sizeof_src_dt() const { return data_type_size(jpp.src_dt); }
+ size_t sizeof_dst_dt() const { return data_type_size(jpp.dst_dt); }
- Xmm xmm_dst_s32(int jj, int ii) {
- return Xmm(2*jj + ii + 2 * jpp.ur_c);
- }
+ /* max pooling */
+ Vmm vreg_src(int idx) const { return base_vr(idx); } // [0 .. ur_c-1]
+ Vmm vreg_dst(int idx) const { return base_vr(jpp.ur_c + idx); } // [ur_c .. 2*ur_c-1]
- Vmm vmm_dst_f32(int jj, int ii) {
- return Vmm(2*jj + ii + 4 * jpp.ur_c);
- }
+ /* avg pooling */
+ // s32 used for processing of s8/u8 data
+ // thus we need to take into account ratio of sizes s32/i8 = 4
+ static constexpr data_type_t avg_proc_dt = data_type::s32;
+ enum:int {
+ s32_to_i8_ratio = sizeof(typename prec_traits<avg_proc_dt>::type)
+ / sizeof(typename prec_traits<data_type::u8>::type),
+ max_num_ll = s32_to_i8_ratio
+ };
+ Vmm vreg_src_s32(int jj, int ll) { return base_vr(3*max_num_ll*jj + ll + 0*max_num_ll); } // ll: 0..4 [0..3]
+ Vmm vreg_dst_s32(int jj, int ll) { return base_vr(3*max_num_ll*jj + ll + 1*max_num_ll); } // ll: 0..4 [4..7]
+ Vmm vreg_dst_f32(int jj, int ll) { return base_vr(3*max_num_ll*jj + ll + 2*max_num_ll); } // ll: 0..4 [8..11]
void (*ker_)(const call_params_t *);
jit_pool_conf_t jpp;
void init_tmp_reg();
+ void init_mask();
+
+ void load_vreg_mask_q(int ll) {};
+
+ void load_src_max_op(int jj, int ll, size_t offset, bool masked, uint64_t msk);
+ void load_src_avg_op(int jj, int ll, size_t offset, bool masked, uint64_t msk);
+ void load_src(int jj, int ll, int c_tail);
- void load_src(int jj, int c_step);
- void store_dst(int jj, int c_step);
+ void store_dst_max_op(int jj, int ll, size_t offset, bool masked, uint64_t msk);
+ void store_dst_avg_op(int jj, int ll, size_t offset, bool masked, uint64_t msk);
+ void store_dst(int jj, int ll, int c_tail);
- void compute_avg_step(int ur_c, int c_step);
- void compute_max_step(int ur_c, int c_step);
- void compute_step(int ur_c, int c_step);
+ void compute_avg_step(int ur_c, int c_tail);
+ void compute_max_op(const int jj);
+ void compute_max_step(int ur_c, int c_tail);
+ void compute_step(int ur_c, int c_tail);
void compute_c_block();
void generate();
const pooling_desc_t &pd, const memory_desc_wrapper &src_d,
const memory_desc_wrapper &dst_d);
- jit_uni_i8i8_pool_fwd_ker_t(const jit_pool_conf_t &jpp_)
+ jit_uni_i8i8_pooling_fwd_ker_t(const jit_pool_conf_t &jpp_)
: jpp(jpp_) {
generate();
ker_ = reinterpret_cast<decltype(ker_)>(const_cast<uint8_t*>(
}
};
+template <>
+void jit_uni_i8i8_pooling_fwd_ker_t<avx2>::load_vreg_mask_q(int ll) {
+
+ // extract ll-th part of mask (ll-th QWORD)
+ vpblendd(vreg_mask_q, vreg_zeros, vreg_mask, 0x3 << ll); // 0x3 - mask for 2 x DWORD
+
+ // Move mask from ll-th pos to 0-th pos
+ if (ll>0)
+ vpermq(vreg_mask_q, vreg_mask_q, ll);
+};
+
+template <>
+void jit_uni_i8i8_pooling_fwd_ker_t<avx2>::load_src_max_op(int jj, int ll,
+ size_t offset, bool masked, uint64_t msk) {
+ using namespace data_type;
+
+ if (masked) {
+ if (jpp.src_dt == s32) {
+ vpblendd(vreg_src(jj), vreg_tmp, ptr[aux_reg_src_w + offset], static_cast<uint8_t>(msk));
+ } else {
+ vpblendvb(vreg_src(jj), vreg_tmp, ptr[aux_reg_src_w + offset], vreg_mask);
+ }
+ } else
+ vmovups(vreg_src(jj), ptr[aux_reg_src_w + offset]);
+};
+
+template <>
+void jit_uni_i8i8_pooling_fwd_ker_t<avx512_core>::load_src_max_op(int jj, int ll,
+ size_t offset, bool masked, uint64_t msk) {
+ using namespace data_type;
+
+ if (masked) {
+ if (jpp.src_dt == s32)
+ vmovups(vreg_src(jj) | mask(0), ptr[aux_reg_src_w + offset]);
+ else
+ vmovdqu8(vreg_src(jj) | mask(0), ptr[aux_reg_src_w + offset]);
+ } else
+ vmovups(vreg_src(jj), ptr[aux_reg_src_w + offset]);
+};
+
+template <>
+void jit_uni_i8i8_pooling_fwd_ker_t<avx2>::load_src_avg_op(int jj, int ll,
+ size_t offset, bool masked, uint64_t msk) {
+ using namespace data_type;
+
+ // Don't generate useless code
+ if (masked && !msk)
+ return;
+
+ auto load_i8 = [&](bool is_signed, const Vmm& vr_src) {
+
+ // Need to use mask of tail?
+ if (masked) {
+
+ // load ll-th part of mask into vreg_mask_q
+ load_vreg_mask_q(ll);
+
+ // Load by mask from mem into register vr_src
+ vpblendvb(vr_src, vreg_zeros, ptr[aux_reg_src_w + offset], vreg_mask_q);
+
+ // Conversion s8/u8 -> s32
+ if (is_signed)
+ vpmovsxbd(vr_src, vr_src);
+ else
+ vpmovzxbd(vr_src, vr_src);
+ } else {
+
+ // Load from mem into vr_src with conversion
+ if (is_signed)
+ vpmovsxbd(vr_src, ptr[aux_reg_src_w + offset]);
+ else
+ vpmovzxbd(vr_src, ptr[aux_reg_src_w + offset]);
+ }
+ };
+
+ switch (jpp.src_dt) {
+ case s32:
+ if (masked)
+ vpblendd(vreg_src_s32(jj, ll), vreg_zeros, ptr[aux_reg_src_w + offset],
+ static_cast<uint8_t>(msk));
+ else
+ vmovups(vreg_src_s32(jj, ll), ptr[aux_reg_src_w + offset]);
+ break;
+ case s8:
+ load_i8(true, vreg_src_s32(jj, ll));
+ break;
+ case u8:
+ load_i8(false, vreg_src_s32(jj, ll));
+ break;
+ default: assert(!"unsupported src data type");
+ }
+};
+
+template <>
+void jit_uni_i8i8_pooling_fwd_ker_t<avx512_core>::load_src_avg_op(int jj, int ll,
+ size_t offset, bool masked, uint64_t msk) {
+ using namespace data_type;
+
+ // Don't generate useless code
+ if (masked && !msk)
+ return;
+
+ const Vmm& vr_src = masked ?
+ vreg_src_s32(jj, ll) | mask(ll) :
+ vreg_src_s32(jj, ll);
+
+ switch (jpp.src_dt) {
+ case s32:
+ vmovups(vr_src, ptr[aux_reg_src_w + offset]);
+ break;
+ case s8:
+ vpmovsxbd(vr_src, ptr[aux_reg_src_w + offset]);
+ break;
+ case u8:
+ vpmovzxbd(vr_src, ptr[aux_reg_src_w + offset]);
+ break;
+ default: assert(!"unsupported src data type");
+ }
+};
+
template <cpu_isa_t isa>
-void jit_uni_i8i8_pool_fwd_ker_t<isa>::load_src(int jj, int c_step) {
+void jit_uni_i8i8_pooling_fwd_ker_t<isa>::load_src(int jj, int ll, int c_tail) {
using namespace data_type;
- int repeats = isa == sse42 && c_step != 1 ? 2 : 1;
+ int c_block = jpp.c_block;
+ int ur_c = jpp.ur_c;
+
switch (jpp.alg) {
case pooling_max: {
- auto offset = jj*c_step*sizeof_src_dt();
- if (c_step == jpp.c_block) {
- for (int ii = 0; ii < repeats; ii++)
- uni_vmovups(vmm_src(jj, ii), ptr[aux_reg_src_w + offset + (jpp.c_block / 2) * ii * sizeof_src_dt()]);
- } else if (c_step == 1) {
- if (jpp.src_dt == s32) {
- movsd(xmm_src(jj), ptr[aux_reg_src_w + offset]);
- } else {
- mov(reg_src_8, ptr[aux_reg_src_w + offset]);
- movq(xmm_src(jj), reg_src_64);
- }
- }
+ auto offset = jj*c_block*sizeof_src_dt();
+ bool masked = jj == ur_c - 1 && c_tail;
+ load_src_max_op(jj, ll, offset, masked, jpp.tail[0]);
break;
}
case pooling_avg_include_padding:
case pooling_avg_exclude_padding: {
- auto offset = jj*c_step*sizeof_src_dt();
- switch (jpp.src_dt) {
- case s32:
- if (c_step == jpp.c_block) {
- for (int ii = 0; ii < repeats; ii++)
- uni_vmovups(vmm_src_s32(jj, ii), ptr[aux_reg_src_w + offset + (jpp.c_block / 2) * ii * sizeof_src_dt()]);
- } else if (c_step == 1) {
- movsd(xmm_src_s32(jj, 0), ptr[aux_reg_src_w + offset]);
- }
- break;
- case s8:
- if (c_step == jpp.c_block) {
- for (int ii = 0; ii < repeats; ii++) {
- if (isa == sse42)
- movd(xmm_src_s32(jj, ii), ptr[aux_reg_src_w + offset + (jpp.c_block / 2) * ii * sizeof_src_dt()]);
- else
- movq(xmm_src_s32(jj, ii), ptr[aux_reg_src_w + offset + (jpp.c_block / 2) * ii * sizeof_src_dt()]);
-
- uni_vpmovsxbd(vmm_src_s32(jj, ii), xmm_src_s32(jj, ii));
- }
- } else if (c_step == 1) {
- movsx(reg_src_32, ptr[aux_reg_src_w + offset]);
- movq(xmm_src_s32(jj, 0), reg_src_64);
- }
- break;
- case u8:
- if (c_step == jpp.c_block) {
- for (int ii = 0; ii < repeats; ii++) {
- if (isa == sse42)
- movd(xmm_src_s32(jj, ii), ptr[aux_reg_src_w + offset + (jpp.c_block / 2) * ii * sizeof_src_dt()]);
- else
- movq(xmm_src_s32(jj, ii), ptr[aux_reg_src_w + offset + (jpp.c_block / 2) * ii * sizeof_src_dt()]);
-
- uni_vpmovzxbd(vmm_src_s32(jj, ii), xmm_src_s32(jj, ii));
- }
- } else if (c_step == 1) {
- movzx(reg_src_32, ptr[aux_reg_src_w + offset]);
- movq(xmm_src_s32(jj, 0), reg_src_64);
- }
- break;
- default: assert(!"unsupported src data type");
- }
+ auto offset = (ll*(c_block/max_num_ll) + jj*c_block)*sizeof_src_dt();
+ bool masked = jj == ur_c - 1 && c_tail;
+ load_src_avg_op(jj, ll, offset, masked, jpp.tail[ll]);
break;
}
default: assert(!"unsupported algorithm");
}
}
+template <>
+void jit_uni_i8i8_pooling_fwd_ker_t<avx2>::store_dst_max_op(int jj, int ll,
+ size_t offset, bool masked, uint64_t msk) {
+ using namespace data_type;
+
+ int c_block = jpp.c_block;
+
+ if (masked) {
+ switch (jpp.src_dt) {
+ case s32:
+ vpmaskmovd(ptr[reg_ptr_dst_i8 + offset], vreg_mask, vreg_dst(jj));
+ break;
+ case s8:
+ case u8: {
+ // Store low half by mask (bytes 0...15)
+ lea(reg_ptr_maskmovdqu_dst, ptr[reg_ptr_dst_i8 + offset]);
+ maskmovdqu(vreg_dst(jj), xreg_mask_lo);
+
+ // Do we need to store high half (bytes 16...31) ?
+ const uint64_t low_mask = (1ULL << (c_block/2))-1;
+ if (msk & ~low_mask) {
+ vextracti128(Xmm(vreg_dst(jj).getIdx()), vreg_dst(jj), 1);
+ add(reg_ptr_maskmovdqu_dst, c_block / 2);
+ maskmovdqu(vreg_dst(jj), xreg_mask_hi);
+ }
+ } break;
+ default: assert(!"unsupported src data type");
+ }
+ } else
+ vmovups(ptr[reg_ptr_dst_i8 + offset], vreg_dst(jj));
+}
+
+template <>
+void jit_uni_i8i8_pooling_fwd_ker_t<avx512_core>::store_dst_max_op(int jj, int ll,
+ size_t offset, bool masked, uint64_t msk) {
+ using namespace data_type;
+
+ if (masked) {
+ switch (jpp.src_dt) {
+ case s32:
+ vmovups(ptr[reg_ptr_dst_i8 + offset], vreg_dst(jj) | mask(0));
+ break;
+ case s8:
+ case u8:
+ vmovdqu8(ptr[reg_ptr_dst_i8 + offset], vreg_dst(jj) | mask(0));
+ break;
+ default: assert(!"unsupported src data type");
+ }
+ } else
+ vmovups(ptr[reg_ptr_dst_i8 + offset], vreg_dst(jj));
+}
+
+template <>
+void jit_uni_i8i8_pooling_fwd_ker_t<avx2>::store_dst_avg_op(int jj, int ll,
+ size_t offset, bool masked, uint64_t msk){
+ using namespace data_type;
+
+ // Don't generate useless code
+ if (masked && !msk)
+ return;
+
+ auto s32_to_i8 = [&](bool is_signed, const Vmm& vr_dst) {
+
+ // conversion: s32 -> s16/u16 : {8 x s32}{8 x 0} -> {16 x s16/u16}
+ // Result QWORDs (qw0, qw1) permuted: {qw0, 0, qw1, 0}
+ if (is_signed)
+ vpackssdw(vr_dst, vr_dst, vreg_zeros);
+ else
+ vpackusdw(vr_dst, vr_dst, vreg_zeros);
+
+ // Permute qwords to restore original order
+ // {qw0, 0, qw1, 0} -> {qw0, qw1, 0, 0}
+ vpermq(vr_dst, vr_dst, 0x58);
+
+ // conversion: s16/u16 -> s8/u8 : {16 x s16/u16}{16 x 0} -> {32 x s8/u8}
+ // Target QWORD qw = {8 x s8/u8} has proper position: {qw, xx, xx, xx}
+ if (is_signed)
+ vpacksswb(vr_dst, vr_dst, vreg_zeros);
+ else
+ vpackuswb(vr_dst, vr_dst, vreg_zeros);
+
+ };
+
+ auto store_i8 = [&](bool is_signed, bool is_masked, const Vmm& vr_dst) {
+
+ // Conversion s32 -> s8/u8
+ s32_to_i8(is_signed, vr_dst);
+
+ // Need to use mask of tail?
+ if (is_masked) {
+ // load ll-th part of mask into vreg_mask_q
+ load_vreg_mask_q(ll);
+ }
+
+ // store 8 bytes
+ lea(reg_ptr_maskmovdqu_dst, ptr[reg_ptr_dst_i8 + offset]);
+ maskmovdqu(vr_dst, xreg_mask_q);
+ };
+
+ switch (jpp.dst_dt) {
+ case s32:
+ if (masked) {
+ vpmaskmovd(ptr[reg_ptr_dst_i8 + offset], vreg_mask, vreg_dst_s32(jj, ll));
+ } else
+ vmovups(ptr[reg_ptr_dst_i8 + offset], vreg_dst_s32(jj, ll));
+ break;
+ case s8:
+ store_i8(true, masked, vreg_dst_s32(jj, ll));
+ break;
+ case u8:
+ store_i8(false, masked, vreg_dst_s32(jj, ll));
+ break;
+ default: assert(!"unsuppotred dst data_type");
+ }
+}
+
+template <>
+void jit_uni_i8i8_pooling_fwd_ker_t<avx512_core>::store_dst_avg_op(int jj, int ll,
+ size_t offset, bool masked, uint64_t msk) {
+ using namespace data_type;
+
+ // Don't generate useless code
+ if (masked && !msk)
+ return;
+
+ const Vmm& vr_dst = masked ?
+ vreg_dst_s32(jj, ll) | mask(ll) :
+ vreg_dst_s32(jj, ll);
+
+ switch (jpp.dst_dt) {
+ case s32:
+ vmovups(ptr[reg_ptr_dst_i8 + offset], vr_dst);
+ break;
+ case s8:
+ vpmovdb(ptr[reg_ptr_dst_i8 + offset], vr_dst);
+ break;
+ case u8:
+ vpmovusdb(ptr[reg_ptr_dst_i8 + offset], vr_dst);
+ break;
+ default: assert(!"unsupported dst data_type");
+ }
+}
+
+
template <cpu_isa_t isa>
-void jit_uni_i8i8_pool_fwd_ker_t<isa>::store_dst(int jj, int c_step) {
+void jit_uni_i8i8_pooling_fwd_ker_t<isa>::store_dst(int jj, int ll,
+ int c_tail) {
using namespace data_type;
- int repeats = isa == sse42 && c_step != 1 ? 2 : 1;
+ int c_block = jpp.c_block;
+ int ur_c = jpp.ur_c;
+
switch(jpp.alg) {
case pooling_max: {
- auto offset = jj*c_step*sizeof_dst_dt();
- if (c_step == jpp.c_block) {
- for (int ii = 0; ii < repeats; ii++)
- uni_vmovups(ptr[reg_ptr_dst_i8 + offset + (jpp.c_block / 2) * ii * sizeof_dst_dt()], vmm_dst(jj, ii));
- } else if (c_step == 1) {
- if (jpp.src_dt == s32) {
- movq(reg_src_64, xmm_dst(jj));
- mov(ptr[reg_ptr_dst_i8 + offset], reg_src_32);
- } else {
- movq(reg_src_64, xmm_dst(jj));
- mov(ptr[reg_ptr_dst_i8 + offset], reg_src_8);
- }
- }
+ auto offset = jj*c_block*sizeof_dst_dt();
+ bool masked = jj == ur_c - 1 && c_tail;
+ store_dst_max_op(jj, ll, offset, masked, jpp.tail[ll]);
break;
}
case pooling_avg_include_padding:
case pooling_avg_exclude_padding: {
- auto offset = jj*c_step*sizeof_dst_dt();
- switch (jpp.dst_dt) {
- case s32:
- if (c_step == jpp.c_block) {
- for (int ii = 0; ii < repeats; ii++)
- uni_vmovups(ptr[reg_ptr_dst_i8 + offset + (jpp.c_block / 2) * ii * sizeof_dst_dt()], vmm_dst_s32(jj, ii));
- } else if (c_step == 1) {
- movq(reg_src_64, xmm_dst_s32(jj, 0));
- mov(ptr[reg_ptr_dst_i8 + offset], reg_src_32);
- }
- break;
- case s8:
- if (c_step == jpp.c_block) {
- for (int ii = 0; ii < repeats; ii++) {
- uni_vpackssdw(vmm_dst_s32(jj, ii), vmm_dst_s32(jj, ii), vmm_dst_s32(jj, ii));
-
- if (isa != sse42)
- vpermq(ymm_dst_s32(jj, ii), ymm_dst_s32(jj, ii), 0x08);
-
- uni_vpacksswb(xmm_dst_s32(jj, ii), xmm_dst_s32(jj, ii), xmm_dst_s32(jj, ii));
-
- if (isa != sse42)
- movq(ptr[reg_ptr_dst_i8 + offset + (jpp.c_block / 2) * ii * sizeof_dst_dt()], xmm_dst_s32(jj, ii));
- else
- movd(ptr[reg_ptr_dst_i8 + offset + (jpp.c_block / 2) * ii * sizeof_dst_dt()], xmm_dst_s32(jj, ii));
- }
- } else if (c_step == 1) {
- vpackssdw(vmm_dst_s32(jj, 0), vmm_dst_s32(jj, 0), vmm_dst_s32(jj, 0));
- vpacksswb(xmm_dst_s32(jj, 0), xmm_dst_s32(jj, 0), xmm_dst_s32(jj, 0));
- movq(reg_src_64, xmm_dst_s32(jj, 0));
- mov(ptr[reg_ptr_dst_i8 + offset], reg_src_8);
- }
- break;
- case u8:
- if (c_step == jpp.c_block) {
- for (int ii = 0; ii < repeats; ii++) {
- uni_vpackusdw(vmm_dst_s32(jj, ii), vmm_dst_s32(jj, ii), vmm_dst_s32(jj, ii));
-
- if (isa != sse42)
- vpermq(ymm_dst_s32(jj, ii), ymm_dst_s32(jj, ii), 0x08);
-
- uni_vpackuswb(xmm_dst_s32(jj, ii), xmm_dst_s32(jj, ii), xmm_dst_s32(jj, ii));
-
- if (isa != sse42)
- movq(ptr[reg_ptr_dst_i8 + offset + (jpp.c_block / 2) * ii * sizeof_dst_dt()], xmm_dst_s32(jj, ii));
- else
- movd(ptr[reg_ptr_dst_i8 + offset + (jpp.c_block / 2) * ii * sizeof_dst_dt()], xmm_dst_s32(jj, ii));
- }
- } else if (c_step == 1) {
- vpackusdw(vmm_dst_s32(jj, 0), vmm_dst_s32(jj, 0), vmm_dst_s32(jj, 0));
- vpackuswb(xmm_dst_s32(jj, 0), xmm_dst_s32(jj, 0), xmm_dst_s32(jj, 0));
- movq(reg_src_64, xmm_dst_s32(jj, 0));
- mov(ptr[reg_ptr_dst_i8 + offset], reg_src_8);
- }
- break;
- default: assert(!"unsuppotred dst data_type");
- }
+ auto offset = (ll*(c_block/max_num_ll) + jj*c_block)*sizeof_dst_dt();
+ bool masked = jj == ur_c - 1 && c_tail;
+ store_dst_avg_op(jj, ll, offset, masked, jpp.tail[ll]);
break;
}
default: assert(!"unsupported pooling algorithm");
}
}
+template <>
+void jit_uni_i8i8_pooling_fwd_ker_t<avx2>::compute_max_op(const int jj)
+{
+ using namespace data_type;
+ switch (jpp.src_dt) {
+ case s32:
+ vpmaxsd(vreg_dst(jj), vreg_dst(jj), vreg_src(jj));
+ break;
+ case s8:
+ vpmaxsb(vreg_dst(jj), vreg_dst(jj), vreg_src(jj));
+ break;
+ case u8:
+ vpmaxub(vreg_dst(jj), vreg_dst(jj), vreg_src(jj));
+ break;
+ default: assert(!"unsupported src data type");
+ }
+}
+
+template <>
+void jit_uni_i8i8_pooling_fwd_ker_t<avx512_core>::compute_max_op(const int jj)
+{
+ using namespace data_type;
+
+ // Compare
+ switch (jpp.src_dt) {
+ case s32:
+ vpcmpd(k_cmp_mask, vreg_dst(jj), vreg_src(jj), _cmp_lt_os);
+ break;
+ case s8:
+ vpcmpb(k_cmp_mask, vreg_dst(jj), vreg_src(jj), _cmp_lt_os);
+ break;
+ case u8:
+ vpcmpub(k_cmp_mask, vreg_dst(jj), vreg_src(jj), _cmp_lt_os);
+ break;
+ default: assert(!"unsupported src data type");
+ }
+
+ // move max values into vreg_dst
+ if (jpp.src_dt == s32)
+ vpblendmd(vreg_dst(jj) | k_cmp_mask, vreg_dst(jj), vreg_src(jj));
+ else
+ vpblendmb(vreg_dst(jj) | k_cmp_mask, vreg_dst(jj), vreg_src(jj));
+}
+
+
template <cpu_isa_t isa>
-void jit_uni_i8i8_pool_fwd_ker_t<isa>::compute_max_step(int ur_c, int c_step)
+void jit_uni_i8i8_pooling_fwd_ker_t<isa>::compute_max_step(int ur_c, int c_tail)
{
Label l_kw, l_kh;
int iw = jpp.iw;
int c = jpp.c;
- int repeats = isa == sse42 && c_step != 1 ? 2 : 1;
-
- for (int jj = 0; jj < ur_c; jj++) {
- for (int ii = 0; ii < repeats; ii++) {
- uni_vmovups(vmm_dst(jj, ii), vreg_tmp);
- }
- }
+ for (int jj = 0; jj < ur_c; jj++)
+ vmovups(vreg_dst(jj), vreg_tmp);
mov(aux_reg_src_h, reg_ptr_src_i8);
L(l_kw);
{
for (int jj = 0; jj < ur_c; jj++) {
- load_src(jj, c_step);
-
- for (int ii = 0; ii < repeats; ii++) {
- if (jpp.src_dt == data_type::s32) {
- uni_vpmaxsd(vmm_dst(jj, ii), vmm_dst(jj, ii), vmm_src(jj, ii));
- } else {
- if (jpp.src_dt == data_type::s8)
- uni_vpmaxsb(vmm_dst(jj, ii), vmm_dst(jj, ii), vmm_src(jj, ii));
- else
- uni_vpmaxub(vmm_dst(jj, ii), vmm_dst(jj, ii), vmm_src(jj, ii));
- }
- }
+ load_src(jj, 0, c_tail);
+ compute_max_op(jj);
}
add(aux_reg_src_w, c * sizeof_src_dt());
inc(ki);
}
for (int jj = 0; jj < ur_c; jj++)
- store_dst(jj, c_step);
+ store_dst(jj, 0, c_tail);
}
template <cpu_isa_t isa>
-void jit_uni_i8i8_pool_fwd_ker_t<isa>::compute_avg_step(int ur_c, int c_step)
+void jit_uni_i8i8_pooling_fwd_ker_t<isa>::compute_avg_step(int ur_c, int c_tail)
{
using namespace data_type;
int iw = jpp.iw;
int c = jpp.c;
- int repeats = isa == sse42 && c_step != 1 ? 2 : 1;
+ const int num_ll = data_type_size(avg_proc_dt)/data_type_size(jpp.src_dt);
for (int jj = 0; jj < ur_c; jj++) {
- for (int ii = 0; ii < repeats; ii++) {
- uni_vpxor(vmm_src_s32(jj, ii), vmm_src_s32(jj, ii), vmm_src_s32(jj, ii));
- uni_vpxor(vmm_dst_s32(jj, ii), vmm_dst_s32(jj, ii), vmm_dst_s32(jj, ii));
+ for (int ll = 0; ll < num_ll; ll++) {
+ bool masked = jj == ur_c - 1 && c_tail;
+ size_t msk = jpp.tail[ll];
+ if (!(masked && !msk)) {
+ uni_vpxor(vreg_src_s32(jj, ll), vreg_src_s32(jj, ll), vreg_src_s32(jj, ll));
+ uni_vpxor(vreg_dst_s32(jj, ll), vreg_dst_s32(jj, ll), vreg_dst_s32(jj, ll));
+ }
}
}
L(l_kw);
{
for (int jj = 0; jj < ur_c; jj++) {
- load_src(jj, c_step);
-
- for (int ii = 0; ii < repeats; ii++) {
- uni_vpaddd(vmm_dst_s32(jj, ii), vmm_dst_s32(jj, ii), vmm_src_s32(jj, ii));
+ for (int ll = 0; ll < num_ll; ll++) {
+ bool masked = jj == ur_c - 1 && c_tail;
+ size_t msk = jpp.tail[ll];
+ if (!(masked && !msk)) {
+ load_src(jj, ll, c_tail);
+ vpaddd(vreg_dst_s32(jj, ll), vreg_dst_s32(jj, ll),
+ vreg_src_s32(jj, ll));
+ }
}
}
add(aux_reg_src_w, c * sizeof_src_dt());
}
for (int jj = 0; jj < ur_c; jj++) {
- for (int ii = 0; ii < repeats; ii++) {
- uni_vcvtdq2ps(vmm_dst_f32(jj, ii), vmm_dst_s32(jj, ii));
+ for (int ll = 0; ll < num_ll; ll++) {
+ bool masked = jj == ur_c - 1 && c_tail;
+ size_t msk = jpp.tail[ll];
+ if (!(masked && !msk)) {
- if (isa == sse42)
- mulps(vmm_dst_f32(jj, ii), vreg_tmp);
- else
- vfmadd132ps(vmm_dst_f32(jj, ii), vreg_zeros, vreg_tmp);
+ vcvtdq2ps(vreg_dst_f32(jj, ll), vreg_dst_s32(jj, ll));
+ vfmadd132ps(vreg_dst_f32(jj, ll), vreg_zeros, vreg_tmp);
- uni_vcvtps2dq(vmm_dst_s32(jj, ii), vmm_dst_f32(jj, ii));
- }
+ if (isa == avx2) {
+ uni_vroundps(vreg_dst_f32(jj, ll), vreg_dst_f32(jj, ll), rnd_op_nearest);
+ vcvtps2dq(vreg_dst_s32(jj, ll), vreg_dst_f32(jj, ll));
+ } else if (isa >= avx512_common) {
+ // AVX512: use of EVEX-embedded static rounding override
+ vcvtps2dq(vreg_dst_s32(jj, ll) | T_rn_sae, vreg_dst_f32(jj, ll));
+ }
- store_dst(jj, c_step);
+ store_dst(jj, ll, c_tail);
+ }
+ }
}
}
template <cpu_isa_t isa>
-void jit_uni_i8i8_pool_fwd_ker_t<isa>::compute_step(int ur_c, int c_step) {
+void jit_uni_i8i8_pooling_fwd_ker_t<isa>::compute_step(int ur_c, int c_tail) {
switch (jpp.alg) {
case pooling_max:
- compute_max_step(ur_c, c_step); break;
+ compute_max_step(ur_c, c_tail); break;
case pooling_avg_include_padding:
case pooling_avg_exclude_padding:
- compute_avg_step(ur_c, c_step); break;
+ compute_avg_step(ur_c, c_tail); break;
default: assert(!"unsupported pooling algorithm");
}
}
template <cpu_isa_t isa>
-void jit_uni_i8i8_pool_fwd_ker_t<isa>::compute_c_block() {
+void jit_uni_i8i8_pooling_fwd_ker_t<isa>::compute_c_block(){
Label l_main_loop;
- Label l_tail_loop;
- Label exit;
+ int nb_c = jpp.nb_c;
+ int c_block = jpp.c_block;
int ur_c = jpp.ur_c;
+ int ur_c_tail = jpp.ur_c_tail;
+ int c_steps = nb_c / ur_c;
+ int c_tail = jpp.c_tail;
xor_(c_iter, c_iter);
+ if (c_steps > 0) {
+ L(l_main_loop); {
+ compute_step(ur_c, 0);
+ add(reg_ptr_src_i8, ur_c*c_block*sizeof_src_dt());
+ add(reg_ptr_dst_i8, ur_c*c_block*sizeof_dst_dt());
+ inc(c_iter);
+ cmp(c_iter, c_steps);
+ jl(l_main_loop, T_NEAR);
+ }
+ }
- L(l_main_loop);
- {
- cmp(c_iter, jpp.c - ur_c * jpp.c_block);
- jg(l_tail_loop, T_NEAR);
+ if (ur_c_tail != 0) {
+ compute_step(ur_c_tail, c_tail);
+ }
+}
- compute_step(ur_c, jpp.c_block);
+template<>
+void jit_uni_i8i8_pooling_fwd_ker_t<avx2>::init_mask() {
+ using namespace data_type;
+ using cpu_isa = cpu_isa_traits<avx2>;
+
+ // AVX2 mask initialization: mask stored in Ymm-regs
+ auto init = [&](uint64_t bit_mask, bool init_mask_q) {
+ const size_t QW_PER_VREG = cpu_isa::vlen / sizeof(uint64_t);
+
+ uint64_t vmask[QW_PER_VREG];
+ for (size_t i = 0; i < QW_PER_VREG; i++){
+
+ uint64_t qw_vmask=0ULL;
+ const size_t DBITS = 8*sizeof_src_dt();
+ const uint64_t VMSK = 1ULL << (DBITS-1);
+ const size_t D_PER_QW = (8*sizeof(qw_vmask))/DBITS;
+ for (size_t j = 0; j < D_PER_QW; j++) {
+ if (bit_mask & 1)
+ qw_vmask |= VMSK << DBITS * j;
+ bit_mask >>= 1;
+ }
+ vmask[i] = qw_vmask;
+ }
- add(reg_ptr_src_i8, ur_c * jpp.c_block * sizeof_src_dt());
- add(reg_ptr_dst_i8, ur_c * jpp.c_block * sizeof_dst_dt());
- add(c_iter, ur_c * jpp.c_block);
- jmp(l_main_loop);
- }
+ // Put QWORDS with target mask into xmm regs
+ const int xdst_i[QW_PER_VREG] = {
+ xreg_mask_lo.getIdx(),
+ xreg_mask_lo.getIdx(),
+ xreg_mask_hi.getIdx(),
+ xreg_mask_hi.getIdx()
+ };
+ const int xsrc_i[QW_PER_VREG] = {
+ vreg_zeros.getIdx(), // 0-th qword insert in zeros -> {qw0, 0}
+ xreg_mask_lo.getIdx(), // 1-st and 0-th merge -> {qw0,qw1}
+ vreg_zeros.getIdx(),
+ xreg_mask_hi.getIdx()
+ };
+ const uint8 qw_dst_idx[QW_PER_VREG] = {0, 1, 0, 1}; // qword index in 128-bit xreg
+
+ for (size_t i = 0; i < QW_PER_VREG; i++) {
+ mov(reg_mask, vmask[i]);
+ vpinsrq(Xmm(xdst_i[i]), Xmm(xsrc_i[i]), reg_mask, qw_dst_idx[i]);
+ }
- L(l_tail_loop);
- {
- cmp(c_iter, jpp.c - ur_c);
- jg(exit, T_NEAR);
+ // Merge Low (xreg_mask_lo alias for vreg_mask.xreg)
+ // and High (xreg_mask_hi) into full vreg_mask
+ // vreg_mask -> {xreg_mask_hi, vreg_mask.xreg}
+ vinserti128(vreg_mask, vreg_mask, xreg_mask_hi, 1);
- compute_step(ur_c, 1);
+ // Keep only low qword of mask in xreg_mask_q
+ if (init_mask_q) {
+ mov(reg_mask, vmask[0]);
+ vpinsrq(xreg_mask_q, Xmm(vreg_zeros.getIdx()), reg_mask, 0);
+ }
+ };
- add(reg_ptr_src_i8, ur_c * sizeof_src_dt());
- add(reg_ptr_dst_i8, ur_c * sizeof_dst_dt());
- add(c_iter, ur_c);
- jmp(l_tail_loop);
+ uint64_t tail_mask = (1ULL << jpp.c_tail) - 1;
+ switch (jpp.alg) {
+ case pooling_max:
+ // For "max" we need mask only in case of non-zero tail
+ if (tail_mask)
+ init(tail_mask, false);
+ break;
+ case pooling_avg_include_padding:
+ case pooling_avg_exclude_padding:
+ // For "avg" we need mask:
+ // - s32 - in case of the non-zero tail
+ // - s8/u8 - irrespective of the tail
+ switch (jpp.src_dt) {
+ case s32:
+ if (tail_mask)
+ init(tail_mask, false);
+ break;
+ case s8:
+ case u8:
+ init(tail_mask ? tail_mask : ~0ULL, tail_mask == 0);
+ break;
+ default: assert(!"unsupported src data type");
+ }
+ break;
+ default: assert(!"unsupported pooling algorithm");
}
+}
+
+template<>
+void jit_uni_i8i8_pooling_fwd_ker_t<avx512_core>::init_mask() {
- L(exit);
+ for (int ll = 0; ll < max_num_ll; ll++) {
+ mov(reg_mask, jpp.tail[ll]);
+ kmovq(mask(ll), reg_mask);
+ }
}
template <cpu_isa_t isa>
-void jit_uni_i8i8_pool_fwd_ker_t<isa>::init_tmp_reg() {
+void jit_uni_i8i8_pooling_fwd_ker_t<isa>::init_tmp_reg() {
using namespace data_type;
switch (jpp.alg) {
case pooling_avg_include_padding:
case pooling_avg_exclude_padding:
- mov(reg_tmp, ptr[abi_param1 + offsetof(call_params_t, idivider)]);
+ mov(reg_tmp, ptr[reg_param + offsetof(call_params_t, idivider)]);
movq(xmm_tmp, reg_tmp);
- uni_vpbroadcastd(vreg_tmp, xmm_tmp);
+ vpbroadcastd(vreg_tmp, xmm_tmp);
break;
case pooling_max:
switch (jpp.src_dt) {
}
movq(xmm_tmp, reg_tmp);
- if (jpp.src_dt == s32) {
- uni_vpbroadcastd(vreg_tmp, xmm_tmp);
- } else {
- if (isa == avx2) {
- vpbroadcastb(vreg_tmp, xmm_tmp);
- } else {
- movups(vreg_tmp, xmm_tmp);
- uni_vpxor(xmm_tmp, xmm_tmp, xmm_tmp);
- pshufb(vreg_tmp, xmm_tmp);
- }
- }
+ if (jpp.src_dt == s32)
+ vpbroadcastd(vreg_tmp, xmm_tmp);
+ else
+ vpbroadcastb(vreg_tmp, xmm_tmp);
break;
default: assert(!"unsupported pooling algorithm");
}
}
template <cpu_isa_t isa>
-void jit_uni_i8i8_pool_fwd_ker_t<isa>::generate() {
+void jit_uni_i8i8_pooling_fwd_ker_t<isa>::generate() {
preamble();
+#if !defined(_WIN32)
+ // Always use rcx as abi_param1 -
+ // see the note about maskmovdqu near reg_param.
+ mov(rcx, rdi);
+#endif
+
# define READ_PARAM(reg, field) \
- mov(reg, ptr[abi_param1 + offsetof(call_params_t, field)])
+ mov(reg, ptr[reg_param + offsetof(call_params_t, field)])
READ_PARAM(reg_ptr_src_i8, src_i8);
READ_PARAM(reg_ptr_dst_i8, dst_i8);
READ_PARAM(reg_kw, kw_range);
# undef READ_PARAM
- init_tmp_reg();
-
uni_vpxor(vreg_zeros, vreg_zeros, vreg_zeros);
+ init_mask();
+
+ init_tmp_reg();
+
compute_c_block();
postamble();
}
template <cpu_isa_t isa>
-status_t jit_uni_i8i8_pool_fwd_ker_t<isa>::init_conf(jit_pool_conf_t &jpp,
+status_t jit_uni_i8i8_pooling_fwd_ker_t<isa>::init_conf(jit_pool_conf_t &jpp,
const pooling_desc_t &pd, const memory_desc_wrapper &src_d,
const memory_desc_wrapper &dst_d) {
- if (!mayiuse(isa)) {
+ if (!mayiuse(isa))
return status::unimplemented;
- }
jpp.mb = src_d.dims()[0];
jpp.c = src_d.dims()[1];
jpp.src_dt = pd.src_desc.data_type;
jpp.dst_dt = pd.dst_desc.data_type;
- jpp.c_block = jpp.alg == pooling_max ? 32 / (jpp.src_dt == data_type::s32 ? 4 : 1) : 8;
+ // data_type items per one vreg on the <isa>
+ // isa == avx2 : 32 bytes -> 32 for s8/u8, 8 for s32
+ // isa == avx512* : 64 bytes -> 64 for s8/u8, 16 for s32
+ int simd_w = cpu_isa_traits<isa>::vlen / data_type_size(jpp.src_dt);
+
+ jpp.c_block = simd_w;
jpp.c_tail = jpp.c % jpp.c_block;
jpp.nb_c = jpp.c / jpp.c_block;
jpp.ur_c = 1;
- jpp.ur_c_tail = jpp.nb_c - (jpp.nb_c / jpp.ur_c)*jpp.ur_c + (jpp.c_tail != 0);
+ jpp.ur_c_tail = jpp.nb_c - (jpp.nb_c / jpp.ur_c)*jpp.ur_c +
+ (jpp.c_tail != 0);
+
+ size_t tail_mask = (1ULL << jpp.c_tail) - 1;
+
+ switch (jpp.alg) {
+ case pooling_max:
+ jpp.tail[0] = tail_mask;
+ jpp.tail[1] = 0;
+ jpp.tail[2] = 0;
+ jpp.tail[3] = 0;
+ break;
+ case pooling_avg_include_padding:
+ case pooling_avg_exclude_padding: {
+ // avg_proc_dt (s32) defines granularity (because u8/s8 processed as s32)
+ // avx2 : 8, avx512 : 16
+ const size_t msk_gran = cpu_isa_traits<isa>::vlen / data_type_size(avg_proc_dt);
+ const size_t msk_msk = (1ULL << msk_gran) - 1;
+ size_t m = tail_mask;
+ for (size_t ll = 0; ll < max_num_ll; ll++) {
+ jpp.tail[ll] = m & msk_msk;
+ m = m >> msk_gran;
+ }
+ break;
+ }
+ default: return status::unimplemented;
+ }
return status::success;
}
template <cpu_isa_t isa>
status_t jit_uni_i8i8_pooling_fwd_t<isa>::pd_t::jit_conf() {
- return jit_uni_i8i8_pool_fwd_ker_t<isa>::init_conf(jpp_,
+ return jit_uni_i8i8_pooling_fwd_ker_t<isa>::init_conf(jpp_,
desc_, src_pd_.desc(), dst_pd_.desc());
}
template <cpu_isa_t isa>
-jit_uni_i8i8_pooling_fwd_t<isa>::jit_uni_i8i8_pooling_fwd_t(const pd_t *pd,
+jit_uni_i8i8_pooling_fwd_t<isa>::
+jit_uni_i8i8_pooling_fwd_t(const pd_t *apd,
const input_vector &inputs, const output_vector &outputs)
- : cpu_primitive_t(&conf_, inputs, outputs), conf_(*pd), ker_(nullptr) {
- ker_ = new jit_uni_i8i8_pool_fwd_ker_t<isa>(conf_.jpp_);
-}
+ : cpu_primitive_t(apd, inputs, outputs), ker_(nullptr)
+{ ker_ = new jit_uni_i8i8_pooling_fwd_ker_t<isa>(pd()->jpp_); }
template <cpu_isa_t isa>
-jit_uni_i8i8_pooling_fwd_t<isa>::~jit_uni_i8i8_pooling_fwd_t() {
- delete ker_;
-}
+jit_uni_i8i8_pooling_fwd_t<isa>::
+~jit_uni_i8i8_pooling_fwd_t() { delete ker_; }
template <cpu_isa_t isa>
-void jit_uni_i8i8_pooling_fwd_t<isa>::execute_forward() {
+void jit_uni_i8i8_pooling_fwd_t<isa>::execute_forward() const {
auto src_i8 = reinterpret_cast<const char *>(input_memory(0));
auto dst_i8 = reinterpret_cast<char *>(memory());
- const memory_desc_wrapper src_d(conf_.src_pd());
- const memory_desc_wrapper dst_d(conf_.dst_pd());
+ const memory_desc_wrapper src_d(pd()->src_pd());
+ const memory_desc_wrapper dst_d(pd()->dst_pd());
- const auto &jpp = conf_.jpp_;
+ const auto &jpp = pd()->jpp_;
parallel_nd(jpp.mb, jpp.oh, jpp.ow,
- [&](int n, int oh, int ow) {
- const int ih = nstl::max(oh * jpp.stride_h - jpp.t_pad, 0);
- const int iw = nstl::max(ow * jpp.stride_w - jpp.l_pad, 0);
+ [&](int n, int oh, int ow) {
+ const int ih = nstl::max(oh*jpp.stride_h - jpp.t_pad, 0);
+ const int iw = nstl::max(ow*jpp.stride_w - jpp.l_pad, 0);
const int kh_start = nstl::max(0, jpp.t_pad - oh * jpp.stride_h);
const int kh_end = nstl::min(jpp.kh,
- jpp.ih + jpp.t_pad - oh * jpp.stride_h);
+ jpp.ih + jpp.t_pad - oh * jpp.stride_h);
const int kw_start = nstl::max(0, jpp.l_pad - ow * jpp.stride_w);
const int kw_end = nstl::min(jpp.kw,
- jpp.iw + jpp.l_pad - ow * jpp.stride_w);
+ jpp.iw + jpp.l_pad - ow * jpp.stride_w);
- auto p = call_params_t();
+ auto p = typename jit_uni_i8i8_pooling_fwd_ker_t<isa>::call_params_t();
p.src_i8 = &src_i8[
- src_d.blk_off(n, 0, ih, iw) * src_d.data_type_size()];
+ src_d.blk_off(n, 0, ih, iw) * src_d.data_type_size()];
p.dst_i8 = &dst_i8[
- dst_d.blk_off(n, 0, oh, ow) * dst_d.data_type_size()];
- p.kw_range = (size_t) (kw_end - kw_start);
- p.kh_range = (size_t) (kh_end - kh_start);
+ dst_d.blk_off(n, 0, oh, ow) * dst_d.data_type_size()];
+ p.kw_range = (size_t)(kw_end - kw_start);
+ p.kh_range = (size_t)(kh_end - kh_start);
p.idivider = 1.0f / ((jpp.alg == pooling_avg_exclude_padding) ?
- p.kh_range * p.kw_range : jpp.kw * jpp.kh);
+ p.kh_range*p.kw_range : jpp.kw*jpp.kh);
ker_->ker_(&p);
});
}
+// Explicit instantiation only for supported <isa> values.
+//
+template struct jit_uni_i8i8_pooling_fwd_ker_t<avx512_core>;
+template struct jit_uni_i8i8_pooling_fwd_t<avx512_core>;
+
+template struct jit_uni_i8i8_pooling_fwd_ker_t<avx2>;
template struct jit_uni_i8i8_pooling_fwd_t<avx2>;
-template struct jit_uni_i8i8_pooling_fwd_t<sse42>;
}
}