1 /*******************************************************************************
2 * Copyright 2016-2018 Intel Corporation
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
8 * http://www.apache.org/licenses/LICENSE-2.0
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
19 #include "cpu_engine.hpp"
20 #include "cpu_memory.hpp"
21 #include "type_helpers.hpp"
22 #include "verbose.hpp"
24 #include "cpu_concat.hpp"
25 #include "cpu_sum.hpp"
27 #include "cpu/rnn/ref_rnn.hpp"
29 #include "cpu/jit_avx512_core_x8s8s32x_1x1_convolution.hpp"
30 #include "cpu/jit_avx512_common_1x1_convolution.hpp"
31 #include "cpu/jit_avx512_core_fp32_wino_conv_4x3.hpp"
32 #include "cpu/jit_avx512_common_convolution_winograd.hpp"
33 #include "cpu/jit_avx512_core_x8s8s32x_convolution.hpp"
34 #include "cpu/jit_avx512_common_convolution.hpp"
35 #include "cpu/jit_avx2_1x1_convolution.hpp"
36 #include "cpu/jit_sse42_1x1_convolution.hpp"
37 #include "cpu/jit_avx2_convolution.hpp"
38 #include "cpu/jit_sse42_convolution.hpp"
39 #include "cpu/gemm_convolution.hpp"
40 #include "cpu/gemm_x8s8s32x_convolution.hpp"
41 #include "cpu/ref_convolution.hpp"
42 #include "cpu/jit_avx512_core_x8s8s32x_deconvolution.hpp"
43 #include "cpu/jit_avx512_core_x8s8s32x_1x1_deconvolution.hpp"
44 #include "cpu/ref_deconvolution.hpp"
45 #include "cpu/ref_shuffle.hpp"
46 #include "cpu/jit_uni_eltwise.hpp"
47 #include "cpu/ref_eltwise.hpp"
48 #include "cpu/ref_softmax.hpp"
49 #include "cpu/jit_uni_pooling.hpp"
50 #include "cpu/jit_uni_i8i8_pooling.hpp"
51 #include "cpu/ref_pooling.hpp"
52 #include "cpu/nchw_pooling.hpp"
53 #include "cpu/nhwc_pooling.hpp"
54 #include "cpu/jit_avx512_common_lrn.hpp"
55 #include "cpu/jit_uni_lrn.hpp"
56 #include "cpu/ref_lrn.hpp"
57 #include "cpu/jit_uni_batch_normalization.hpp"
58 #include "cpu/ref_batch_normalization.hpp"
59 #include "cpu/ncsp_batch_normalization.hpp"
60 #include "cpu/nspc_batch_normalization.hpp"
61 #include "cpu/ref_inner_product.hpp"
62 #include "cpu/gemm_inner_product.hpp"
63 #include "cpu/gemm_x8s8s32x_inner_product.hpp"
64 #include "cpu/jit_uni_dw_convolution.hpp"
65 #include "cpu/jit_avx512_core_u8s8s32x_wino_convolution.hpp"
66 #include "cpu/jit_avx512_core_fp32_wino_conv_2x3.hpp"
67 #include "cpu/jit_uni_roi_pooling.hpp"
68 #include "cpu/jit_uni_softmax.hpp"
69 #include "cpu/ref_roi_pooling.hpp"
70 #include "cpu/jit_uni_depthwise.hpp"
71 #include "cpu/ref_depthwise.hpp"
72 #include "cpu/jit_uni_x8s8s32x_convolution.hpp"
73 #include "cpu/jit_uni_x8s8s32x_dw_convolution.hpp"
74 #include "cpu/jit_sse42_i8i8_pooling.hpp"
75 #include "cpu/jit_uni_planar_convolution.hpp"
76 #include "cpu/jit_uni_binary_convolution.hpp"
77 #include "cpu/ref_binary_convolution.hpp"
78 #include "cpu/jit_uni_binarization.hpp"
79 #include "cpu/ref_binarization.hpp"
85 using namespace mkldnn::impl::status;
87 status_t cpu_engine_t::memory_primitive_desc_create(memory_pd_t **pd,
88 const memory_desc_t *desc) {
89 return safe_ptr_assign<memory_pd_t>(*pd,
90 new cpu_memory_t::pd_t(this, desc));
93 status_t cpu_engine_t::view_primitive_desc_create(view_pd_t **view_pd,
94 const memory_pd_t *memory_pd, const dims_t dims,
95 const dims_t offsets) {
96 assert(memory_pd->engine() == this);
97 cpu_view_t::pd_t *cpu_vpd = nullptr;
98 status_t status = cpu_view_t::pd_t::create(&cpu_vpd,
99 (const cpu_memory_t::pd_t *)memory_pd, dims, offsets);
100 if (status != success) return status;
105 using pd_create_f = mkldnn::impl::engine_t::primitive_desc_create_f;
108 using namespace mkldnn::impl::data_type;
110 #define INSTANCE(...) &primitive_desc_t::create<__VA_ARGS__::pd_t>
111 static const pd_create_f cpu_impl_list[] = {
113 INSTANCE(ref_rnn_fwd_f32_t),
114 INSTANCE(ref_rnn_fwd_u8s8_t),
115 INSTANCE(ref_rnn_bwd_f32_t),
117 INSTANCE(jit_avx512_common_planar_convolution_fwd_t),
118 INSTANCE(jit_avx512_common_dw_convolution_fwd_t),
119 INSTANCE(jit_avx512_common_dw_convolution_bwd_data_t),
120 INSTANCE(jit_avx512_common_dw_convolution_bwd_weights_t),
121 INSTANCE(jit_avx512_common_1x1_convolution_fwd_f32_t),
122 INSTANCE(jit_avx512_common_1x1_convolution_bwd_data_f32_t),
123 INSTANCE(jit_avx512_common_1x1_convolution_bwd_weights_t),
124 INSTANCE(jit_avx512_common_1x1_convolution_fwd_s16s16s32_t),
125 INSTANCE(jit_avx512_common_1x1_convolution_bwd_data_s16s16s32_t),
126 INSTANCE(jit_avx512_core_fp32_wino_conv_2x3_fwd_t),
127 INSTANCE(jit_avx512_core_fp32_wino_conv_4x3_fwd_t),
128 INSTANCE(jit_avx512_core_fp32_wino_conv_4x3_bwd_data_t),
129 INSTANCE(jit_avx512_core_fp32_wino_conv_4x3_bwd_weights_t),
130 INSTANCE(jit_avx512_common_convolution_winograd_fwd_t),
131 INSTANCE(jit_avx512_common_convolution_winograd_bwd_data_t),
132 INSTANCE(jit_avx512_common_convolution_winograd_bwd_weights_t),
133 INSTANCE(jit_avx512_common_convolution_fwd_t<f32>),
134 INSTANCE(jit_avx512_common_convolution_bwd_data_t<f32>),
135 INSTANCE(jit_avx512_common_convolution_bwd_weights_t<f32>),
136 INSTANCE(jit_avx2_planar_convolution_fwd_t),
137 INSTANCE(jit_avx2_dw_convolution_fwd_t),
138 INSTANCE(jit_avx2_dw_convolution_bwd_data_t),
139 INSTANCE(jit_avx2_dw_convolution_bwd_weights_t),
140 INSTANCE(jit_avx2_1x1_convolution_fwd_t),
141 INSTANCE(jit_avx2_1x1_convolution_bwd_data_t),
142 INSTANCE(jit_avx2_1x1_convolution_bwd_weights_t),
143 INSTANCE(jit_sse42_dw_convolution_fwd_t),
144 INSTANCE(jit_sse42_dw_convolution_bwd_data_t),
145 INSTANCE(jit_sse42_dw_convolution_bwd_weights_t),
146 INSTANCE(jit_sse42_1x1_convolution_fwd_t),
147 INSTANCE(jit_avx2_convolution_fwd_t),
148 INSTANCE(jit_avx2_convolution_bwd_data_t),
149 INSTANCE(jit_avx2_convolution_bwd_weights_t),
150 INSTANCE(jit_sse42_convolution_fwd_t),
151 INSTANCE(gemm_convolution_fwd_t),
152 INSTANCE(gemm_convolution_bwd_data_t),
153 INSTANCE(gemm_convolution_bwd_weights_t),
154 INSTANCE(ref_convolution_fwd_t<f32>),
155 INSTANCE(ref_convolution_bwd_data_t<f32, f32, f32, f32>),
156 INSTANCE(ref_convolution_bwd_weights_t<f32, f32, f32, f32>),
158 INSTANCE(jit_avx512_core_u8s8s32x_wino_convolution_fwd_t<f32>),
159 INSTANCE(jit_avx512_core_u8s8s32x_wino_convolution_fwd_t<s32>),
160 INSTANCE(jit_avx512_core_u8s8s32x_wino_convolution_fwd_t<s8>),
161 INSTANCE(jit_avx512_core_u8s8s32x_wino_convolution_fwd_t<u8>),
162 INSTANCE(jit_avx512_common_convolution_fwd_t<s16, s16, s32>),
163 INSTANCE(jit_avx512_core_x8s8s32x_1x1_convolution_fwd_t<u8,f32>),
164 INSTANCE(jit_avx512_core_x8s8s32x_1x1_convolution_fwd_t<u8,s32>),
165 INSTANCE(jit_avx512_core_x8s8s32x_1x1_convolution_fwd_t<u8,u8>),
166 INSTANCE(jit_avx512_core_x8s8s32x_1x1_convolution_fwd_t<u8,s8>),
167 INSTANCE(jit_avx512_core_x8s8s32x_1x1_convolution_fwd_t<s8,f32>),
168 INSTANCE(jit_avx512_core_x8s8s32x_1x1_convolution_fwd_t<s8,s32>),
169 INSTANCE(jit_avx512_core_x8s8s32x_1x1_convolution_fwd_t<s8,u8>),
170 INSTANCE(jit_avx512_core_x8s8s32x_1x1_convolution_fwd_t<s8,s8>),
171 INSTANCE(jit_avx512_core_x8s8s32x_convolution_fwd_t<u8,f32>),
172 INSTANCE(jit_avx512_core_x8s8s32x_convolution_fwd_t<u8,s32>),
173 INSTANCE(jit_avx512_core_x8s8s32x_convolution_fwd_t<u8,u8>),
174 INSTANCE(jit_avx512_core_x8s8s32x_convolution_fwd_t<u8,s8>),
175 INSTANCE(jit_avx512_core_x8s8s32x_convolution_fwd_t<s8,f32>),
176 INSTANCE(jit_avx512_core_x8s8s32x_convolution_fwd_t<s8,s32>),
177 INSTANCE(jit_avx512_core_x8s8s32x_convolution_fwd_t<s8,u8>),
178 INSTANCE(jit_avx512_core_x8s8s32x_convolution_fwd_t<s8,s8>),
179 INSTANCE(jit_avx512_common_convolution_bwd_data_t<s16, s16, s32>),
180 INSTANCE(jit_avx512_common_convolution_bwd_weights_t<s16, s16, s32>),
181 INSTANCE(jit_avx2_x8s8s32x_dw_convolution_fwd_t<u8,f32>),
182 INSTANCE(jit_avx2_x8s8s32x_dw_convolution_fwd_t<u8,s32>),
183 INSTANCE(jit_avx2_x8s8s32x_dw_convolution_fwd_t<u8,u8>),
184 INSTANCE(jit_avx2_x8s8s32x_dw_convolution_fwd_t<u8,s8>),
185 INSTANCE(jit_sse42_x8s8s32x_dw_convolution_fwd_t<u8,f32>),
186 INSTANCE(jit_sse42_x8s8s32x_dw_convolution_fwd_t<u8,s32>),
187 INSTANCE(jit_sse42_x8s8s32x_dw_convolution_fwd_t<u8,u8>),
188 INSTANCE(jit_sse42_x8s8s32x_dw_convolution_fwd_t<u8,s8>),
189 INSTANCE(jit_avx2_x8s8s32x_convolution_fwd_t<u8,f32>),
190 INSTANCE(jit_avx2_x8s8s32x_convolution_fwd_t<u8,s32>),
191 INSTANCE(jit_avx2_x8s8s32x_convolution_fwd_t<u8,u8>),
192 INSTANCE(jit_avx2_x8s8s32x_convolution_fwd_t<u8,s8>),
193 INSTANCE(jit_avx2_x8s8s32x_convolution_fwd_t<s8,f32>),
194 INSTANCE(jit_avx2_x8s8s32x_convolution_fwd_t<s8,s32>),
195 INSTANCE(jit_avx2_x8s8s32x_convolution_fwd_t<s8,u8>),
196 INSTANCE(jit_avx2_x8s8s32x_convolution_fwd_t<s8,s8>),
197 INSTANCE(jit_sse42_x8s8s32x_convolution_fwd_t<u8,f32>),
198 INSTANCE(jit_sse42_x8s8s32x_convolution_fwd_t<u8,s32>),
199 INSTANCE(jit_sse42_x8s8s32x_convolution_fwd_t<u8,u8>),
200 INSTANCE(jit_sse42_x8s8s32x_convolution_fwd_t<u8,s8>),
201 INSTANCE(jit_sse42_x8s8s32x_convolution_fwd_t<s8,f32>),
202 INSTANCE(jit_sse42_x8s8s32x_convolution_fwd_t<s8,s32>),
203 INSTANCE(jit_sse42_x8s8s32x_convolution_fwd_t<s8,u8>),
204 INSTANCE(jit_sse42_x8s8s32x_convolution_fwd_t<s8,s8>),
205 INSTANCE(_gemm_x8s8s32x_convolution_fwd_t<u8, s32>),
206 INSTANCE(_gemm_x8s8s32x_convolution_fwd_t<u8, u8>),
207 INSTANCE(_gemm_x8s8s32x_convolution_fwd_t<u8, s8>),
208 INSTANCE(_gemm_x8s8s32x_convolution_fwd_t<u8, f32>),
209 INSTANCE(_gemm_x8s8s32x_convolution_fwd_t<s8, s32>),
210 INSTANCE(_gemm_x8s8s32x_convolution_fwd_t<s8, u8>),
211 INSTANCE(_gemm_x8s8s32x_convolution_fwd_t<s8, s8>),
212 INSTANCE(_gemm_x8s8s32x_convolution_fwd_t<s8, f32>),
213 INSTANCE(_gemm_u8s8s32x_convolution_bwd_data_t<s32>),
214 INSTANCE(_gemm_u8s8s32x_convolution_bwd_data_t<u8>),
215 INSTANCE(_gemm_u8s8s32x_convolution_bwd_data_t<s8>),
216 INSTANCE(_gemm_u8s8s32x_convolution_bwd_data_t<f32>),
217 INSTANCE(ref_convolution_fwd_t<s16, s16, s32, s32>),
218 INSTANCE(ref_convolution_fwd_t<u8, s8, f32, s32>),
219 INSTANCE(ref_convolution_fwd_t<u8, s8, s32, s32>),
220 INSTANCE(ref_convolution_fwd_t<u8, s8, s8, s32>),
221 INSTANCE(ref_convolution_fwd_t<u8, s8, u8, s32>),
222 INSTANCE(ref_convolution_bwd_data_t<s32, s16, s16, s32>),
223 INSTANCE(ref_convolution_bwd_data_t<f32, s8, u8, s32>),
224 INSTANCE(ref_convolution_bwd_data_t<s32, s8, u8, s32>),
225 INSTANCE(ref_convolution_bwd_data_t<s8, s8, u8, s32>),
226 INSTANCE(ref_convolution_bwd_data_t<u8, s8, u8, s32>),
227 INSTANCE(ref_convolution_bwd_weights_t<s16, s32, s16, s32>),
229 INSTANCE(jit_avx512_core_x8s8s32x_1x1_deconvolution_fwd_t<u8,f32>),
230 INSTANCE(jit_avx512_core_x8s8s32x_1x1_deconvolution_fwd_t<u8,s32>),
231 INSTANCE(jit_avx512_core_x8s8s32x_1x1_deconvolution_fwd_t<u8,u8>),
232 INSTANCE(jit_avx512_core_x8s8s32x_1x1_deconvolution_fwd_t<u8,s8>),
233 INSTANCE(jit_avx512_core_x8s8s32x_1x1_deconvolution_fwd_t<s8,f32>),
234 INSTANCE(jit_avx512_core_x8s8s32x_1x1_deconvolution_fwd_t<s8,s32>),
235 INSTANCE(jit_avx512_core_x8s8s32x_1x1_deconvolution_fwd_t<s8,u8>),
236 INSTANCE(jit_avx512_core_x8s8s32x_1x1_deconvolution_fwd_t<s8,s8>),
237 INSTANCE(_jit_avx512_core_x8s8s32x_deconvolution_fwd_t<u8,s32>),
238 INSTANCE(_jit_avx512_core_x8s8s32x_deconvolution_fwd_t<u8,u8>),
239 INSTANCE(_jit_avx512_core_x8s8s32x_deconvolution_fwd_t<u8,s8>),
240 INSTANCE(_jit_avx512_core_x8s8s32x_deconvolution_fwd_t<u8,f32>),
241 INSTANCE(_jit_avx512_core_x8s8s32x_deconvolution_fwd_t<s8,s32>),
242 INSTANCE(_jit_avx512_core_x8s8s32x_deconvolution_fwd_t<s8,u8>),
243 INSTANCE(_jit_avx512_core_x8s8s32x_deconvolution_fwd_t<s8,s8>),
244 INSTANCE(_jit_avx512_core_x8s8s32x_deconvolution_fwd_t<s8,f32>),
245 INSTANCE(ref_deconvolution_bwd_weights_t),
246 INSTANCE(ref_deconvolution_bwd_data_t),
247 INSTANCE(ref_deconvolution_fwd_t),
249 INSTANCE(ref_shuffle_t<4>), /* f32 or s32 */
250 INSTANCE(ref_shuffle_t<1>), /* s8 or u8 */
252 INSTANCE(jit_uni_eltwise_fwd_t<avx512_common>),
253 INSTANCE(jit_uni_eltwise_bwd_t<avx512_common>),
254 INSTANCE(jit_uni_eltwise_fwd_t<avx2>),
255 INSTANCE(jit_uni_eltwise_bwd_t<avx2>),
256 INSTANCE(jit_uni_eltwise_fwd_t<sse42>),
257 INSTANCE(jit_uni_eltwise_bwd_t<sse42>),
258 INSTANCE(ref_eltwise_fwd_t<f32>),
259 INSTANCE(ref_eltwise_bwd_t<f32>),
261 INSTANCE(jit_uni_depthwise_fwd_t<avx512_common>),
262 INSTANCE(jit_uni_depthwise_fwd_t<avx2>),
263 INSTANCE(jit_uni_depthwise_fwd_t<sse42>),
264 INSTANCE(ref_depthwise_fwd_t<f32>),
266 INSTANCE(ref_eltwise_fwd_t<s32>),
267 INSTANCE(ref_eltwise_fwd_t<s16>),
268 INSTANCE(ref_eltwise_fwd_t<s8>),
269 INSTANCE(ref_eltwise_fwd_t<u8>),
270 INSTANCE(ref_eltwise_bwd_t<s32>),
271 INSTANCE(ref_eltwise_bwd_t<s16>),
273 INSTANCE(jit_uni_softmax_fwd_t<avx512_common>),
274 INSTANCE(jit_uni_softmax_fwd_t<avx2>),
275 INSTANCE(jit_uni_softmax_fwd_t<sse42>),
276 INSTANCE(ref_softmax_fwd_t<f32>),
277 INSTANCE(ref_softmax_bwd_t<f32>),
279 INSTANCE(jit_uni_pooling_fwd_t<avx512_common>),
280 INSTANCE(jit_uni_pooling_bwd_t<avx512_common>),
281 INSTANCE(jit_uni_pooling_fwd_t<avx>),
282 INSTANCE(jit_uni_pooling_bwd_t<avx>),
283 INSTANCE(jit_uni_pooling_fwd_t<sse42>),
284 INSTANCE(jit_uni_pooling_bwd_t<sse42>),
285 INSTANCE(nchw_pooling_fwd_t<f32>),
286 INSTANCE(nchw_pooling_bwd_t<f32>),
287 INSTANCE(nhwc_pooling_fwd_t<f32>),
288 INSTANCE(nhwc_pooling_bwd_t<f32>),
289 INSTANCE(ref_pooling_fwd_t<f32>),
290 INSTANCE(ref_pooling_bwd_t<f32>),
292 INSTANCE(jit_uni_i8i8_pooling_fwd_t<avx512_core>),
293 INSTANCE(jit_uni_i8i8_pooling_fwd_t<avx2>),
294 INSTANCE(jit_sse42_i8i8_pooling_fwd_t),
295 INSTANCE(ref_pooling_fwd_t<s32>),
296 INSTANCE(ref_pooling_fwd_t<s16, s32>),
297 INSTANCE(ref_pooling_fwd_t<s8, s32>),
298 INSTANCE(ref_pooling_fwd_t<u8, s32>),
299 INSTANCE(ref_pooling_bwd_t<s32>),
300 INSTANCE(ref_pooling_bwd_t<s16, s32>),
302 INSTANCE(jit_avx512_common_lrn_fwd_t),
303 INSTANCE(jit_avx512_common_lrn_bwd_t),
304 INSTANCE(jit_uni_lrn_fwd_t<avx2>),
305 INSTANCE(jit_uni_lrn_bwd_t<avx2>),
306 INSTANCE(jit_uni_lrn_fwd_t<sse42>),
307 INSTANCE(ref_lrn_fwd_t<f32>),
308 INSTANCE(ref_lrn_bwd_t<f32>),
309 /* batch normalization */
310 INSTANCE(jit_uni_batch_normalization_fwd_t<avx512_common>),
311 INSTANCE(jit_uni_batch_normalization_bwd_t<avx512_common>),
312 INSTANCE(jit_uni_batch_normalization_fwd_t<avx2>),
313 INSTANCE(jit_uni_batch_normalization_bwd_t<avx2>),
314 INSTANCE(jit_uni_batch_normalization_fwd_t<sse42>),
315 INSTANCE(jit_uni_batch_normalization_bwd_t<sse42>),
316 INSTANCE(ncsp_batch_normalization_fwd_t),
317 INSTANCE(ncsp_batch_normalization_bwd_t),
318 INSTANCE(nspc_batch_normalization_fwd_t),
319 INSTANCE(nspc_batch_normalization_bwd_t),
320 INSTANCE(ref_batch_normalization_fwd_t<f32>),
321 INSTANCE(ref_batch_normalization_bwd_t<f32>),
323 INSTANCE(gemm_inner_product_fwd_t<f32>),
324 INSTANCE(gemm_inner_product_bwd_data_t<f32>),
325 INSTANCE(gemm_inner_product_bwd_weights_t<f32>),
326 INSTANCE(ref_inner_product_fwd_t<f32>),
327 INSTANCE(ref_inner_product_bwd_data_t<f32, f32, f32, f32>),
328 INSTANCE(ref_inner_product_bwd_weights_t<f32>),
329 /* inner product (int) */
330 INSTANCE(gemm_x8s8s32x_inner_product_fwd_t<u8, u8>),
331 INSTANCE(gemm_x8s8s32x_inner_product_fwd_t<u8, s8>),
332 INSTANCE(gemm_x8s8s32x_inner_product_fwd_t<u8, s32>),
333 INSTANCE(gemm_x8s8s32x_inner_product_fwd_t<u8, f32>),
334 INSTANCE(gemm_x8s8s32x_inner_product_fwd_t<s8, u8>),
335 INSTANCE(gemm_x8s8s32x_inner_product_fwd_t<s8, s8>),
336 INSTANCE(gemm_x8s8s32x_inner_product_fwd_t<s8, s32>),
337 INSTANCE(gemm_x8s8s32x_inner_product_fwd_t<s8, f32>),
338 INSTANCE(ref_inner_product_fwd_t<u8, s8, u8, s32>),
339 INSTANCE(ref_inner_product_fwd_t<u8, s8, s8, s32>),
340 INSTANCE(ref_inner_product_fwd_t<u8, s8, s32, s32>),
341 INSTANCE(ref_inner_product_fwd_t<u8, s8, f32, s32>),
342 INSTANCE(ref_inner_product_fwd_t<s16, s16, s32, s32>),
343 INSTANCE(ref_inner_product_bwd_data_t<s32, s16, s16, s32>),
345 INSTANCE(jit_uni_roi_pooling_fwd_t<avx512_common>),
346 INSTANCE(jit_uni_roi_pooling_fwd_t<avx2>),
347 INSTANCE(jit_uni_roi_pooling_fwd_t<sse42>),
348 INSTANCE(ref_roi_pooling_fwd_t<data_type::f32>),
349 /* binary convolution */
350 // INSTANCE(jit_uni_binary_convolution_fwd_t<avx512_common>),
351 INSTANCE(jit_uni_binary_convolution_fwd_t<avx2>),
352 INSTANCE(jit_uni_binary_convolution_fwd_t<sse42>),
353 INSTANCE(ref_binary_convolution_fwd_t),
355 INSTANCE(jit_uni_binarization_fwd_t<avx512_common>),
356 INSTANCE(jit_uni_binarization_fwd_t<avx2>),
357 INSTANCE(jit_uni_binarization_fwd_t<sse42>),
358 INSTANCE(ref_binarization_fwd_t<f32>),
365 const pd_create_f* cpu_engine_t::get_implementation_list() const {
366 return cpu_impl_list;
369 cpu_engine_factory_t engine_factory;
372 // XXX: this is a huge hammer. This disables all and any msan checks on
373 // primitives outputs.
375 // A proper approach would be an implementation-specific unpoisoning.
376 void unpoison_outputs(primitive_t *p)
378 for(auto o: p->outputs()) {
379 assert(o->kind() == primitive_kind::memory);
381 o->get_data_handle(&p);
382 size_t s = ((memory_pd_t *)o->pd())->get_size();
388 status_t cpu_engine_t::submit(primitive_t *p, event_t *e,
389 event_vector &prerequisites) {
390 /* FIXME: this should live in primitive execute function... */
391 if (mkldnn_verbose()->level) {
392 double ms = get_msec();
394 ms = get_msec() - ms;
395 printf("mkldnn_verbose,exec,%s,%g\n", p->pd()->info(), ms);
409 // vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s