#define JIT_SSE42_1x1_CONV_KERNEL_F32_HPP
#include "c_types_map.hpp"
+#include "cpu_memory.hpp"
#include "jit_generator.hpp"
#include "jit_primitive_conf.hpp"
-#include "cpu_memory.hpp"
#include "jit_uni_eltwise.hpp"
#include "jit_uni_depthwise.hpp"
namespace cpu {
struct jit_sse42_1x1_conv_kernel_f32: public jit_generator {
- jit_sse42_1x1_conv_kernel_f32(jit_1x1_conv_conf_t ajcp,
- const primitive_attr_t &attr): jcp(ajcp), attr_(attr) {
+ jit_sse42_1x1_conv_kernel_f32(jit_1x1_conv_conf_t ajcp, jit_conv_conf_t ajcp_dw,
+ const primitive_attr_t &attr)
+ : jcp(ajcp), jcp_dw(ajcp_dw), attr_(attr)
+ {
this->generate();
jit_ker = (void (*)(jit_1x1_conv_call_s *))this->getCode();
}
const memory_desc_wrapper &src_d,
const memory_desc_wrapper &weights_d,
const memory_desc_wrapper &dst_d,
- const primitive_attr_t &attr,
- bool with_relu, float relu_negative_slope);
+ const primitive_attr_t &attr);
- static status_t init_conf(jit_1x1_conv_conf_t &jcp,
- const convolution_desc_t &cd,
- const memory_desc_wrapper &src_d,
- const memory_desc_wrapper &weights_d,
- const memory_desc_wrapper &dst_d,
- const primitive_attr_t &attr)
- {
- return init_conf(jcp, cd, src_d, weights_d, dst_d, attr, false, 0.0);
- }
+ static void init_scratchpad(memory_tracking::registrar_t &scratchpad,
+ const jit_1x1_conv_conf_t &jcp, const jit_conv_conf_t &jcp_dw = jit_conv_conf_t());
DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_sse42_1x1_conv_kernel_f32)
jit_1x1_conv_conf_t jcp;
+ jit_conv_conf_t jcp_dw;
const primitive_attr_t &attr_;
void (*jit_ker)(jit_1x1_conv_call_s *);