#define JIT_AVX512_COMMON_CONV_KERNEL_F32_HPP
#include "c_types_map.hpp"
-#include "cpu_memory.hpp"
+#include "memory_tracking.hpp"
+#include "cpu_memory.hpp"
#include "jit_generator.hpp"
#include "jit_primitive_conf.hpp"
#include "jit_uni_eltwise.hpp"
namespace impl {
namespace cpu {
-struct jit_avx512_common_conv_fwd_kernel : public jit_generator {
+template<typename Vmm>
+struct _jit_avx512_common_conv_fwd_kernel : public jit_generator {
- jit_avx512_common_conv_fwd_kernel(jit_conv_conf_t ajcp,
- const primitive_attr_t &attr) : jcp(ajcp), attr_(attr)
+ _jit_avx512_common_conv_fwd_kernel(jit_conv_conf_t ajcp,
+ const primitive_attr_t &attr)
+ : jcp(ajcp), attr_(attr)
{
generate();
- jit_ker = (void (*)(jit_conv_call_s *))getCode();
+ jit_ker_ = (void (*)(jit_conv_call_s *))getCode();
}
- ~jit_avx512_common_conv_fwd_kernel() {
+ ~_jit_avx512_common_conv_fwd_kernel() {
for (auto inj : eltwise_injectors)
delete inj;
eltwise_injectors.clear();
depthwise_injectors.clear();
}
- DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_avx512_common_conv_fwd_kernel)
-
- static bool post_ops_ok(jit_conv_conf_t &jcp,
- const primitive_attr_t &attr);
- static status_t init_conf(jit_conv_conf_t &jcp,
- const convolution_desc_t &cd,
- cpu_memory_t::pd_t &src_pd,
- cpu_memory_t::pd_t &weights_pd,
- cpu_memory_t::pd_t &dst_pd,
- cpu_memory_t::pd_t &bias_pd,
- const primitive_attr_t &attr,
- int nthreads,
- bool with_relu,
- float relu_negative_slope);
+ DECLARE_CPU_JIT_AUX_FUNCTIONS(_jit_avx512_common_conv_fwd_kernel)
jit_conv_conf_t jcp;
const primitive_attr_t &attr_;
- void (*jit_ker)(jit_conv_call_s *);
+ void (*jit_ker_)(jit_conv_call_s *);
private:
using reg64_t = const Xbyak::Reg64;
reg64_t reg_long_offt = r11;
reg64_t reg_out_long_offt = r14;
- inline Xbyak::Zmm zmm_ker(int i_ic) {
+ inline Vmm vmm_ker(int i_ic) {
assert(i_ic < 4);
- return Xbyak::Zmm(ker_reg_base_idx + i_ic);
+ return Vmm(ker_reg_base_idx + i_ic);
}
- inline Xbyak::Zmm zmm_out(int i_ur, int i_oc) {
+ inline Vmm vmm_out(int i_ur, int i_oc) {
int idx = i_ur + i_oc * jcp.ur_w;
assert(idx < ker_reg_base_idx);
- return Xbyak::Zmm(idx);
+ return Vmm(idx);
}
- inline Xbyak::Zmm zmm_inp(int i_ic, int nb_x_blocking) {
+ inline Vmm vmm_inp(int i_ic, int nb_x_blocking) {
int idx = i_ic + nb_x_blocking * jcp.ur_w;
assert(idx < 31);
- return Xbyak::Zmm(idx);
+ return Vmm(idx);
}
Xbyak::Reg64 imm_addr64 = r15;
- Xbyak::Zmm zmm_wei = Xbyak::Zmm(31);
+ Vmm vmm_wei = Vmm(31);
reg64_t reg_d_weights = imm_addr64;
reg64_t reg_d_bias = reg_kj;
void generate();
- inline void vpXdpwssd(Xbyak::Zmm zmm1, Xbyak::Zmm zmm2,
- const Xbyak::Address& op) {
- if (jcp.ver == ver_4vnni)
- vp4dpwssd(zmm1, zmm2, op);
- else
- vpdpwssd(zmm1, zmm2, op);
- }
-
- inline void vadd(Xbyak::Zmm zmm, const Xbyak::Operand& op) {
+ inline void vadd(Vmm vmm, const Xbyak::Operand& op) {
if (jcp.ver == ver_4vnni || jcp.ver == ver_vnni)
- vpaddd(zmm, zmm, op);
+ vpaddd(vmm, vmm, op);
else
- vaddps(zmm, zmm, op);
- }
-
- inline void vcmp(Xbyak::Opmask kmask,
- Xbyak::Zmm zmm_src1, Xbyak::Zmm zmm_src2, const unsigned char cmp) {
- if (jcp.ver == ver_4vnni || jcp.ver == ver_vnni)
- vpcmpd(kmask, zmm_src1, zmm_src2, cmp);
- else
- vcmpps(kmask, zmm_src1, zmm_src2, cmp);
- }
-
- inline void vmul(Xbyak::Zmm zmm_dst, Xbyak::Opmask kmask,
- Xbyak::Zmm zmm_src1, Xbyak::Zmm zmm_src2) {
- if (jcp.ver == ver_4vnni || jcp.ver == ver_vnni)
- vpmulld(zmm_dst | kmask, zmm_src1, zmm_src2);
- else
- vmulps(zmm_dst | kmask, zmm_src1, zmm_src2);
+ vaddps(vmm, vmm, op);
}
inline size_t get_output_offset(int oi, int n_oc_block) {
}
};
+struct jit_avx512_common_conv_fwd_kernel {
+
+ jit_avx512_common_conv_fwd_kernel(jit_conv_conf_t ajcp,
+ const primitive_attr_t &attr) :
+ jit_ker(nullptr),
+ zmm_kernel_(nullptr),
+ xmm_kernel_(nullptr) {
+ int ch_block = ajcp.is_depthwise ? ajcp.ch_block : ajcp.oc_block;
+ switch (ch_block) {
+ case 16:
+ zmm_kernel_ =
+ new _jit_avx512_common_conv_fwd_kernel<Xbyak::Zmm>(
+ ajcp, attr);
+ jit_ker = zmm_kernel_->jit_ker_;
+ return;
+ case 4:
+ xmm_kernel_ =
+ new _jit_avx512_common_conv_fwd_kernel<Xbyak::Xmm>(
+ ajcp, attr);
+ jit_ker = xmm_kernel_->jit_ker_;
+ return;
+ default:
+ assert(!"invalid channel blocking");
+ }
+ }
+
+ ~jit_avx512_common_conv_fwd_kernel() {
+ delete xmm_kernel_;
+ delete zmm_kernel_;
+ }
+
+ enum {
+ typesize = sizeof(float)
+ };
+
+ static bool post_ops_ok(jit_conv_conf_t &jcp,
+ const primitive_attr_t &attr);
+ static status_t init_conf(jit_conv_conf_t &jcp,
+ const convolution_desc_t &cd,
+ cpu_memory_t::pd_t &src_pd,
+ cpu_memory_t::pd_t &weights_pd,
+ cpu_memory_t::pd_t &dst_pd,
+ cpu_memory_t::pd_t &bias_pd,
+ const primitive_attr_t &attr,
+ int nthreads);
+ static void init_scratchpad(memory_tracking::registrar_t &scratchpad,
+ const jit_conv_conf_t &jcp);
+
+ void(*jit_ker)(jit_conv_call_s *);
+ _jit_avx512_common_conv_fwd_kernel<Xbyak::Zmm> *zmm_kernel_;
+ _jit_avx512_common_conv_fwd_kernel<Xbyak::Xmm> *xmm_kernel_;
+};
+
struct jit_avx512_common_conv_bwd_data_kernel_f32: public jit_generator {
jit_avx512_common_conv_bwd_data_kernel_f32(jit_conv_conf_t ajcp): jcp(ajcp)
const memory_desc_wrapper &diff_src_d,
const memory_desc_wrapper &weights_d,
const memory_desc_wrapper &diff_dst_d);
+ static void init_scratchpad(memory_tracking::registrar_t &scratchpad,
+ const jit_conv_conf_t &jcp);
jit_conv_conf_t jcp;
void (*jit_ker)(jit_conv_call_s *);
const convolution_desc_t &cd, cpu_memory_t::pd_t &src_pd,
cpu_memory_t::pd_t &diff_weights_pd,
cpu_memory_t::pd_t &diff_bias_pd, cpu_memory_t::pd_t &diff_dst_pd);
+ static void init_scratchpad(memory_tracking::registrar_t &scratchpad,
+ const jit_conv_conf_t &jcp);
jit_conv_conf_t jcp;
void (*jit_ker)(jit_conv_call_s *);
inline void compute_loop();
void generate();
+
+ static void balance(const jit_conv_conf_t &j, int &nthr, int &nthr_mb,
+ int &nthr_g, int &nthr_oc_b, int &nthr_ic_b);
};
}