Publishing 2019 R1 content
[platform/upstream/dldt.git] / inference-engine / thirdparty / mkl-dnn / src / cpu / cpu_engine.cpp
index 104ce88..738725d 100644 (file)
@@ -24,7 +24,7 @@
 #include "cpu_concat.hpp"
 #include "cpu_sum.hpp"
 
-#include "cpu/ref_rnn.hpp"
+#include "cpu/rnn/ref_rnn.hpp"
 
 #include "cpu/jit_avx512_core_x8s8s32x_1x1_convolution.hpp"
 #include "cpu/jit_avx512_common_1x1_convolution.hpp"
 #include "cpu/gemm_convolution.hpp"
 #include "cpu/gemm_x8s8s32x_convolution.hpp"
 #include "cpu/ref_convolution.hpp"
-#include "cpu/jit_avx512_core_u8s8s32x_deconvolution.hpp"
+#include "cpu/jit_avx512_core_x8s8s32x_deconvolution.hpp"
+#include "cpu/jit_avx512_core_x8s8s32x_1x1_deconvolution.hpp"
 #include "cpu/ref_deconvolution.hpp"
 #include "cpu/ref_shuffle.hpp"
 #include "cpu/jit_uni_eltwise.hpp"
 #include "cpu/ref_eltwise.hpp"
 #include "cpu/ref_softmax.hpp"
 #include "cpu/jit_uni_pooling.hpp"
-#include "cpu/jit_avx512_core_i8i8_pooling.hpp"
+#include "cpu/jit_uni_i8i8_pooling.hpp"
 #include "cpu/ref_pooling.hpp"
 #include "cpu/nchw_pooling.hpp"
 #include "cpu/nhwc_pooling.hpp"
@@ -59,7 +60,7 @@
 #include "cpu/nspc_batch_normalization.hpp"
 #include "cpu/ref_inner_product.hpp"
 #include "cpu/gemm_inner_product.hpp"
-#include "cpu/gemm_u8s8s32x_inner_product.hpp"
+#include "cpu/gemm_x8s8s32x_inner_product.hpp"
 #include "cpu/jit_uni_dw_convolution.hpp"
 #include "cpu/jit_avx512_core_u8s8s32x_wino_convolution.hpp"
 #include "cpu/jit_avx512_core_fp32_wino_conv_2x3.hpp"
 #include "cpu/jit_uni_depthwise.hpp"
 #include "cpu/ref_depthwise.hpp"
 #include "cpu/jit_uni_x8s8s32x_convolution.hpp"
-#include "cpu/jit_uni_x8s8s32x_1x1_convolution.hpp"
 #include "cpu/jit_uni_x8s8s32x_dw_convolution.hpp"
-#include "cpu/jit_uni_i8i8_pooling.hpp"
+#include "cpu/jit_sse42_i8i8_pooling.hpp"
+#include "cpu/jit_uni_planar_convolution.hpp"
+#include "cpu/jit_uni_binary_convolution.hpp"
+#include "cpu/ref_binary_convolution.hpp"
+#include "cpu/jit_uni_binarization.hpp"
+#include "cpu/ref_binarization.hpp"
 
 namespace mkldnn {
 namespace impl {
@@ -105,9 +110,11 @@ using namespace mkldnn::impl::data_type;
 #define INSTANCE(...) &primitive_desc_t::create<__VA_ARGS__::pd_t>
 static const pd_create_f cpu_impl_list[] = {
     /* RNN */
-    INSTANCE(ref_rnn_fwd_t),
-    INSTANCE(ref_rnn_bwd_t),
+    INSTANCE(ref_rnn_fwd_f32_t),
+    INSTANCE(ref_rnn_fwd_u8s8_t),
+    INSTANCE(ref_rnn_bwd_f32_t),
     /* conv */
+    INSTANCE(jit_avx512_common_planar_convolution_fwd_t),
     INSTANCE(jit_avx512_common_dw_convolution_fwd_t),
     INSTANCE(jit_avx512_common_dw_convolution_bwd_data_t),
     INSTANCE(jit_avx512_common_dw_convolution_bwd_weights_t),
@@ -126,6 +133,7 @@ static const pd_create_f cpu_impl_list[] = {
     INSTANCE(jit_avx512_common_convolution_fwd_t<f32>),
     INSTANCE(jit_avx512_common_convolution_bwd_data_t<f32>),
     INSTANCE(jit_avx512_common_convolution_bwd_weights_t<f32>),
+    INSTANCE(jit_avx2_planar_convolution_fwd_t),
     INSTANCE(jit_avx2_dw_convolution_fwd_t),
     INSTANCE(jit_avx2_dw_convolution_bwd_data_t),
     INSTANCE(jit_avx2_dw_convolution_bwd_weights_t),
@@ -194,14 +202,14 @@ static const pd_create_f cpu_impl_list[] = {
     INSTANCE(jit_sse42_x8s8s32x_convolution_fwd_t<s8,s32>),
     INSTANCE(jit_sse42_x8s8s32x_convolution_fwd_t<s8,u8>),
     INSTANCE(jit_sse42_x8s8s32x_convolution_fwd_t<s8,s8>),
-    INSTANCE(_gemm_x8s8s32x_convolution_fwd_t<false, u8, s32>),
-    INSTANCE(_gemm_x8s8s32x_convolution_fwd_t<false, u8, u8>),
-    INSTANCE(_gemm_x8s8s32x_convolution_fwd_t<false, u8, s8>),
-    INSTANCE(_gemm_x8s8s32x_convolution_fwd_t<false, u8, f32>),
-    INSTANCE(_gemm_x8s8s32x_convolution_fwd_t<false, s8, s32>),
-    INSTANCE(_gemm_x8s8s32x_convolution_fwd_t<false, s8, u8>),
-    INSTANCE(_gemm_x8s8s32x_convolution_fwd_t<false, s8, s8>),
-    INSTANCE(_gemm_x8s8s32x_convolution_fwd_t<false, s8, f32>),
+    INSTANCE(_gemm_x8s8s32x_convolution_fwd_t<u8, s32>),
+    INSTANCE(_gemm_x8s8s32x_convolution_fwd_t<u8, u8>),
+    INSTANCE(_gemm_x8s8s32x_convolution_fwd_t<u8, s8>),
+    INSTANCE(_gemm_x8s8s32x_convolution_fwd_t<u8, f32>),
+    INSTANCE(_gemm_x8s8s32x_convolution_fwd_t<s8, s32>),
+    INSTANCE(_gemm_x8s8s32x_convolution_fwd_t<s8, u8>),
+    INSTANCE(_gemm_x8s8s32x_convolution_fwd_t<s8, s8>),
+    INSTANCE(_gemm_x8s8s32x_convolution_fwd_t<s8, f32>),
     INSTANCE(_gemm_u8s8s32x_convolution_bwd_data_t<s32>),
     INSTANCE(_gemm_u8s8s32x_convolution_bwd_data_t<u8>),
     INSTANCE(_gemm_u8s8s32x_convolution_bwd_data_t<s8>),
@@ -218,10 +226,22 @@ static const pd_create_f cpu_impl_list[] = {
     INSTANCE(ref_convolution_bwd_data_t<u8, s8, u8, s32>),
     INSTANCE(ref_convolution_bwd_weights_t<s16, s32, s16, s32>),
     /* deconv */
-    INSTANCE(_jit_avx512_core_u8s8s32x_deconvolution_fwd_t<s32>),
-    INSTANCE(_jit_avx512_core_u8s8s32x_deconvolution_fwd_t<u8>),
-    INSTANCE(_jit_avx512_core_u8s8s32x_deconvolution_fwd_t<s8>),
-    INSTANCE(_jit_avx512_core_u8s8s32x_deconvolution_fwd_t<f32>),
+    INSTANCE(jit_avx512_core_x8s8s32x_1x1_deconvolution_fwd_t<u8,f32>),
+    INSTANCE(jit_avx512_core_x8s8s32x_1x1_deconvolution_fwd_t<u8,s32>),
+    INSTANCE(jit_avx512_core_x8s8s32x_1x1_deconvolution_fwd_t<u8,u8>),
+    INSTANCE(jit_avx512_core_x8s8s32x_1x1_deconvolution_fwd_t<u8,s8>),
+    INSTANCE(jit_avx512_core_x8s8s32x_1x1_deconvolution_fwd_t<s8,f32>),
+    INSTANCE(jit_avx512_core_x8s8s32x_1x1_deconvolution_fwd_t<s8,s32>),
+    INSTANCE(jit_avx512_core_x8s8s32x_1x1_deconvolution_fwd_t<s8,u8>),
+    INSTANCE(jit_avx512_core_x8s8s32x_1x1_deconvolution_fwd_t<s8,s8>),
+    INSTANCE(_jit_avx512_core_x8s8s32x_deconvolution_fwd_t<u8,s32>),
+    INSTANCE(_jit_avx512_core_x8s8s32x_deconvolution_fwd_t<u8,u8>),
+    INSTANCE(_jit_avx512_core_x8s8s32x_deconvolution_fwd_t<u8,s8>),
+    INSTANCE(_jit_avx512_core_x8s8s32x_deconvolution_fwd_t<u8,f32>),
+    INSTANCE(_jit_avx512_core_x8s8s32x_deconvolution_fwd_t<s8,s32>),
+    INSTANCE(_jit_avx512_core_x8s8s32x_deconvolution_fwd_t<s8,u8>),
+    INSTANCE(_jit_avx512_core_x8s8s32x_deconvolution_fwd_t<s8,s8>),
+    INSTANCE(_jit_avx512_core_x8s8s32x_deconvolution_fwd_t<s8,f32>),
     INSTANCE(ref_deconvolution_bwd_weights_t),
     INSTANCE(ref_deconvolution_bwd_data_t),
     INSTANCE(ref_deconvolution_fwd_t),
@@ -269,9 +289,9 @@ static const pd_create_f cpu_impl_list[] = {
     INSTANCE(ref_pooling_fwd_t<f32>),
     INSTANCE(ref_pooling_bwd_t<f32>),
     /* pool (int) */
-    INSTANCE(jit_avx512_core_i8i8_pooling_fwd_t),
+    INSTANCE(jit_uni_i8i8_pooling_fwd_t<avx512_core>),
     INSTANCE(jit_uni_i8i8_pooling_fwd_t<avx2>),
-    INSTANCE(jit_uni_i8i8_pooling_fwd_t<sse42>),
+    INSTANCE(jit_sse42_i8i8_pooling_fwd_t),
     INSTANCE(ref_pooling_fwd_t<s32>),
     INSTANCE(ref_pooling_fwd_t<s16, s32>),
     INSTANCE(ref_pooling_fwd_t<s8, s32>),
@@ -307,69 +327,35 @@ static const pd_create_f cpu_impl_list[] = {
     INSTANCE(ref_inner_product_bwd_data_t<f32, f32, f32, f32>),
     INSTANCE(ref_inner_product_bwd_weights_t<f32>),
     /* inner product (int) */
-    INSTANCE(gemm_u8s8s32x_inner_product_fwd_t<u8>),
-    INSTANCE(gemm_u8s8s32x_inner_product_fwd_t<s8>),
-    INSTANCE(gemm_u8s8s32x_inner_product_fwd_t<s32>),
-    INSTANCE(gemm_u8s8s32x_inner_product_fwd_t<f32>),
+    INSTANCE(gemm_x8s8s32x_inner_product_fwd_t<u8, u8>),
+    INSTANCE(gemm_x8s8s32x_inner_product_fwd_t<u8, s8>),
+    INSTANCE(gemm_x8s8s32x_inner_product_fwd_t<u8, s32>),
+    INSTANCE(gemm_x8s8s32x_inner_product_fwd_t<u8, f32>),
+    INSTANCE(gemm_x8s8s32x_inner_product_fwd_t<s8, u8>),
+    INSTANCE(gemm_x8s8s32x_inner_product_fwd_t<s8, s8>),
+    INSTANCE(gemm_x8s8s32x_inner_product_fwd_t<s8, s32>),
+    INSTANCE(gemm_x8s8s32x_inner_product_fwd_t<s8, f32>),
     INSTANCE(ref_inner_product_fwd_t<u8, s8, u8, s32>),
     INSTANCE(ref_inner_product_fwd_t<u8, s8, s8, s32>),
     INSTANCE(ref_inner_product_fwd_t<u8, s8, s32, s32>),
     INSTANCE(ref_inner_product_fwd_t<u8, s8, f32, s32>),
     INSTANCE(ref_inner_product_fwd_t<s16, s16, s32, s32>),
     INSTANCE(ref_inner_product_bwd_data_t<s32, s16, s16, s32>),
-    /* conv_eltwise */
-    INSTANCE(jit_avx512_common_dw_convolution_relu_t),
-    INSTANCE(jit_avx512_common_convolution_winograd_relu_t),
-    INSTANCE(jit_avx512_common_1x1_convolution_relu_f32_t),
-    INSTANCE(jit_avx512_common_convolution_relu_t<f32>),
-    INSTANCE(jit_avx2_dw_convolution_relu_t),
-    INSTANCE(jit_avx2_1x1_convolution_relu_t),
-    INSTANCE(jit_sse42_dw_convolution_relu_t),
-    INSTANCE(jit_sse42_1x1_convolution_relu_t),
-    INSTANCE(jit_avx2_convolution_relu_t),
-    INSTANCE(jit_sse42_convolution_relu_t),
-    INSTANCE(gemm_convolution_relu_t),
-    INSTANCE(ref_convolution_relu_t<f32>),
-    /* conv_eltwise (int) */
-    INSTANCE(jit_avx512_core_u8s8s32x_wino_convolution_relu_t<f32>),
-    INSTANCE(jit_avx512_core_u8s8s32x_wino_convolution_relu_t<s32>),
-    INSTANCE(jit_avx512_core_u8s8s32x_wino_convolution_relu_t<s8>),
-    INSTANCE(jit_avx512_core_u8s8s32x_wino_convolution_relu_t<u8>),
-    INSTANCE(jit_avx512_common_1x1_convolution_relu_s16s16s32_t),
-    INSTANCE(jit_avx512_common_convolution_relu_t<s16, s16, s32>),
-    INSTANCE(jit_avx512_core_x8s8s32x_1x1_convolution_relu_t<u8,f32>),
-    INSTANCE(jit_avx512_core_x8s8s32x_1x1_convolution_relu_t<u8,s32>),
-    INSTANCE(jit_avx512_core_x8s8s32x_1x1_convolution_relu_t<u8,s8>),
-    INSTANCE(jit_avx512_core_x8s8s32x_1x1_convolution_relu_t<u8,u8>),
-    INSTANCE(jit_avx512_core_x8s8s32x_1x1_convolution_relu_t<s8,f32>),
-    INSTANCE(jit_avx512_core_x8s8s32x_1x1_convolution_relu_t<s8,s32>),
-    INSTANCE(jit_avx512_core_x8s8s32x_1x1_convolution_relu_t<s8,s8>),
-    INSTANCE(jit_avx512_core_x8s8s32x_1x1_convolution_relu_t<s8,u8>),
-    INSTANCE(jit_avx512_core_x8s8s32x_convolution_relu_t<u8,f32>),
-    INSTANCE(jit_avx512_core_x8s8s32x_convolution_relu_t<u8,s32>),
-    INSTANCE(jit_avx512_core_x8s8s32x_convolution_relu_t<u8,u8>),
-    INSTANCE(jit_avx512_core_x8s8s32x_convolution_relu_t<u8,s8>),
-    INSTANCE(jit_avx512_core_x8s8s32x_convolution_relu_t<s8,f32>),
-    INSTANCE(jit_avx512_core_x8s8s32x_convolution_relu_t<s8,s32>),
-    INSTANCE(jit_avx512_core_x8s8s32x_convolution_relu_t<s8,u8>),
-    INSTANCE(jit_avx512_core_x8s8s32x_convolution_relu_t<s8,s8>),
-    INSTANCE(_gemm_x8s8s32x_convolution_fwd_t<true, u8, s32>),
-    INSTANCE(_gemm_x8s8s32x_convolution_fwd_t<true, u8, u8>),
-    INSTANCE(_gemm_x8s8s32x_convolution_fwd_t<true, u8, s8>),
-    INSTANCE(_gemm_x8s8s32x_convolution_fwd_t<true, u8, f32>),
-    INSTANCE(_gemm_x8s8s32x_convolution_fwd_t<true, s8, s32>),
-    INSTANCE(_gemm_x8s8s32x_convolution_fwd_t<true, s8, u8>),
-    INSTANCE(_gemm_x8s8s32x_convolution_fwd_t<true, s8, s8>),
-    INSTANCE(_gemm_x8s8s32x_convolution_fwd_t<true, s8, f32>),
-    INSTANCE(ref_convolution_relu_t<s16, s16, s32, s32>),
-    INSTANCE(ref_convolution_relu_t<u8, s8, s32, s32>),
-    INSTANCE(ref_convolution_relu_t<u8, s8, s8, s32>),
-    INSTANCE(ref_convolution_relu_t<u8, s8, u8, s32>),
     /* roi pooling */
     INSTANCE(jit_uni_roi_pooling_fwd_t<avx512_common>),
     INSTANCE(jit_uni_roi_pooling_fwd_t<avx2>),
     INSTANCE(jit_uni_roi_pooling_fwd_t<sse42>),
     INSTANCE(ref_roi_pooling_fwd_t<data_type::f32>),
+    /* binary convolution */
+//    INSTANCE(jit_uni_binary_convolution_fwd_t<avx512_common>),
+    INSTANCE(jit_uni_binary_convolution_fwd_t<avx2>),
+    INSTANCE(jit_uni_binary_convolution_fwd_t<sse42>),
+    INSTANCE(ref_binary_convolution_fwd_t),
+    /* binarization */
+    INSTANCE(jit_uni_binarization_fwd_t<avx512_common>),
+    INSTANCE(jit_uni_binarization_fwd_t<avx2>),
+    INSTANCE(jit_uni_binarization_fwd_t<sse42>),
+    INSTANCE(ref_binarization_fwd_t<f32>),
     /* eol */
     nullptr,
 };