1 /*******************************************************************************
2 * Copyright 2016-2018 Intel Corporation
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
8 * http://www.apache.org/licenses/LICENSE-2.0
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
17 // Required for posix_memalign
18 #define _POSIX_C_SOURCE 200112L
42 #define CHECK(f) do { \
43 mkldnn_status_t s = f; \
44 if (s != mkldnn_success) { \
45 printf("[%s:%d] error: %s returns %d\n", __FILE__, __LINE__, #f, s); \
50 #define CHECK_TRUE(expr) do { \
53 printf("[%s:%d] %s failed\n", __FILE__, __LINE__, #expr); \
58 void *aligned_malloc(size_t size, size_t alignment) {
60 return _aligned_malloc(size, alignment);
65 return !posix_memalign(&p, alignment, size) ? p : NULL;
70 void _free(void *ptr) {
74 void _free(void *ptr) {
79 static size_t product(ptrdiff_t *arr, size_t size) {
81 for (size_t i = 0; i < size; ++i) prod *= arr[i];
85 static void init_data_memory(uint32_t dim, const ptrdiff_t *dims,
86 mkldnn_memory_format_t user_fmt, mkldnn_data_type_t mkldnn_f32,
87 mkldnn_engine_t engine, float *data, mkldnn_primitive_t *memory)
89 mkldnn_memory_desc_t prim_md;
90 mkldnn_primitive_desc_t user_pd;
91 CHECK(mkldnn_memory_desc_init(&prim_md, dim, dims, mkldnn_f32, user_fmt));
92 CHECK(mkldnn_memory_primitive_desc_create(&user_pd, &prim_md, engine));
93 CHECK(mkldnn_primitive_create(memory, user_pd, NULL, NULL));
96 CHECK(mkldnn_memory_get_data_handle(*memory, &req));
97 CHECK_TRUE(req == NULL);
98 CHECK(mkldnn_memory_set_data_handle(*memory, data));
99 CHECK(mkldnn_memory_get_data_handle(*memory, &req));
100 CHECK_TRUE(req == data);
101 CHECK(mkldnn_primitive_desc_destroy(user_pd));
104 mkldnn_status_t prepare_reorder(
105 mkldnn_primitive_t *user_memory, /** in */
106 const_mkldnn_primitive_desc_t *prim_memory_pd, /** in */
107 int dir_is_user_to_prim, /** in: user -> prim or prim -> user */
108 mkldnn_primitive_t *prim_memory, /** out: memory primitive created */
109 mkldnn_primitive_t *reorder, /** out: reorder primitive created */
112 const_mkldnn_primitive_desc_t user_memory_pd;
113 mkldnn_primitive_get_primitive_desc(*user_memory, &user_memory_pd);
115 if (!mkldnn_memory_primitive_desc_equal(user_memory_pd, *prim_memory_pd)) {
116 /* memory_create(&p, m, NULL) means allocate memory */
117 CHECK(mkldnn_primitive_create(prim_memory, *prim_memory_pd,
119 mkldnn_primitive_desc_t reorder_pd;
120 if (dir_is_user_to_prim) {
121 /* reorder primitive descriptor doesn't need engine, because it is
122 * already appeared in in- and out- memory primitive descriptors */
123 CHECK(mkldnn_reorder_primitive_desc_create(&reorder_pd,
124 user_memory_pd, *prim_memory_pd));
125 mkldnn_primitive_at_t inputs = { *user_memory, 0 };
126 const_mkldnn_primitive_t outputs[] = { *prim_memory };
127 CHECK(mkldnn_primitive_create(reorder, reorder_pd, &inputs,
130 CHECK(mkldnn_reorder_primitive_desc_create(&reorder_pd,
131 *prim_memory_pd, user_memory_pd));
132 mkldnn_primitive_at_t inputs = { *prim_memory, 0 };
133 const_mkldnn_primitive_t outputs[] = { *user_memory };
134 CHECK(mkldnn_primitive_create(reorder, reorder_pd, &inputs,
137 CHECK(mkldnn_memory_set_data_handle(*prim_memory, buffer));
138 CHECK(mkldnn_primitive_desc_destroy(reorder_pd));
144 return mkldnn_success;
147 mkldnn_status_t simple_net() {
149 mkldnn_engine_t engine;
150 CHECK(mkldnn_engine_create(&engine, mkldnn_cpu, 0 /* idx */));
152 float *net_src = (float *)aligned_malloc(
153 BATCH * IC * CONV_IH * CONV_IW * sizeof(float), 64);
154 float *net_dst = (float *)aligned_malloc(
155 BATCH * OC * POOL_OH * POOL_OW * sizeof(float), 64);
158 * {BATCH, IC, CONV_IH, CONV_IW} (x) {OC, IC, CONV_KH, CONV_KW} ->
159 * {BATCH, OC, CONV_OH, CONV_OW}
160 * strides: {CONV_STRIDE, CONV_STRIDE}
162 ptrdiff_t conv_user_src_sizes[4] = { BATCH, IC, CONV_IH, CONV_IW };
163 ptrdiff_t conv_user_weights_sizes[4] = { OC, IC, 11, 11 };
164 ptrdiff_t conv_bias_sizes[4] = { OC };
165 ptrdiff_t conv_user_dst_sizes[4] = { BATCH, OC, CONV_OH, CONV_OW };
166 ptrdiff_t conv_strides[2] = { CONV_STRIDE, CONV_STRIDE };
167 ptrdiff_t conv_padding[2] = { CONV_PAD, CONV_PAD };
169 float *conv_src = net_src;
170 float *conv_weights = (float *)aligned_malloc(
171 product(conv_user_weights_sizes, 4) * sizeof(float), 64);
172 float *conv_bias = (float *)aligned_malloc(
173 product(conv_bias_sizes, 1) * sizeof(float), 64);
175 /* create memory for user data */
176 mkldnn_primitive_t conv_user_src_memory, conv_user_weights_memory,
177 conv_user_bias_memory;
178 init_data_memory(4, conv_user_src_sizes, mkldnn_nchw, mkldnn_f32, engine,
179 conv_src, &conv_user_src_memory);
180 init_data_memory(4, conv_user_weights_sizes, mkldnn_oihw, mkldnn_f32,
181 engine, conv_weights, &conv_user_weights_memory);
182 init_data_memory(1, conv_bias_sizes, mkldnn_x, mkldnn_f32, engine,
183 conv_bias, &conv_user_bias_memory);
185 /* create data descriptors for convolution w/ no specified format */
187 mkldnn_memory_desc_t conv_src_md, conv_weights_md, conv_bias_md,
189 CHECK(mkldnn_memory_desc_init(&conv_src_md, 4, conv_user_src_sizes,
190 mkldnn_f32, mkldnn_any));
191 CHECK(mkldnn_memory_desc_init(&conv_weights_md, 4, conv_user_weights_sizes,
192 mkldnn_f32, mkldnn_any));
193 CHECK(mkldnn_memory_desc_init(&conv_bias_md, 1, conv_bias_sizes,
194 mkldnn_f32, mkldnn_x));
195 CHECK(mkldnn_memory_desc_init(&conv_dst_md, 4, conv_user_dst_sizes,
196 mkldnn_f32, mkldnn_any));
198 /* create a convolution */
199 mkldnn_convolution_desc_t conv_any_desc;
200 CHECK(mkldnn_convolution_forward_desc_init(&conv_any_desc, mkldnn_forward,
201 mkldnn_convolution_direct, &conv_src_md, &conv_weights_md,
202 &conv_bias_md, &conv_dst_md, conv_strides, conv_padding,
203 conv_padding, mkldnn_padding_zero));
205 mkldnn_primitive_desc_t conv_pd;
206 CHECK(mkldnn_primitive_desc_create(&conv_pd, &conv_any_desc,
209 mkldnn_primitive_t conv_internal_src_memory, conv_internal_weights_memory,
210 conv_internal_dst_memory;
212 /* create memory for dst data, we don't need reorder it to user data */
213 const_mkldnn_primitive_desc_t dst_pd
214 = mkldnn_primitive_desc_query_pd(conv_pd, mkldnn_query_dst_pd, 0);
215 CHECK(mkldnn_primitive_create(
216 &conv_internal_dst_memory, dst_pd, NULL, NULL));
217 size_t conv_dst_size = mkldnn_memory_primitive_desc_get_size(dst_pd);
218 float *conv_dst_buffer = (float *)aligned_malloc(conv_dst_size, 64);
219 CHECK(mkldnn_memory_set_data_handle(
220 conv_internal_dst_memory, conv_dst_buffer));
222 /* create reorder primitives between user data and convolution srcs
224 mkldnn_primitive_t conv_reorder_src, conv_reorder_weights;
226 const_mkldnn_primitive_desc_t src_pd = mkldnn_primitive_desc_query_pd(
227 conv_pd, mkldnn_query_src_pd, 0);
228 size_t conv_src_size = mkldnn_memory_primitive_desc_get_size(src_pd);
229 float *conv_src_buffer = (float *)aligned_malloc(conv_src_size, 64);
230 CHECK(prepare_reorder(&conv_user_src_memory, &src_pd, 1,
231 &conv_internal_src_memory, &conv_reorder_src, conv_src_buffer));
233 const_mkldnn_primitive_desc_t weights_pd = mkldnn_primitive_desc_query_pd(
234 conv_pd, mkldnn_query_weights_pd, 0);
235 size_t conv_weights_size
236 = mkldnn_memory_primitive_desc_get_size(weights_pd);
237 float *conv_weights_buffer = (float *)aligned_malloc(conv_weights_size, 64);
238 CHECK(prepare_reorder(&conv_user_weights_memory, &weights_pd, 1,
239 &conv_internal_weights_memory, &conv_reorder_weights,
240 conv_weights_buffer));
242 mkldnn_primitive_t conv_src_memory = conv_internal_src_memory ?
243 conv_internal_src_memory : conv_user_src_memory;
244 mkldnn_primitive_t conv_weights_memory = conv_internal_weights_memory ?
245 conv_internal_weights_memory : conv_user_weights_memory;
247 mkldnn_primitive_at_t conv_srcs[] = {
248 mkldnn_primitive_at(conv_src_memory, 0),
249 mkldnn_primitive_at(conv_weights_memory, 0),
250 mkldnn_primitive_at(conv_user_bias_memory, 0)
253 const_mkldnn_primitive_t conv_dsts[] = { conv_internal_dst_memory };
255 /* finally create a convolution primitive */
256 mkldnn_primitive_t conv;
257 CHECK(mkldnn_primitive_create(&conv, conv_pd, conv_srcs, conv_dsts));
260 * {BATCH, OC, CONV_OH, CONV_OW} -> {BATCH, OC, CONV_OH, CONV_OW}
262 float negative_slope = 1.0f;
264 /* create relu memory descriptor on dst memory descriptor
265 * from previous primitive */
266 const_mkldnn_primitive_desc_t conv_dst_pd = mkldnn_primitive_desc_query_pd(
267 conv_pd, mkldnn_query_dst_pd, 0);
268 const mkldnn_memory_desc_t *relu_src_md =
269 mkldnn_primitive_desc_query_memory_d(conv_dst_pd);
272 mkldnn_eltwise_desc_t relu_desc;
273 CHECK(mkldnn_eltwise_forward_desc_init(&relu_desc, mkldnn_forward,
274 mkldnn_eltwise_relu, relu_src_md, negative_slope, 0));
276 mkldnn_primitive_desc_t relu_pd;
277 CHECK(mkldnn_primitive_desc_create(&relu_pd, &relu_desc, engine, NULL));
279 mkldnn_primitive_t relu_dst_memory;
280 const_mkldnn_primitive_desc_t relu_dst_pd = mkldnn_primitive_desc_query_pd(
281 relu_pd, mkldnn_query_dst_pd, 0);
282 CHECK(mkldnn_primitive_create(&relu_dst_memory, relu_dst_pd, NULL, NULL));
283 size_t relu_dst_size = mkldnn_memory_primitive_desc_get_size(relu_dst_pd);
284 float *relu_dst_buffer = (float *)aligned_malloc(relu_dst_size, 64);
285 CHECK(mkldnn_memory_set_data_handle(relu_dst_memory, relu_dst_buffer));
287 /* finally create a relu primitive */
288 mkldnn_primitive_t relu;
289 mkldnn_primitive_at_t relu_srcs = { conv_internal_dst_memory, 0 };
290 const_mkldnn_primitive_t relu_dsts[] = { relu_dst_memory };
292 CHECK(mkldnn_primitive_create(&relu, relu_pd, &relu_srcs, relu_dsts));
295 * {BATCH, OC, CONV_OH, CONV_OW} -> {BATCH, OC, CONV_OH, CONV_OW}
300 uint32_t local_size = 5;
301 float alpha = 0.0001f;
305 /* create lrn memory descriptor on dst memory descriptor
306 * from previous primitive */
307 const mkldnn_memory_desc_t *lrn_src_md =
308 mkldnn_primitive_desc_query_memory_d(relu_dst_pd);
311 mkldnn_lrn_desc_t lrn_desc;
312 CHECK(mkldnn_lrn_forward_desc_init(&lrn_desc, mkldnn_forward,
313 mkldnn_lrn_across_channels, lrn_src_md, local_size,
316 mkldnn_primitive_desc_t lrn_pd;
317 CHECK(mkldnn_primitive_desc_create(&lrn_pd, &lrn_desc, engine, NULL));
319 mkldnn_primitive_t lrn_dst_memory;
320 const_mkldnn_primitive_desc_t lrn_dst_pd = mkldnn_primitive_desc_query_pd(
321 lrn_pd, mkldnn_query_dst_pd, 0);
322 CHECK(mkldnn_primitive_create(&lrn_dst_memory, lrn_dst_pd, NULL, NULL));
323 size_t lrn_dst_size = mkldnn_memory_primitive_desc_get_size(lrn_dst_pd);
324 float *lrn_dst_buffer = (float *)aligned_malloc(lrn_dst_size, 64);
325 CHECK(mkldnn_memory_set_data_handle(lrn_dst_memory, lrn_dst_buffer));
327 mkldnn_primitive_t lrn_scratch_memory;
328 const_mkldnn_primitive_desc_t lrn_scratch_pd =
329 mkldnn_primitive_desc_query_pd(lrn_pd, mkldnn_query_workspace_pd, 0);
330 CHECK(mkldnn_primitive_create(&lrn_scratch_memory,
331 lrn_scratch_pd, NULL, NULL));
332 size_t lrn_scratch_size =
333 mkldnn_memory_primitive_desc_get_size(lrn_scratch_pd);
334 float *lrn_scratch_buffer = (float*)aligned_malloc(lrn_scratch_size, 64);
335 CHECK(mkldnn_memory_set_data_handle(lrn_scratch_memory,
336 lrn_scratch_buffer));
338 mkldnn_primitive_at_t lrn_srcs = { relu_dst_memory, 0 };
340 const_mkldnn_primitive_t lrn_dsts[] = { lrn_dst_memory,
341 lrn_scratch_memory };
343 /* finally create a lrn primitive */
344 mkldnn_primitive_t lrn;
345 CHECK(mkldnn_primitive_create(&lrn, lrn_pd, &lrn_srcs, lrn_dsts));
348 * {BATCH, OC, CONV_OH, CONV_OW} -> {BATCH, OC, POOL_OH, POOL_OW}
350 * strides: {POOL_STRIDE, POOL_STRIDE}
353 ptrdiff_t pool_dst_sizes[4] = { BATCH, OC, POOL_OH, POOL_OW };
354 ptrdiff_t pool_kernel[2] = { 3, 3 };
355 ptrdiff_t pool_strides[2] = { POOL_STRIDE, POOL_STRIDE };
356 ptrdiff_t pool_padding[2] = { POOL_PAD, POOL_PAD };
358 /* create pooling memory descriptor on dst descriptor
359 * from previous primitive */
360 const mkldnn_memory_desc_t *pool_src_md =
361 mkldnn_primitive_desc_query_memory_d(lrn_dst_pd);
363 /* create descriptors for dst pooling data */
364 mkldnn_memory_desc_t pool_dst_md;
365 CHECK(mkldnn_memory_desc_init(
366 &pool_dst_md, 4, pool_dst_sizes, mkldnn_f32, mkldnn_any));
368 /* create memory for user data */
369 mkldnn_primitive_t pool_user_dst_memory;
370 init_data_memory(4, pool_dst_sizes, mkldnn_nchw, mkldnn_f32, engine,
371 net_dst, &pool_user_dst_memory);
373 /* create a pooling */
374 mkldnn_pooling_desc_t pool_desc;
375 CHECK(mkldnn_pooling_forward_desc_init(&pool_desc, mkldnn_forward,
376 mkldnn_pooling_max, pool_src_md, &pool_dst_md, pool_strides,
377 pool_kernel, pool_padding, pool_padding, mkldnn_padding_zero));
379 mkldnn_primitive_desc_t pool_pd;
380 CHECK(mkldnn_primitive_desc_create(&pool_pd, &pool_desc, engine, NULL));
382 /* create memory for workspace */
383 mkldnn_primitive_t pool_indices_memory;
384 const_mkldnn_primitive_desc_t pool_indices_pd =
385 mkldnn_primitive_desc_query_pd(pool_pd, mkldnn_query_workspace_pd, 0);
386 CHECK(mkldnn_primitive_create(&pool_indices_memory,
387 pool_indices_pd, NULL, NULL));
388 size_t pool_indices_size =
389 mkldnn_memory_primitive_desc_get_size(pool_indices_pd);
390 float *pool_indices_buffer = (float*)aligned_malloc(pool_indices_size, 64);
391 CHECK(mkldnn_memory_set_data_handle(pool_indices_memory,
392 pool_indices_buffer));
394 mkldnn_primitive_t pool_dst_memory;
396 /* create reorder primitives between user data and pooling dsts
398 mkldnn_primitive_t pool_reorder_dst, pool_internal_dst_memory;
399 const_mkldnn_primitive_desc_t pool_dst_pd =
400 mkldnn_primitive_desc_query_pd(pool_pd, mkldnn_query_dst_pd, 0);
401 size_t pool_dst_size = mkldnn_memory_primitive_desc_get_size(pool_dst_pd);
402 float *pool_dst_buffer = (float *)aligned_malloc(pool_dst_size, 64);
403 CHECK(prepare_reorder(&pool_user_dst_memory, &pool_dst_pd, 0,
404 &pool_internal_dst_memory, &pool_reorder_dst, pool_dst_buffer));
406 mkldnn_primitive_at_t pool_srcs = { lrn_dst_memory, 0 };
408 pool_dst_memory = pool_internal_dst_memory ? pool_internal_dst_memory
409 : pool_user_dst_memory;
411 const_mkldnn_primitive_t pool_dsts[] = { pool_dst_memory,
412 pool_indices_memory };
414 /* finally create a pooling primitive */
415 mkldnn_primitive_t pool;
416 CHECK(mkldnn_primitive_create(&pool, pool_pd, &pool_srcs, pool_dsts));
418 /* build a simple net */
420 mkldnn_primitive_t net[10];
422 if (conv_reorder_src) net[n++] = conv_reorder_src;
423 if (conv_reorder_weights) net[n++] = conv_reorder_weights;
428 if (pool_reorder_dst) net[n++] = pool_reorder_dst;
430 mkldnn_stream_t stream;
431 CHECK(mkldnn_stream_create(&stream, mkldnn_eager));
432 CHECK(mkldnn_stream_submit(stream, n, net, NULL));
433 CHECK(mkldnn_stream_wait(stream, n, NULL));
436 CHECK(mkldnn_primitive_desc_destroy(conv_pd));
437 CHECK(mkldnn_primitive_desc_destroy(relu_pd));
438 CHECK(mkldnn_primitive_desc_destroy(lrn_pd));
439 CHECK(mkldnn_primitive_desc_destroy(pool_pd));
441 mkldnn_stream_destroy(stream);
446 mkldnn_primitive_destroy(conv_user_src_memory);
447 mkldnn_primitive_destroy(conv_user_weights_memory);
448 mkldnn_primitive_destroy(conv_user_bias_memory);
449 mkldnn_primitive_destroy(conv_internal_src_memory);
450 mkldnn_primitive_destroy(conv_internal_weights_memory);
451 mkldnn_primitive_destroy(conv_internal_dst_memory);
452 mkldnn_primitive_destroy(conv_reorder_src);
453 mkldnn_primitive_destroy(conv_reorder_weights);
454 mkldnn_primitive_destroy(conv);
459 _free(conv_src_buffer);
460 _free(conv_weights_buffer);
461 _free(conv_dst_buffer);
463 mkldnn_primitive_destroy(relu_dst_memory);
464 mkldnn_primitive_destroy(relu);
466 _free(relu_dst_buffer);
468 mkldnn_primitive_destroy(lrn_scratch_memory);
469 mkldnn_primitive_destroy(lrn_dst_memory);
470 mkldnn_primitive_destroy(lrn);
472 _free(lrn_scratch_buffer);
473 _free(lrn_dst_buffer);
475 mkldnn_primitive_destroy(pool_user_dst_memory);
476 mkldnn_primitive_destroy(pool_internal_dst_memory);
477 mkldnn_primitive_destroy(pool_indices_memory);
478 mkldnn_primitive_destroy(pool_reorder_dst);
479 mkldnn_primitive_destroy(pool);
481 _free(pool_dst_buffer);
482 _free(pool_indices_buffer);
484 mkldnn_engine_destroy(engine);
486 return mkldnn_success;
489 int main(int argc, char **argv) {
490 mkldnn_status_t result = simple_net();
491 printf("%s\n", (result == mkldnn_success) ? "passed" : "failed");