#define JIT_AVX2_1x1_CONV_KERNEL_F32_HPP
#include "c_types_map.hpp"
+#include "memory_tracking.hpp"
+
+#include "cpu_memory.hpp"
#include "jit_generator.hpp"
#include "jit_primitive_conf.hpp"
-#include "cpu_memory.hpp"
#include "jit_uni_eltwise.hpp"
#include "jit_uni_depthwise.hpp"
struct jit_avx2_1x1_conv_kernel_f32: public jit_generator {
DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_avx2_1x1_conv_kernel_f32)
- jit_avx2_1x1_conv_kernel_f32(jit_1x1_conv_conf_t ajcp,
- const primitive_attr_t &attr): jcp(ajcp), attr_(attr)
+ jit_avx2_1x1_conv_kernel_f32(jit_1x1_conv_conf_t ajcp, jit_conv_conf_t ajcp_dw,
+ const primitive_attr_t &attr)
+ : jcp(ajcp), jcp_dw(ajcp_dw), attr_(attr)
{
this->generate();
jit_ker = (void (*)(jit_1x1_conv_call_s *))this->getCode();
const memory_desc_wrapper &src_d,
const memory_desc_wrapper &weights_d,
const memory_desc_wrapper &dst_d,
- const primitive_attr_t &attr,
- bool with_relu, float relu_negative_slope);
+ const primitive_attr_t &attr);
- static status_t init_conf(jit_1x1_conv_conf_t &jcp,
- const convolution_desc_t &cd,
- const memory_desc_wrapper &src_d,
- const memory_desc_wrapper &weights_d,
- const memory_desc_wrapper &dst_d,
- const primitive_attr_t &attr)
- {
- return init_conf(jcp, cd, src_d, weights_d, dst_d, attr, false, 0.0);
- }
+ static void init_scratchpad(memory_tracking::registrar_t &scratchpad,
+ const jit_1x1_conv_conf_t &jcp, const jit_conv_conf_t &jcp_dw = jit_conv_conf_t());
jit_1x1_conv_conf_t jcp;
+ jit_conv_conf_t jcp_dw;
const primitive_attr_t &attr_;
void (*jit_ker)(jit_1x1_conv_call_s *);
int stack_space_needed = 8;
ymm_t vreg_bcast = ymm_t(15);
- Xbyak::Ymm vmask = Xbyak::Ymm(14);
+ ymm_t vtmp = ymm_t(14);
void generate_bcast_loop(int load_loop_blk);
void generate_reduce_loop(int load_loop_blk, int ur);