#define JIT_AVX512_COMMON_1x1_CONV_KERNEL_HPP
#include "c_types_map.hpp"
+#include "memory_tracking.hpp"
+
#include "jit_generator.hpp"
#include "jit_primitive_conf.hpp"
#include "jit_uni_eltwise.hpp"
struct jit_avx512_common_1x1_conv_kernel : public jit_generator {
jit_avx512_common_1x1_conv_kernel(jit_1x1_conv_conf_t ajcp,
- const primitive_attr_t &attr) : jcp(ajcp), attr_(attr)
+ const primitive_attr_t &attr)
+ : jcp(ajcp), attr_(attr)
{
this->generate();
jit_ker = (void (*)(jit_1x1_conv_call_s *)) this->getCode();
const primitive_attr_t &attr);
static status_t init_conf(jit_1x1_conv_conf_t &jcp,
- const convolution_desc_t &cd,
- const memory_desc_wrapper &src_d,
- const memory_desc_wrapper &weights_d,
- const memory_desc_wrapper &dst_d,
- const primitive_attr_t &attr,
- bool with_relu, float relu_negative_slope,
- int nthreads, bool reduce_src);
+ const convolution_desc_t &cd,
+ const memory_desc_wrapper &src_d,
+ const memory_desc_wrapper &weights_d,
+ const memory_desc_wrapper &dst_d,
+ const primitive_attr_t &attr,
+ int nthreads, bool reduce_src);
- static status_t init_conf(jit_1x1_conv_conf_t &jcp,
- const convolution_desc_t &cd,
- const memory_desc_wrapper &src_d,
- const memory_desc_wrapper &weights_d,
- const memory_desc_wrapper &dst_d,
- const primitive_attr_t &attr,
- int nthreads, bool reduce_src)
- {
- return init_conf(jcp, cd, src_d, weights_d, dst_d, attr, false, 0.0,
- nthreads, reduce_src);
- }
+ static void init_scratchpad(memory_tracking::registrar_t &scratchpad,
+ const jit_1x1_conv_conf_t &jcp);
jit_1x1_conv_conf_t jcp;
const primitive_attr_t &attr_;
private:
using reg64_t = const Xbyak::Reg64;
using zmm_t = const Xbyak::Zmm;
- using mask_t = const Xbyak::Opmask;
reg64_t reg_bcast_data = r8;
reg64_t reg_load_data = r10;
reg64_t reg_reduce_pos_flag = rax;
reg64_t reg_output_stride = r13;
reg64_t reg_bias_data = r12;
+ reg64_t reg_relu_ns = r13;
reg64_t reg_bcast_loop_work = aux1_reg_bcast_data;
Xbyak::Zmm vreg_bcast = Xbyak::Zmm(31);
void generate();
static void balance(jit_1x1_conv_conf_t &jcp, int nthreads);
};
+
}
}
}