Publishing 2019 R1 content
[platform/upstream/dldt.git] / inference-engine / thirdparty / mkl-dnn / src / cpu / cpu_engine.cpp
1 /*******************************************************************************
2 * Copyright 2016-2018 Intel Corporation
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *     http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
16
17 #include <assert.h>
18
19 #include "cpu_engine.hpp"
20 #include "cpu_memory.hpp"
21 #include "type_helpers.hpp"
22 #include "verbose.hpp"
23
24 #include "cpu_concat.hpp"
25 #include "cpu_sum.hpp"
26
27 #include "cpu/rnn/ref_rnn.hpp"
28
29 #include "cpu/jit_avx512_core_x8s8s32x_1x1_convolution.hpp"
30 #include "cpu/jit_avx512_common_1x1_convolution.hpp"
31 #include "cpu/jit_avx512_core_fp32_wino_conv_4x3.hpp"
32 #include "cpu/jit_avx512_common_convolution_winograd.hpp"
33 #include "cpu/jit_avx512_core_x8s8s32x_convolution.hpp"
34 #include "cpu/jit_avx512_common_convolution.hpp"
35 #include "cpu/jit_avx2_1x1_convolution.hpp"
36 #include "cpu/jit_sse42_1x1_convolution.hpp"
37 #include "cpu/jit_avx2_convolution.hpp"
38 #include "cpu/jit_sse42_convolution.hpp"
39 #include "cpu/gemm_convolution.hpp"
40 #include "cpu/gemm_x8s8s32x_convolution.hpp"
41 #include "cpu/ref_convolution.hpp"
42 #include "cpu/jit_avx512_core_x8s8s32x_deconvolution.hpp"
43 #include "cpu/jit_avx512_core_x8s8s32x_1x1_deconvolution.hpp"
44 #include "cpu/ref_deconvolution.hpp"
45 #include "cpu/ref_shuffle.hpp"
46 #include "cpu/jit_uni_eltwise.hpp"
47 #include "cpu/ref_eltwise.hpp"
48 #include "cpu/ref_softmax.hpp"
49 #include "cpu/jit_uni_pooling.hpp"
50 #include "cpu/jit_uni_i8i8_pooling.hpp"
51 #include "cpu/ref_pooling.hpp"
52 #include "cpu/nchw_pooling.hpp"
53 #include "cpu/nhwc_pooling.hpp"
54 #include "cpu/jit_avx512_common_lrn.hpp"
55 #include "cpu/jit_uni_lrn.hpp"
56 #include "cpu/ref_lrn.hpp"
57 #include "cpu/jit_uni_batch_normalization.hpp"
58 #include "cpu/ref_batch_normalization.hpp"
59 #include "cpu/ncsp_batch_normalization.hpp"
60 #include "cpu/nspc_batch_normalization.hpp"
61 #include "cpu/ref_inner_product.hpp"
62 #include "cpu/gemm_inner_product.hpp"
63 #include "cpu/gemm_x8s8s32x_inner_product.hpp"
64 #include "cpu/jit_uni_dw_convolution.hpp"
65 #include "cpu/jit_avx512_core_u8s8s32x_wino_convolution.hpp"
66 #include "cpu/jit_avx512_core_fp32_wino_conv_2x3.hpp"
67 #include "cpu/jit_uni_roi_pooling.hpp"
68 #include "cpu/jit_uni_softmax.hpp"
69 #include "cpu/ref_roi_pooling.hpp"
70 #include "cpu/jit_uni_depthwise.hpp"
71 #include "cpu/ref_depthwise.hpp"
72 #include "cpu/jit_uni_x8s8s32x_convolution.hpp"
73 #include "cpu/jit_uni_x8s8s32x_dw_convolution.hpp"
74 #include "cpu/jit_sse42_i8i8_pooling.hpp"
75 #include "cpu/jit_uni_planar_convolution.hpp"
76 #include "cpu/jit_uni_binary_convolution.hpp"
77 #include "cpu/ref_binary_convolution.hpp"
78 #include "cpu/jit_uni_binarization.hpp"
79 #include "cpu/ref_binarization.hpp"
80
81 namespace mkldnn {
82 namespace impl {
83 namespace cpu {
84
85 using namespace mkldnn::impl::status;
86
87 status_t cpu_engine_t::memory_primitive_desc_create(memory_pd_t **pd,
88         const memory_desc_t *desc) {
89     return safe_ptr_assign<memory_pd_t>(*pd,
90             new cpu_memory_t::pd_t(this, desc));
91 }
92
93 status_t cpu_engine_t::view_primitive_desc_create(view_pd_t **view_pd,
94             const memory_pd_t *memory_pd, const dims_t dims,
95             const dims_t offsets) {
96     assert(memory_pd->engine() == this);
97     cpu_view_t::pd_t *cpu_vpd = nullptr;
98     status_t status = cpu_view_t::pd_t::create(&cpu_vpd,
99             (const cpu_memory_t::pd_t *)memory_pd, dims, offsets);
100     if (status != success) return status;
101     *view_pd = cpu_vpd;
102     return success;
103 }
104
105 using pd_create_f = mkldnn::impl::engine_t::primitive_desc_create_f;
106
107 namespace {
108 using namespace mkldnn::impl::data_type;
109
110 #define INSTANCE(...) &primitive_desc_t::create<__VA_ARGS__::pd_t>
111 static const pd_create_f cpu_impl_list[] = {
112     /* RNN */
113     INSTANCE(ref_rnn_fwd_f32_t),
114     INSTANCE(ref_rnn_fwd_u8s8_t),
115     INSTANCE(ref_rnn_bwd_f32_t),
116     /* conv */
117     INSTANCE(jit_avx512_common_planar_convolution_fwd_t),
118     INSTANCE(jit_avx512_common_dw_convolution_fwd_t),
119     INSTANCE(jit_avx512_common_dw_convolution_bwd_data_t),
120     INSTANCE(jit_avx512_common_dw_convolution_bwd_weights_t),
121     INSTANCE(jit_avx512_common_1x1_convolution_fwd_f32_t),
122     INSTANCE(jit_avx512_common_1x1_convolution_bwd_data_f32_t),
123     INSTANCE(jit_avx512_common_1x1_convolution_bwd_weights_t),
124     INSTANCE(jit_avx512_common_1x1_convolution_fwd_s16s16s32_t),
125     INSTANCE(jit_avx512_common_1x1_convolution_bwd_data_s16s16s32_t),
126     INSTANCE(jit_avx512_core_fp32_wino_conv_2x3_fwd_t),
127     INSTANCE(jit_avx512_core_fp32_wino_conv_4x3_fwd_t),
128     INSTANCE(jit_avx512_core_fp32_wino_conv_4x3_bwd_data_t),
129     INSTANCE(jit_avx512_core_fp32_wino_conv_4x3_bwd_weights_t),
130     INSTANCE(jit_avx512_common_convolution_winograd_fwd_t),
131     INSTANCE(jit_avx512_common_convolution_winograd_bwd_data_t),
132     INSTANCE(jit_avx512_common_convolution_winograd_bwd_weights_t),
133     INSTANCE(jit_avx512_common_convolution_fwd_t<f32>),
134     INSTANCE(jit_avx512_common_convolution_bwd_data_t<f32>),
135     INSTANCE(jit_avx512_common_convolution_bwd_weights_t<f32>),
136     INSTANCE(jit_avx2_planar_convolution_fwd_t),
137     INSTANCE(jit_avx2_dw_convolution_fwd_t),
138     INSTANCE(jit_avx2_dw_convolution_bwd_data_t),
139     INSTANCE(jit_avx2_dw_convolution_bwd_weights_t),
140     INSTANCE(jit_avx2_1x1_convolution_fwd_t),
141     INSTANCE(jit_avx2_1x1_convolution_bwd_data_t),
142     INSTANCE(jit_avx2_1x1_convolution_bwd_weights_t),
143     INSTANCE(jit_sse42_dw_convolution_fwd_t),
144     INSTANCE(jit_sse42_dw_convolution_bwd_data_t),
145     INSTANCE(jit_sse42_dw_convolution_bwd_weights_t),
146     INSTANCE(jit_sse42_1x1_convolution_fwd_t),
147     INSTANCE(jit_avx2_convolution_fwd_t),
148     INSTANCE(jit_avx2_convolution_bwd_data_t),
149     INSTANCE(jit_avx2_convolution_bwd_weights_t),
150     INSTANCE(jit_sse42_convolution_fwd_t),
151     INSTANCE(gemm_convolution_fwd_t),
152     INSTANCE(gemm_convolution_bwd_data_t),
153     INSTANCE(gemm_convolution_bwd_weights_t),
154     INSTANCE(ref_convolution_fwd_t<f32>),
155     INSTANCE(ref_convolution_bwd_data_t<f32, f32, f32, f32>),
156     INSTANCE(ref_convolution_bwd_weights_t<f32, f32, f32, f32>),
157     /* conv (int) */
158     INSTANCE(jit_avx512_core_u8s8s32x_wino_convolution_fwd_t<f32>),
159     INSTANCE(jit_avx512_core_u8s8s32x_wino_convolution_fwd_t<s32>),
160     INSTANCE(jit_avx512_core_u8s8s32x_wino_convolution_fwd_t<s8>),
161     INSTANCE(jit_avx512_core_u8s8s32x_wino_convolution_fwd_t<u8>),
162     INSTANCE(jit_avx512_common_convolution_fwd_t<s16, s16, s32>),
163     INSTANCE(jit_avx512_core_x8s8s32x_1x1_convolution_fwd_t<u8,f32>),
164     INSTANCE(jit_avx512_core_x8s8s32x_1x1_convolution_fwd_t<u8,s32>),
165     INSTANCE(jit_avx512_core_x8s8s32x_1x1_convolution_fwd_t<u8,u8>),
166     INSTANCE(jit_avx512_core_x8s8s32x_1x1_convolution_fwd_t<u8,s8>),
167     INSTANCE(jit_avx512_core_x8s8s32x_1x1_convolution_fwd_t<s8,f32>),
168     INSTANCE(jit_avx512_core_x8s8s32x_1x1_convolution_fwd_t<s8,s32>),
169     INSTANCE(jit_avx512_core_x8s8s32x_1x1_convolution_fwd_t<s8,u8>),
170     INSTANCE(jit_avx512_core_x8s8s32x_1x1_convolution_fwd_t<s8,s8>),
171     INSTANCE(jit_avx512_core_x8s8s32x_convolution_fwd_t<u8,f32>),
172     INSTANCE(jit_avx512_core_x8s8s32x_convolution_fwd_t<u8,s32>),
173     INSTANCE(jit_avx512_core_x8s8s32x_convolution_fwd_t<u8,u8>),
174     INSTANCE(jit_avx512_core_x8s8s32x_convolution_fwd_t<u8,s8>),
175     INSTANCE(jit_avx512_core_x8s8s32x_convolution_fwd_t<s8,f32>),
176     INSTANCE(jit_avx512_core_x8s8s32x_convolution_fwd_t<s8,s32>),
177     INSTANCE(jit_avx512_core_x8s8s32x_convolution_fwd_t<s8,u8>),
178     INSTANCE(jit_avx512_core_x8s8s32x_convolution_fwd_t<s8,s8>),
179     INSTANCE(jit_avx512_common_convolution_bwd_data_t<s16, s16, s32>),
180     INSTANCE(jit_avx512_common_convolution_bwd_weights_t<s16, s16, s32>),
181     INSTANCE(jit_avx2_x8s8s32x_dw_convolution_fwd_t<u8,f32>),
182     INSTANCE(jit_avx2_x8s8s32x_dw_convolution_fwd_t<u8,s32>),
183     INSTANCE(jit_avx2_x8s8s32x_dw_convolution_fwd_t<u8,u8>),
184     INSTANCE(jit_avx2_x8s8s32x_dw_convolution_fwd_t<u8,s8>),
185     INSTANCE(jit_sse42_x8s8s32x_dw_convolution_fwd_t<u8,f32>),
186     INSTANCE(jit_sse42_x8s8s32x_dw_convolution_fwd_t<u8,s32>),
187     INSTANCE(jit_sse42_x8s8s32x_dw_convolution_fwd_t<u8,u8>),
188     INSTANCE(jit_sse42_x8s8s32x_dw_convolution_fwd_t<u8,s8>),
189     INSTANCE(jit_avx2_x8s8s32x_convolution_fwd_t<u8,f32>),
190     INSTANCE(jit_avx2_x8s8s32x_convolution_fwd_t<u8,s32>),
191     INSTANCE(jit_avx2_x8s8s32x_convolution_fwd_t<u8,u8>),
192     INSTANCE(jit_avx2_x8s8s32x_convolution_fwd_t<u8,s8>),
193     INSTANCE(jit_avx2_x8s8s32x_convolution_fwd_t<s8,f32>),
194     INSTANCE(jit_avx2_x8s8s32x_convolution_fwd_t<s8,s32>),
195     INSTANCE(jit_avx2_x8s8s32x_convolution_fwd_t<s8,u8>),
196     INSTANCE(jit_avx2_x8s8s32x_convolution_fwd_t<s8,s8>),
197     INSTANCE(jit_sse42_x8s8s32x_convolution_fwd_t<u8,f32>),
198     INSTANCE(jit_sse42_x8s8s32x_convolution_fwd_t<u8,s32>),
199     INSTANCE(jit_sse42_x8s8s32x_convolution_fwd_t<u8,u8>),
200     INSTANCE(jit_sse42_x8s8s32x_convolution_fwd_t<u8,s8>),
201     INSTANCE(jit_sse42_x8s8s32x_convolution_fwd_t<s8,f32>),
202     INSTANCE(jit_sse42_x8s8s32x_convolution_fwd_t<s8,s32>),
203     INSTANCE(jit_sse42_x8s8s32x_convolution_fwd_t<s8,u8>),
204     INSTANCE(jit_sse42_x8s8s32x_convolution_fwd_t<s8,s8>),
205     INSTANCE(_gemm_x8s8s32x_convolution_fwd_t<u8, s32>),
206     INSTANCE(_gemm_x8s8s32x_convolution_fwd_t<u8, u8>),
207     INSTANCE(_gemm_x8s8s32x_convolution_fwd_t<u8, s8>),
208     INSTANCE(_gemm_x8s8s32x_convolution_fwd_t<u8, f32>),
209     INSTANCE(_gemm_x8s8s32x_convolution_fwd_t<s8, s32>),
210     INSTANCE(_gemm_x8s8s32x_convolution_fwd_t<s8, u8>),
211     INSTANCE(_gemm_x8s8s32x_convolution_fwd_t<s8, s8>),
212     INSTANCE(_gemm_x8s8s32x_convolution_fwd_t<s8, f32>),
213     INSTANCE(_gemm_u8s8s32x_convolution_bwd_data_t<s32>),
214     INSTANCE(_gemm_u8s8s32x_convolution_bwd_data_t<u8>),
215     INSTANCE(_gemm_u8s8s32x_convolution_bwd_data_t<s8>),
216     INSTANCE(_gemm_u8s8s32x_convolution_bwd_data_t<f32>),
217     INSTANCE(ref_convolution_fwd_t<s16, s16, s32, s32>),
218     INSTANCE(ref_convolution_fwd_t<u8, s8, f32, s32>),
219     INSTANCE(ref_convolution_fwd_t<u8, s8, s32, s32>),
220     INSTANCE(ref_convolution_fwd_t<u8, s8, s8, s32>),
221     INSTANCE(ref_convolution_fwd_t<u8, s8, u8, s32>),
222     INSTANCE(ref_convolution_bwd_data_t<s32, s16, s16, s32>),
223     INSTANCE(ref_convolution_bwd_data_t<f32, s8, u8, s32>),
224     INSTANCE(ref_convolution_bwd_data_t<s32, s8, u8, s32>),
225     INSTANCE(ref_convolution_bwd_data_t<s8, s8, u8, s32>),
226     INSTANCE(ref_convolution_bwd_data_t<u8, s8, u8, s32>),
227     INSTANCE(ref_convolution_bwd_weights_t<s16, s32, s16, s32>),
228     /* deconv */
229     INSTANCE(jit_avx512_core_x8s8s32x_1x1_deconvolution_fwd_t<u8,f32>),
230     INSTANCE(jit_avx512_core_x8s8s32x_1x1_deconvolution_fwd_t<u8,s32>),
231     INSTANCE(jit_avx512_core_x8s8s32x_1x1_deconvolution_fwd_t<u8,u8>),
232     INSTANCE(jit_avx512_core_x8s8s32x_1x1_deconvolution_fwd_t<u8,s8>),
233     INSTANCE(jit_avx512_core_x8s8s32x_1x1_deconvolution_fwd_t<s8,f32>),
234     INSTANCE(jit_avx512_core_x8s8s32x_1x1_deconvolution_fwd_t<s8,s32>),
235     INSTANCE(jit_avx512_core_x8s8s32x_1x1_deconvolution_fwd_t<s8,u8>),
236     INSTANCE(jit_avx512_core_x8s8s32x_1x1_deconvolution_fwd_t<s8,s8>),
237     INSTANCE(_jit_avx512_core_x8s8s32x_deconvolution_fwd_t<u8,s32>),
238     INSTANCE(_jit_avx512_core_x8s8s32x_deconvolution_fwd_t<u8,u8>),
239     INSTANCE(_jit_avx512_core_x8s8s32x_deconvolution_fwd_t<u8,s8>),
240     INSTANCE(_jit_avx512_core_x8s8s32x_deconvolution_fwd_t<u8,f32>),
241     INSTANCE(_jit_avx512_core_x8s8s32x_deconvolution_fwd_t<s8,s32>),
242     INSTANCE(_jit_avx512_core_x8s8s32x_deconvolution_fwd_t<s8,u8>),
243     INSTANCE(_jit_avx512_core_x8s8s32x_deconvolution_fwd_t<s8,s8>),
244     INSTANCE(_jit_avx512_core_x8s8s32x_deconvolution_fwd_t<s8,f32>),
245     INSTANCE(ref_deconvolution_bwd_weights_t),
246     INSTANCE(ref_deconvolution_bwd_data_t),
247     INSTANCE(ref_deconvolution_fwd_t),
248     /* shuffle */
249     INSTANCE(ref_shuffle_t<4>), /* f32 or s32 */
250     INSTANCE(ref_shuffle_t<1>), /* s8 or u8 */
251     /* eltwise */
252     INSTANCE(jit_uni_eltwise_fwd_t<avx512_common>),
253     INSTANCE(jit_uni_eltwise_bwd_t<avx512_common>),
254     INSTANCE(jit_uni_eltwise_fwd_t<avx2>),
255     INSTANCE(jit_uni_eltwise_bwd_t<avx2>),
256     INSTANCE(jit_uni_eltwise_fwd_t<sse42>),
257     INSTANCE(jit_uni_eltwise_bwd_t<sse42>),
258     INSTANCE(ref_eltwise_fwd_t<f32>),
259     INSTANCE(ref_eltwise_bwd_t<f32>),
260     /* depthwise */
261     INSTANCE(jit_uni_depthwise_fwd_t<avx512_common>),
262     INSTANCE(jit_uni_depthwise_fwd_t<avx2>),
263     INSTANCE(jit_uni_depthwise_fwd_t<sse42>),
264     INSTANCE(ref_depthwise_fwd_t<f32>),
265     /* eltwise (int) */
266     INSTANCE(ref_eltwise_fwd_t<s32>),
267     INSTANCE(ref_eltwise_fwd_t<s16>),
268     INSTANCE(ref_eltwise_fwd_t<s8>),
269     INSTANCE(ref_eltwise_fwd_t<u8>),
270     INSTANCE(ref_eltwise_bwd_t<s32>),
271     INSTANCE(ref_eltwise_bwd_t<s16>),
272     /* softmax */
273     INSTANCE(jit_uni_softmax_fwd_t<avx512_common>),
274     INSTANCE(jit_uni_softmax_fwd_t<avx2>),
275     INSTANCE(jit_uni_softmax_fwd_t<sse42>),
276     INSTANCE(ref_softmax_fwd_t<f32>),
277     INSTANCE(ref_softmax_bwd_t<f32>),
278     /* pool */
279     INSTANCE(jit_uni_pooling_fwd_t<avx512_common>),
280     INSTANCE(jit_uni_pooling_bwd_t<avx512_common>),
281     INSTANCE(jit_uni_pooling_fwd_t<avx>),
282     INSTANCE(jit_uni_pooling_bwd_t<avx>),
283     INSTANCE(jit_uni_pooling_fwd_t<sse42>),
284     INSTANCE(jit_uni_pooling_bwd_t<sse42>),
285     INSTANCE(nchw_pooling_fwd_t<f32>),
286     INSTANCE(nchw_pooling_bwd_t<f32>),
287     INSTANCE(nhwc_pooling_fwd_t<f32>),
288     INSTANCE(nhwc_pooling_bwd_t<f32>),
289     INSTANCE(ref_pooling_fwd_t<f32>),
290     INSTANCE(ref_pooling_bwd_t<f32>),
291     /* pool (int) */
292     INSTANCE(jit_uni_i8i8_pooling_fwd_t<avx512_core>),
293     INSTANCE(jit_uni_i8i8_pooling_fwd_t<avx2>),
294     INSTANCE(jit_sse42_i8i8_pooling_fwd_t),
295     INSTANCE(ref_pooling_fwd_t<s32>),
296     INSTANCE(ref_pooling_fwd_t<s16, s32>),
297     INSTANCE(ref_pooling_fwd_t<s8, s32>),
298     INSTANCE(ref_pooling_fwd_t<u8, s32>),
299     INSTANCE(ref_pooling_bwd_t<s32>),
300     INSTANCE(ref_pooling_bwd_t<s16, s32>),
301     /* lrn */
302     INSTANCE(jit_avx512_common_lrn_fwd_t),
303     INSTANCE(jit_avx512_common_lrn_bwd_t),
304     INSTANCE(jit_uni_lrn_fwd_t<avx2>),
305     INSTANCE(jit_uni_lrn_bwd_t<avx2>),
306     INSTANCE(jit_uni_lrn_fwd_t<sse42>),
307     INSTANCE(ref_lrn_fwd_t<f32>),
308     INSTANCE(ref_lrn_bwd_t<f32>),
309     /* batch normalization */
310     INSTANCE(jit_uni_batch_normalization_fwd_t<avx512_common>),
311     INSTANCE(jit_uni_batch_normalization_bwd_t<avx512_common>),
312     INSTANCE(jit_uni_batch_normalization_fwd_t<avx2>),
313     INSTANCE(jit_uni_batch_normalization_bwd_t<avx2>),
314     INSTANCE(jit_uni_batch_normalization_fwd_t<sse42>),
315     INSTANCE(jit_uni_batch_normalization_bwd_t<sse42>),
316     INSTANCE(ncsp_batch_normalization_fwd_t),
317     INSTANCE(ncsp_batch_normalization_bwd_t),
318     INSTANCE(nspc_batch_normalization_fwd_t),
319     INSTANCE(nspc_batch_normalization_bwd_t),
320     INSTANCE(ref_batch_normalization_fwd_t<f32>),
321     INSTANCE(ref_batch_normalization_bwd_t<f32>),
322     /* inner product */
323     INSTANCE(gemm_inner_product_fwd_t<f32>),
324     INSTANCE(gemm_inner_product_bwd_data_t<f32>),
325     INSTANCE(gemm_inner_product_bwd_weights_t<f32>),
326     INSTANCE(ref_inner_product_fwd_t<f32>),
327     INSTANCE(ref_inner_product_bwd_data_t<f32, f32, f32, f32>),
328     INSTANCE(ref_inner_product_bwd_weights_t<f32>),
329     /* inner product (int) */
330     INSTANCE(gemm_x8s8s32x_inner_product_fwd_t<u8, u8>),
331     INSTANCE(gemm_x8s8s32x_inner_product_fwd_t<u8, s8>),
332     INSTANCE(gemm_x8s8s32x_inner_product_fwd_t<u8, s32>),
333     INSTANCE(gemm_x8s8s32x_inner_product_fwd_t<u8, f32>),
334     INSTANCE(gemm_x8s8s32x_inner_product_fwd_t<s8, u8>),
335     INSTANCE(gemm_x8s8s32x_inner_product_fwd_t<s8, s8>),
336     INSTANCE(gemm_x8s8s32x_inner_product_fwd_t<s8, s32>),
337     INSTANCE(gemm_x8s8s32x_inner_product_fwd_t<s8, f32>),
338     INSTANCE(ref_inner_product_fwd_t<u8, s8, u8, s32>),
339     INSTANCE(ref_inner_product_fwd_t<u8, s8, s8, s32>),
340     INSTANCE(ref_inner_product_fwd_t<u8, s8, s32, s32>),
341     INSTANCE(ref_inner_product_fwd_t<u8, s8, f32, s32>),
342     INSTANCE(ref_inner_product_fwd_t<s16, s16, s32, s32>),
343     INSTANCE(ref_inner_product_bwd_data_t<s32, s16, s16, s32>),
344     /* roi pooling */
345     INSTANCE(jit_uni_roi_pooling_fwd_t<avx512_common>),
346     INSTANCE(jit_uni_roi_pooling_fwd_t<avx2>),
347     INSTANCE(jit_uni_roi_pooling_fwd_t<sse42>),
348     INSTANCE(ref_roi_pooling_fwd_t<data_type::f32>),
349     /* binary convolution */
350 //    INSTANCE(jit_uni_binary_convolution_fwd_t<avx512_common>),
351     INSTANCE(jit_uni_binary_convolution_fwd_t<avx2>),
352     INSTANCE(jit_uni_binary_convolution_fwd_t<sse42>),
353     INSTANCE(ref_binary_convolution_fwd_t),
354     /* binarization */
355     INSTANCE(jit_uni_binarization_fwd_t<avx512_common>),
356     INSTANCE(jit_uni_binarization_fwd_t<avx2>),
357     INSTANCE(jit_uni_binarization_fwd_t<sse42>),
358     INSTANCE(ref_binarization_fwd_t<f32>),
359     /* eol */
360     nullptr,
361 };
362 #undef INSTANCE
363 }
364
365 const pd_create_f* cpu_engine_t::get_implementation_list() const {
366     return cpu_impl_list;
367 }
368
369 cpu_engine_factory_t engine_factory;
370
371 namespace {
372 // XXX: this is a huge hammer. This disables all and any msan checks on
373 // primitives outputs.
374 //
375 // A proper approach would be an implementation-specific unpoisoning.
376 void unpoison_outputs(primitive_t *p)
377 {
378     for(auto o: p->outputs()) {
379         assert(o->kind() == primitive_kind::memory);
380         void *p;
381         o->get_data_handle(&p);
382         size_t s = ((memory_pd_t *)o->pd())->get_size();
383         msan_unpoison(p, s);
384     }
385 }
386 }
387
388 status_t cpu_engine_t::submit(primitive_t *p, event_t *e,
389         event_vector &prerequisites) {
390     /* FIXME: this should live in primitive execute function... */
391     if (mkldnn_verbose()->level) {
392         double ms = get_msec();
393         p->execute(e);
394         ms = get_msec() - ms;
395         printf("mkldnn_verbose,exec,%s,%g\n", p->pd()->info(), ms);
396         fflush(0);
397     } else {
398         p->execute(e);
399     }
400     if (msan_enabled)
401         unpoison_outputs(p);
402     return success;
403 }
404
405 }
406 }
407 }
408
409 // vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s