#define CPU_JIT_AVX512_CORE_X8S8S32X_CONV_KERNEL_HPP
#include "c_types_map.hpp"
+#include "memory_tracking.hpp"
+
#include "cpu_memory.hpp"
#include "jit_generator.hpp"
#include "jit_primitive_conf.hpp"
+#include "jit_uni_eltwise.hpp"
+#include "jit_uni_depthwise.hpp"
namespace mkldnn {
namespace impl {
namespace cpu {
-struct jit_avx512_core_x8s8s32x_fwd_kernel : public jit_generator {
- DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_avx512_core_x8s8s32x_conv_fwd_ker_t)
+template<typename Vmm>
+struct _jit_avx512_core_x8s8s32x_fwd_kernel : public jit_generator {
+ DECLARE_CPU_JIT_AUX_FUNCTIONS(_jit_avx512_core_x8s8s32x_conv_fwd_ker_t)
enum { STATE_FIRST_DST_LOAD = 0x1U };
- jit_avx512_core_x8s8s32x_fwd_kernel(jit_conv_conf_t ajcp,
+ _jit_avx512_core_x8s8s32x_fwd_kernel(jit_conv_conf_t ajcp,
const primitive_attr_t &attr) : jcp(ajcp), attr_(attr)
{
generate();
- jit_ker = (void (*)(jit_conv_call_s *))getCode();
+ jit_ker_ = (void (*)(jit_conv_call_s *))getCode();
+ }
+
+ ~_jit_avx512_core_x8s8s32x_fwd_kernel() {
+ for (auto inj : eltwise_injectors)
+ delete inj;
+ eltwise_injectors.clear();
+
+ for (auto inj : depthwise_injectors)
+ delete inj;
+ depthwise_injectors.clear();
}
- static bool post_ops_ok(jit_conv_conf_t &jcp,
- const primitive_attr_t &attr);
- static status_t init_conf(jit_conv_conf_t &jcp,
- const convolution_desc_t &cd,
- cpu_memory_t::pd_t &src_pd,
- cpu_memory_t::pd_t &weights_pd,
- cpu_memory_t::pd_t &dst_pd,
- cpu_memory_t::pd_t &bias_pd,
- const primitive_attr_t &attr,
- int nthreads,
- bool with_relu = false,
- float relu_negative_slope = 0.);
jit_conv_conf_t jcp;
const primitive_attr_t &attr_;
- void (*jit_ker)(jit_conv_call_s *);
+ void (*jit_ker_)(jit_conv_call_s *);
private:
- using reg64_t = const Xbyak::Reg64;
- using zmm_t = const Xbyak::Zmm;
- using xmm_t = const Xbyak::Xmm;
+ nstl::vector<jit_uni_eltwise_injector_f32<avx512_common>*> eltwise_injectors;
+ nstl::vector<jit_uni_depthwise_injector_f32<avx512_common>*> depthwise_injectors;
+
enum {
typesize = sizeof(float),
ker_reg_base_idx = 28,
+ ker_dw_reg_base_idx = 30,
};
- enum {
+ typedef enum {
no_last_block,
last_ic_block,
last_sp_block,
- };
-
- reg64_t reg_inp = r8;
- reg64_t reg_ker = r9;
- reg64_t reg_out = r10;
- reg64_t aux_reg_inp = r11;
- reg64_t reg_ptr_sum_scale = r11;
- reg64_t aux_reg_ker = r12;
- reg64_t reg_owb = r12;
-
- reg64_t reg_scratch = r14;
- reg64_t reg_kj = rax;
- reg64_t reg_overflow = rax;
- reg64_t reg_ptr_scales = rax;
- reg64_t reg_oi = rbx;
- reg64_t reg_bias = rdx;
- reg64_t reg_compensation = reg_scratch;
- reg64_t reg_kh = abi_not_param1;
- reg64_t param = abi_param1;
- reg64_t reg_tmp = rbp;
- reg64_t imm_addr64 = r15;
- reg64_t reg_oc_blocks = rsi;
- reg64_t reg_icb = reg_bias;
- reg64_t reg_bias_alpha = reg_kh;
-
- Xbyak::Opmask ktail_mask = Xbyak::Opmask(2);
-
- zmm_t zmm_tmp = zmm_t(28);
- zmm_t zmm_one = zmm_t(29);
- zmm_t zmm_scales = zmm_t(30);
- zmm_t zmm_shift = zmm_t(30);
- zmm_t zmm_zero = zmm_t(31);
- zmm_t zmm_wei = zmm_t(31);
-
- zmm_t zmm_out(int i_ur, int i_oc) {
+ } ic_block_t;
+
+ /* data regs */
+ const Xbyak::Reg64 reg_ptr_scales = rax;
+ const Xbyak::Reg64 reg_inp = r8;
+ const Xbyak::Reg64 reg_ker = r9;
+ const Xbyak::Reg64 reg_out = r10;
+ const Xbyak::Reg64 aux_reg_inp = r11;
+ const Xbyak::Reg64 reg_ptr_sum_scale = r11;
+ const Xbyak::Reg64 aux_reg_ker = r12;
+ const Xbyak::Reg64 reg_compensation = r14;
+ /* counter regs */
+ const Xbyak::Reg64 reg_bias_alpha = abi_not_param1;
+ const Xbyak::Reg64 reg_oi = rbx;
+ const Xbyak::Reg64 reg_bias = rdx;
+ const Xbyak::Reg64 reg_oc_blocks = rsi;
+ const Xbyak::Reg64 reg_owb = aux_reg_ker;
+ const Xbyak::Reg64 reg_scratch = reg_compensation;
+ const Xbyak::Reg64 reg_kj = reg_ptr_scales;
+ const Xbyak::Reg64 reg_overflow = reg_ptr_scales;
+ const Xbyak::Reg64 reg_icb = reg_bias;
+
+ const Xbyak::Reg64 reg_d_weights = r15;
+ const Xbyak::Reg64 reg_d_bias = r13;
+
+ const Xbyak::Opmask ktail_mask = Xbyak::Opmask(2);
+ const Xbyak::Opmask kblend_mask = Xbyak::Opmask(3);
+
+ const Vmm vmm_wei = Vmm(31);
+ /* used during bias section of store_output */
+ const Vmm vmm_comp = Vmm(30); // only for signed input
+ const Vmm vmm_bias = Vmm(31);
+ /* used during post_op sum section of store_output */
+ const Vmm vmm_prev_dst = Vmm(31);
+ /* used during write-out section of store_output */
+ const Vmm vmm_zero = Vmm(31);
+
+ /* used in compute_ker (but set during prepare_output) */
+ const Vmm vmm_shift = vmm_comp; // only for signed input
+ /* used in compute_ker (but only for pre-VNNI machines) */
+ const Vmm vmm_tmp = Vmm(28); // not used for depthwise
+ const Vmm vmm_one = Vmm(29); // set at start of kernel, not used for depthwise.
+
+ /* registers use only for depthwise
+ groups are always blocked by 16(padded if needed),
+ hence use only Zmm registers */
+ const Xbyak::Zmm zmm_wei = Xbyak::Zmm(31);
+ Xbyak::Zmm zmm_src;
+ Xbyak::Zmm zmm_permute;
+ Xbyak::Zmm zmm_zero_blend; // used only for fast depthwise
+
+ Vmm vmm_out(int i_ur, int i_oc) {
int idx = i_ur + i_oc * jcp.ur_w;
- assert(idx < ker_reg_base_idx);
- return zmm_t(idx);
+ assert(idx < (jcp.is_depthwise
+ ? ker_dw_reg_base_idx : ker_reg_base_idx));
+ return Vmm(idx);
}
- xmm_t xmm_out(int i_ur, int i_oc) {
+ Xbyak::Zmm zmm_out(int i_ur, int i_oc) {
int idx = i_ur + i_oc * jcp.ur_w;
- assert(idx < ker_reg_base_idx);
- return xmm_t(idx);
+ assert(idx < (jcp.is_depthwise
+ ? ker_dw_reg_base_idx : ker_reg_base_idx));
+ return Xbyak::Zmm(idx);
}
- zmm_t zmm_inp(int i_ic, int nb_x_blocking) {
+ Vmm vmm_inp(int i_ic, int nb_x_blocking) {
int idx = i_ic + nb_x_blocking * jcp.ur_w;
assert(idx < 31);
- return zmm_t(idx);
+ return Vmm(idx);
}
- zmm_t zmm_bias_alpha() {
- return zmm_t(jcp.nb_oc_blocking * jcp.ur_w);
+ Vmm vmm_bias_alpha() {
+ int nb_c_block = jcp.is_depthwise ? jcp.nb_ch_blocking : jcp.nb_oc_blocking;
+ return Vmm(nb_c_block * jcp.ur_w);
}
- xmm_t xmm_bias_alpha() {
- return xmm_t(jcp.nb_oc_blocking * jcp.ur_w);
+ Xbyak::Xmm xmm_bias_alpha() {
+ int nb_c_block = jcp.is_depthwise ? jcp.nb_ch_blocking : jcp.nb_oc_blocking;
+ return Xbyak::Xmm(nb_c_block * jcp.ur_w);
}
int get_ow_start(int ki, int pad_l) {
return nstl::max(0,
* (jcp.dilate_w + 1),
jcp.stride_w));
}
- bool maybe_relu(int position);
+
void prepare_output(int ur_w);
- void store_output(int ur_w, int last_oc_block_flag);
- void compute_ker(int ur_w, int pad_l, int pad_r, int last_ic_block_flag,
- bool h_padded = false);
- void kh_loop(int ur_w, int pad_l, int pad_r, int last_ic_block_flag);
+ void store_output(int ur_w, bool last_oc_block_flag);
+ void compute_ker_dw(
+ int ur_w, int pad_l, int pad_r, ic_block_t last_ic_block_flag, bool h_padded);
+ void compute_ker(int ur_w, int pad_l, int pad_r,
+ ic_block_t last_ic_block_flag, bool h_padded = false);
+ void kh_loop(int ur_w, int pad_l, int pad_r, ic_block_t last_ic_block_flag);
void icb_loop(
int ur_w, int pad_l, int pad_r, bool is_last_spatial_block);
void generate();
- void cvt2ps(data_type_t type_in, zmm_t zmm_in, const Xbyak::Operand &op,
+ void cvt2ps(data_type_t type_in, Vmm ymm_in, const Xbyak::Operand &op,
bool mask_flag);
+ const Vmm vmm_mask(const Vmm vmm_in, bool mask_flag, bool store = false);
+};
+
+struct jit_avx512_core_x8s8s32x_fwd_kernel {
+
+ jit_avx512_core_x8s8s32x_fwd_kernel(jit_conv_conf_t ajcp,
+ const primitive_attr_t &attr) :
+ jit_ker(nullptr),
+ zmm_kernel_(nullptr),
+ ymm_kernel_(nullptr),
+ xmm_kernel_(nullptr) {
+ int ch_block = ajcp.is_depthwise ? ajcp.ch_block : ajcp.ic_block;
+ switch (ch_block) {
+ case 16:
+ zmm_kernel_ =
+ new _jit_avx512_core_x8s8s32x_fwd_kernel<Xbyak::Zmm>(
+ ajcp, attr);
+ jit_ker = zmm_kernel_->jit_ker_;
+ return;
+ case 8:
+ ymm_kernel_ =
+ new _jit_avx512_core_x8s8s32x_fwd_kernel<Xbyak::Ymm>(
+ ajcp, attr);
+ jit_ker = ymm_kernel_->jit_ker_;
+ return;
+ case 4:
+ xmm_kernel_ =
+ new _jit_avx512_core_x8s8s32x_fwd_kernel<Xbyak::Xmm>(
+ ajcp, attr);
+ jit_ker = xmm_kernel_->jit_ker_;
+ return;
+ default:
+ assert(!"invalid channel blocking");
+ }
+ }
+
+ ~jit_avx512_core_x8s8s32x_fwd_kernel() {
+ delete xmm_kernel_;
+ delete ymm_kernel_;
+ delete zmm_kernel_;
+ }
+
+ static bool post_ops_ok(jit_conv_conf_t &jcp,
+ const primitive_attr_t &attr);
+
+ static status_t init_conf(jit_conv_conf_t &jcp,
+ const convolution_desc_t &cd,
+ cpu_memory_t::pd_t &src_pd,
+ cpu_memory_t::pd_t &weights_pd,
+ cpu_memory_t::pd_t &dst_pd,
+ cpu_memory_t::pd_t &bias_pd,
+ const primitive_attr_t &attr,
+ int nthreads);
+ static void init_scratchpad(memory_tracking::registrar_t &scratchpad,
+ const jit_conv_conf_t &jcp, const primitive_attr_t &attr);
+
+ void (*jit_ker)(jit_conv_call_s *);
+ _jit_avx512_core_x8s8s32x_fwd_kernel<Xbyak::Zmm> *zmm_kernel_;
+ _jit_avx512_core_x8s8s32x_fwd_kernel<Xbyak::Ymm> *ymm_kernel_;
+ _jit_avx512_core_x8s8s32x_fwd_kernel<Xbyak::Xmm> *xmm_kernel_;
};
}