#define JIT_AVX512_CORE_X8S8S32X_1X1_CONV_KERNEL_HPP
#include "c_types_map.hpp"
+#include "memory_tracking.hpp"
+
#include "jit_generator.hpp"
#include "jit_primitive_conf.hpp"
+#include "jit_uni_eltwise.hpp"
+#include "jit_uni_depthwise.hpp"
namespace mkldnn {
namespace impl {
jit_ker = (void (*)(jit_1x1_conv_call_s *)) this->getCode();
}
+ ~jit_avx512_core_x8s8s32x_1x1_conv_kernel() {
+ for (auto inj : eltwise_injectors)
+ delete inj;
+ eltwise_injectors.clear();
+
+ for (auto inj : depthwise_injectors)
+ delete inj;
+ depthwise_injectors.clear();
+ }
+
static bool post_ops_ok(jit_1x1_conv_conf_t &jcp,
const primitive_attr_t &attr);
static status_t init_conf(jit_1x1_conv_conf_t &jcp,
- const convolution_desc_t &cd,
- const memory_desc_wrapper &src_d,
- const memory_desc_wrapper &weights_d,
- const memory_desc_wrapper &dst_d,
- const memory_desc_wrapper &bias_d,
- const primitive_attr_t &attr,
- bool with_relu, float relu_negative_slope,
- int nthreads, bool reduce_src);
+ const convolution_desc_t &cd,
+ const memory_desc_wrapper &src_d,
+ const memory_desc_wrapper &weights_d,
+ const memory_desc_wrapper &dst_d,
+ const memory_desc_wrapper &bias_d,
+ const primitive_attr_t &attr,
+ int nthreads, bool reduce_src);
- static status_t init_conf(jit_1x1_conv_conf_t &jcp,
- const convolution_desc_t &cd,
- const memory_desc_wrapper &src_d,
- const memory_desc_wrapper &weights_d,
- const memory_desc_wrapper &dst_d,
- const memory_desc_wrapper &bias_d,
- const primitive_attr_t &attr,
- int nthreads, bool reduce_src)
- {
- return init_conf(jcp, cd, src_d, weights_d, dst_d, bias_d, attr, false,
- 0.0, nthreads, reduce_src);
- }
- bool maybe_relu(int position);
+ static void init_scratchpad(memory_tracking::registrar_t &scratchpad,
+ const jit_1x1_conv_conf_t &jcp, const primitive_attr_t &attr);
jit_1x1_conv_conf_t jcp;
const primitive_attr_t &attr_;
void (*jit_ker)(jit_1x1_conv_call_s *);
private:
+ nstl::vector<jit_uni_eltwise_injector_f32<avx512_common>*> eltwise_injectors;
+ nstl::vector<jit_uni_depthwise_injector_f32<avx512_common>*> depthwise_injectors;
+
using reg64_t = const Xbyak::Reg64;
using zmm_t = const Xbyak::Zmm;
using mask_t = const Xbyak::Opmask;
reg64_t aux_reg_output_data = abi_not_param1;
reg64_t reduce_loop_iter = abi_param1;
+ const Xbyak::Reg64 reg_d_weights = aux_reg_bcast_data;
+ const Xbyak::Reg64 reg_d_bias = reduce_loop_iter;
+ const Xbyak::Reg64 reg_oc_off = aux_reg_load_data;
+
reg64_t reg_last_load = r8;
mask_t ktail_mask = k6;
int reg_bcast_data_off = 16;
int reg_load_data_off = 24;
int reg_ptr_sum_scale_off = 32;
- int reg_last_load_off = 40;
- int reg_comp_data_off = 48;
- int stack_space_needed = 56;
+ int reg_comp_data_off = 40;
+ int stack_space_needed = 48;
void bcast_loop(int load_loop_blk);
void reduce_loop(int load_loop_blk, int ur, int substep, bool wraparound);
void generate();
- static void balance(jit_1x1_conv_conf_t &jcp, int nthreads);
void cvt2ps(data_type_t type_in, zmm_t zmm_in, const Xbyak::Operand &op,
bool mask_flag);
};
+
}
}
}