1 /*******************************************************************************
2 * Copyright 2016-2018 Intel Corporation
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
8 * http://www.apache.org/licenses/LICENSE-2.0
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
17 // Required for posix_memalign
18 #define _POSIX_C_SOURCE 200112L
45 mkldnn_status_t s = f; \
46 if (s != mkldnn_success) { \
47 printf("[%s:%d] error: %s returns %d\n", __FILE__, __LINE__, #f, \
53 #define CHECK_TRUE(expr) \
57 printf("[%s:%d] %s failed\n", __FILE__, __LINE__, #expr); \
62 void *aligned_malloc(size_t size, size_t alignment) {
64 return _aligned_malloc(size, alignment);
67 return !posix_memalign(&p, alignment, size) ? p : NULL;
72 void _free(void *ptr) {
76 void _free(void *ptr) {
81 static size_t product(ptrdiff_t *arr, size_t size)
84 for (size_t i = 0; i < size; ++i)
89 static void init_net_data(float *data, uint32_t dim, const ptrdiff_t *dims)
92 for (int i = 0; i < dims[0]; ++i) {
93 data[i] = (float)(i % 1637);
95 } else if (dim == 4) {
96 for (int in = 0; in < dims[0]; ++in) {
97 for (int ic = 0; ic < dims[1]; ++ic) {
98 for (int ih = 0; ih < dims[2]; ++ih) {
99 for (int iw = 0; iw < dims[3]; ++iw) {
100 int indx = in * dims[1] * dims[2] * dims[3]
101 + ic * dims[2] * dims[3] + ih * dims[3] + iw;
102 data[indx] = (float)(indx % 1637);
110 static void init_data_memory(uint32_t dim, const ptrdiff_t *dims,
111 mkldnn_memory_format_t user_fmt,
112 mkldnn_data_type_t data_type,
113 mkldnn_engine_t engine, float *data,
114 mkldnn_primitive_t *memory)
116 mkldnn_memory_desc_t prim_md;
117 mkldnn_primitive_desc_t user_pd;
118 CHECK(mkldnn_memory_desc_init(&prim_md, dim, dims, data_type, user_fmt));
119 CHECK(mkldnn_memory_primitive_desc_create(&user_pd, &prim_md, engine));
120 CHECK(mkldnn_primitive_create(memory, user_pd, NULL, NULL));
123 CHECK(mkldnn_memory_get_data_handle(*memory, &req));
124 CHECK_TRUE(req == NULL);
125 CHECK(mkldnn_memory_set_data_handle(*memory, data));
126 CHECK(mkldnn_memory_get_data_handle(*memory, &req));
127 CHECK_TRUE(req == data);
128 CHECK(mkldnn_primitive_desc_destroy(user_pd));
132 prepare_reorder(mkldnn_primitive_t *user_memory, /** in */
133 const_mkldnn_primitive_desc_t *prim_memory_pd, /** in */
134 int dir_is_user_to_prim, /** in: user -> prim or prim -> user */
135 mkldnn_primitive_t *prim_memory, mkldnn_primitive_t
136 *reorder, /** out: reorder primitive created */
139 const_mkldnn_primitive_desc_t user_memory_pd;
140 mkldnn_primitive_get_primitive_desc(*user_memory, &user_memory_pd);
142 if (!mkldnn_memory_primitive_desc_equal(user_memory_pd, *prim_memory_pd)) {
143 CHECK(mkldnn_primitive_create(prim_memory, *prim_memory_pd, NULL,
145 CHECK(mkldnn_memory_set_data_handle(*prim_memory, buffer));
147 mkldnn_primitive_desc_t reorder_pd;
148 if (dir_is_user_to_prim) {
149 /* reorder primitive descriptor doesn't need engine, because it is
150 * already appeared in in- and out- memory primitive descriptors */
151 CHECK(mkldnn_reorder_primitive_desc_create(
152 &reorder_pd, user_memory_pd, *prim_memory_pd));
153 mkldnn_primitive_at_t inputs = { *user_memory, 0 };
154 const_mkldnn_primitive_t outputs[] = { *prim_memory };
155 CHECK(mkldnn_primitive_create(reorder, reorder_pd, &inputs,
158 CHECK(mkldnn_reorder_primitive_desc_create(
159 &reorder_pd, *prim_memory_pd, user_memory_pd));
160 mkldnn_primitive_at_t inputs = { *prim_memory, 0 };
161 const_mkldnn_primitive_t outputs[] = { *user_memory };
162 CHECK(mkldnn_primitive_create(reorder, reorder_pd, &inputs,
165 CHECK(mkldnn_primitive_desc_destroy(reorder_pd));
171 return mkldnn_success;
174 mkldnn_status_t simple_net()
177 mkldnn_engine_t engine;
178 CHECK(mkldnn_engine_create(&engine, mkldnn_cpu, 0 /* idx */));
180 ptrdiff_t net_src_sizes[4] = { BATCH, IC, CONV_IH, CONV_IW };
181 ptrdiff_t net_dst_sizes[4] = { BATCH, OC, POOL_OH, POOL_OW };
184 (float *)aligned_malloc(product(net_src_sizes,4)*sizeof(float), 64);
186 (float *)aligned_malloc(product(net_dst_sizes, 4)*sizeof(float), 64);
188 init_net_data(net_src, 4, net_src_sizes);
189 memset(net_dst, 0, product(net_dst_sizes, 4)*sizeof(float));
191 /*----------------------------------------------------------------------*/
192 /*----------------- Forward Stream -------------------------------------*/
194 * {BATCH, IC, CONV_IH, CONV_IW} (x) {OC, IC, 11, 11} ->
195 * {BATCH, OC, CONV_OH, CONV_OW}
196 * strides: {CONV_STRIDE, CONV_STRIDE}
198 ptrdiff_t *conv_user_src_sizes = net_src_sizes;
199 ptrdiff_t conv_user_weights_sizes[4] = { OC, IC, 11, 11 };
200 ptrdiff_t conv_bias_sizes[4] = { OC };
201 ptrdiff_t conv_user_dst_sizes[4] = { BATCH, OC, CONV_OH, CONV_OW };
202 ptrdiff_t conv_strides[2] = { CONV_STRIDE, CONV_STRIDE };
203 ptrdiff_t conv_padding[2] = { CONV_PAD, CONV_PAD };
205 float *conv_src = net_src;
206 float *conv_weights = (float *)aligned_malloc(
207 product(conv_user_weights_sizes, 4) * sizeof(float), 64);
208 float *conv_bias = (float *)aligned_malloc(
209 product(conv_bias_sizes, 1) * sizeof(float), 64);
211 init_net_data(conv_weights, 4, conv_user_weights_sizes);
212 init_net_data(conv_bias, 1, conv_bias_sizes);
214 /* create memory for user data */
215 mkldnn_primitive_t conv_user_src_memory, conv_user_weights_memory,
216 conv_user_bias_memory;
217 init_data_memory(4, conv_user_src_sizes, mkldnn_nchw, mkldnn_f32, engine,
218 conv_src, &conv_user_src_memory);
219 init_data_memory(4, conv_user_weights_sizes, mkldnn_oihw, mkldnn_f32,
220 engine, conv_weights, &conv_user_weights_memory);
221 init_data_memory(1, conv_bias_sizes, mkldnn_x, mkldnn_f32, engine,
222 conv_bias, &conv_user_bias_memory);
224 /* create data descriptors for convolution w/ no specified format */
225 mkldnn_memory_desc_t conv_src_md, conv_weights_md, conv_bias_md,
227 CHECK(mkldnn_memory_desc_init(
228 &conv_src_md, 4, conv_user_src_sizes, mkldnn_f32, mkldnn_any));
229 CHECK(mkldnn_memory_desc_init(&conv_weights_md, 4, conv_user_weights_sizes,
230 mkldnn_f32, mkldnn_any));
231 CHECK(mkldnn_memory_desc_init(
232 &conv_bias_md, 1, conv_bias_sizes, mkldnn_f32, mkldnn_x));
233 CHECK(mkldnn_memory_desc_init(
234 &conv_dst_md, 4, conv_user_dst_sizes, mkldnn_f32, mkldnn_any));
236 /* create a convolution */
237 mkldnn_convolution_desc_t conv_any_desc;
238 CHECK(mkldnn_convolution_forward_desc_init(
239 &conv_any_desc, mkldnn_forward, mkldnn_convolution_direct,
240 &conv_src_md, &conv_weights_md, &conv_bias_md, &conv_dst_md,
241 conv_strides, conv_padding, conv_padding, mkldnn_padding_zero));
243 mkldnn_primitive_desc_t conv_pd;
244 CHECK(mkldnn_primitive_desc_create(&conv_pd, &conv_any_desc, engine, NULL));
246 mkldnn_primitive_t conv_internal_src_memory, conv_internal_weights_memory,
247 conv_internal_dst_memory;
249 /* create memory for dst data, we don't need to reorder it to user data */
250 const_mkldnn_primitive_desc_t conv_dst_pd
251 = mkldnn_primitive_desc_query_pd(conv_pd, mkldnn_query_dst_pd, 0);
252 CHECK(mkldnn_primitive_create(
253 &conv_internal_dst_memory, conv_dst_pd, NULL, NULL));
254 size_t conv_dst_size = mkldnn_memory_primitive_desc_get_size(conv_dst_pd);
255 float *conv_dst_buffer = (float *)aligned_malloc(conv_dst_size, 64);
256 CHECK(mkldnn_memory_set_data_handle(
257 conv_internal_dst_memory, conv_dst_buffer));
259 /* create reorder primitives between user data and convolution srcs
261 mkldnn_primitive_t conv_reorder_src, conv_reorder_weights;
263 const_mkldnn_primitive_desc_t conv_src_pd
264 = mkldnn_primitive_desc_query_pd(conv_pd, mkldnn_query_src_pd, 0);
265 size_t conv_src_size = mkldnn_memory_primitive_desc_get_size(conv_src_pd);
266 float *conv_src_buffer = (float *)aligned_malloc(conv_src_size, 64);
267 CHECK(prepare_reorder(&conv_user_src_memory, &conv_src_pd, 1,
268 &conv_internal_src_memory, &conv_reorder_src, conv_src_buffer));
270 const_mkldnn_primitive_desc_t conv_weights_pd
271 = mkldnn_primitive_desc_query_pd(
272 conv_pd, mkldnn_query_weights_pd, 0);
273 size_t conv_weights_size
274 = mkldnn_memory_primitive_desc_get_size(conv_weights_pd);
275 float *conv_weights_buffer = (float *)aligned_malloc(conv_weights_size, 64);
276 CHECK(prepare_reorder(&conv_user_weights_memory, &conv_weights_pd, 1,
277 &conv_internal_weights_memory, &conv_reorder_weights,
278 conv_weights_buffer));
280 mkldnn_primitive_t conv_src_memory = conv_internal_src_memory
281 ? conv_internal_src_memory
282 : conv_user_src_memory;
283 mkldnn_primitive_t conv_weights_memory = conv_internal_weights_memory
284 ? conv_internal_weights_memory
285 : conv_user_weights_memory;
287 mkldnn_primitive_at_t conv_srcs[]
288 = { mkldnn_primitive_at(conv_src_memory, 0),
289 mkldnn_primitive_at(conv_weights_memory, 0),
290 mkldnn_primitive_at(conv_user_bias_memory, 0) };
292 const_mkldnn_primitive_t conv_dsts[] = { conv_internal_dst_memory };
294 /* finally create a convolution primitive */
295 mkldnn_primitive_t conv;
296 CHECK(mkldnn_primitive_create(&conv, conv_pd, conv_srcs, conv_dsts));
299 * {BATCH, OC, CONV_OH, CONV_OW} -> {BATCH, OC, CONV_OH, CONV_OW}
301 float negative_slope = 1.0f;
303 /* keep memory format of source same as the format of convolution
304 * output in order to avoid reorder */
305 const mkldnn_memory_desc_t *relu_src_md
306 = mkldnn_primitive_desc_query_memory_d(conv_dst_pd);
308 /* create a relu primitive descriptor */
309 mkldnn_eltwise_desc_t relu_desc;
310 CHECK(mkldnn_eltwise_forward_desc_init(&relu_desc, mkldnn_forward,
311 mkldnn_eltwise_relu, relu_src_md, negative_slope, 0));
313 mkldnn_primitive_desc_t relu_pd;
314 CHECK(mkldnn_primitive_desc_create(&relu_pd, &relu_desc, engine, NULL));
316 /* create relu dst memory primitive */
317 mkldnn_primitive_t relu_dst_memory;
318 const_mkldnn_primitive_desc_t relu_dst_pd
319 = mkldnn_primitive_desc_query_pd(relu_pd, mkldnn_query_dst_pd, 0);
320 CHECK(mkldnn_primitive_create(&relu_dst_memory, relu_dst_pd, NULL, NULL));
321 size_t relu_dst_size = mkldnn_memory_primitive_desc_get_size(relu_dst_pd);
322 float *relu_dst_buffer = (float *)aligned_malloc(relu_dst_size, 64);
323 CHECK(mkldnn_memory_set_data_handle(relu_dst_memory, relu_dst_buffer));
325 /* finally create a relu primitive */
326 mkldnn_primitive_t relu;
327 mkldnn_primitive_at_t relu_srcs = { conv_internal_dst_memory, 0 };
328 const_mkldnn_primitive_t relu_dsts[] = { relu_dst_memory };
330 CHECK(mkldnn_primitive_create(&relu, relu_pd, &relu_srcs, relu_dsts));
333 * {BATCH, OC, CONV_OH, CONV_OW} -> {BATCH, OC, CONV_OH, CONV_OW}
339 uint32_t local_size = 5;
340 float alpha = 0.0001f;
344 /* create lrn src memory descriptor using dst memory descriptor
345 * from previous primitive */
346 const mkldnn_memory_desc_t *lrn_src_md
347 = mkldnn_primitive_desc_query_memory_d(relu_dst_pd);
349 /* create a lrn primitive descriptor */
350 mkldnn_lrn_desc_t lrn_desc;
351 CHECK(mkldnn_lrn_forward_desc_init(&lrn_desc, mkldnn_forward,
352 mkldnn_lrn_across_channels, lrn_src_md,
353 local_size, alpha, beta, k));
355 mkldnn_primitive_desc_t lrn_pd;
356 CHECK(mkldnn_primitive_desc_create(&lrn_pd, &lrn_desc, engine, NULL));
358 /* create primitives for lrn dst and workspace memory */
359 mkldnn_primitive_t lrn_dst_memory, lrn_workspace_memory;
361 const_mkldnn_primitive_desc_t lrn_dst_pd
362 = mkldnn_primitive_desc_query_pd(lrn_pd, mkldnn_query_dst_pd, 0);
363 CHECK(mkldnn_primitive_create(&lrn_dst_memory, lrn_dst_pd, NULL, NULL));
364 size_t lrn_dst_size = mkldnn_memory_primitive_desc_get_size(lrn_dst_pd);
365 float *lrn_dst_buffer = (float *)aligned_malloc(lrn_dst_size, 64);
366 CHECK(mkldnn_memory_set_data_handle(lrn_dst_memory, lrn_dst_buffer));
368 /* create workspace only in training and only for forward primitive*/
369 /* query lrn_pd for workspace, this memory will be shared with forward lrn*/
370 const_mkldnn_primitive_desc_t lrn_workspace_pd
371 = mkldnn_primitive_desc_query_pd(lrn_pd, mkldnn_query_workspace_pd,
373 CHECK(mkldnn_primitive_create(&lrn_workspace_memory, lrn_workspace_pd, NULL,
375 size_t lrn_workspace_size =
376 mkldnn_memory_primitive_desc_get_size(lrn_workspace_pd);
377 float *lrn_workspace_buffer =
378 (float*)aligned_malloc(lrn_workspace_size, 64);
379 memset(lrn_workspace_buffer, 0, lrn_workspace_size);
380 CHECK(mkldnn_memory_set_data_handle(lrn_workspace_memory,
381 lrn_workspace_buffer));
383 mkldnn_primitive_at_t lrn_srcs = { relu_dst_memory, 0 };
385 const_mkldnn_primitive_t lrn_dsts[]
386 = { lrn_dst_memory, lrn_workspace_memory };
388 /* finally create a lrn primitive */
389 mkldnn_primitive_t lrn;
390 CHECK(mkldnn_primitive_create(&lrn, lrn_pd, &lrn_srcs, lrn_dsts));
393 * {BATCH, OC, CONV_OH, CONV_OW} -> {BATCH, OC, POOL_OH, POOL_OW}
395 * strides: {POOL_STRIDE, POOL_STRIDE}
397 ptrdiff_t *pool_dst_sizes = net_dst_sizes;
398 ptrdiff_t pool_kernel[2] = { 3, 3 };
399 ptrdiff_t pool_strides[2] = { POOL_STRIDE, POOL_STRIDE };
400 ptrdiff_t pool_padding[2] = { POOL_PAD, POOL_PAD };
402 /* create pooling src memory descriptor using dst descriptor
403 * from previous primitive */
404 const mkldnn_memory_desc_t *pool_src_md
405 = mkldnn_primitive_desc_query_memory_d(lrn_dst_pd);
407 /* create descriptors for dst pooling data */
408 mkldnn_memory_desc_t pool_dst_md;
409 CHECK(mkldnn_memory_desc_init(&pool_dst_md, 4, pool_dst_sizes, mkldnn_f32,
412 /* create memory for user dst data */
413 mkldnn_primitive_t pool_user_dst_memory;
414 init_data_memory(4, pool_dst_sizes, mkldnn_nchw, mkldnn_f32, engine,
415 net_dst, &pool_user_dst_memory);
417 /* create a pooling primitive descriptor */
418 mkldnn_pooling_desc_t pool_desc;
419 CHECK(mkldnn_pooling_forward_desc_init(
420 &pool_desc, mkldnn_forward, mkldnn_pooling_max, pool_src_md,
421 &pool_dst_md, pool_strides, pool_kernel, pool_padding, pool_padding,
422 mkldnn_padding_zero));
424 mkldnn_primitive_desc_t pool_pd;
425 CHECK(mkldnn_primitive_desc_create(&pool_pd, &pool_desc, engine, NULL));
427 /* create memory for workspace */
428 mkldnn_primitive_t pool_workspace_memory;
429 const_mkldnn_primitive_desc_t pool_workspace_pd
430 = mkldnn_primitive_desc_query_pd(pool_pd, mkldnn_query_workspace_pd,
432 CHECK(mkldnn_primitive_create(&pool_workspace_memory, pool_workspace_pd,
434 size_t pool_workspace_size =
435 mkldnn_memory_primitive_desc_get_size(pool_workspace_pd);
436 float *pool_workspace_buffer =
437 (float*)aligned_malloc(pool_workspace_size, 64);
438 memset(pool_workspace_buffer, 0, pool_workspace_size);
439 CHECK(mkldnn_memory_set_data_handle(pool_workspace_memory,
440 pool_workspace_buffer));
442 mkldnn_primitive_t pool_dst_memory;
444 /* create reorder primitives between pooling dsts and user format dst
446 mkldnn_primitive_t pool_reorder_dst, pool_internal_dst_memory;
447 const_mkldnn_primitive_desc_t pool_dst_pd
448 = mkldnn_primitive_desc_query_pd(pool_pd, mkldnn_query_dst_pd, 0);
449 size_t pool_dst_size = mkldnn_memory_primitive_desc_get_size(pool_dst_pd);
450 float *pool_dst_buffer = (float *)aligned_malloc(pool_dst_size, 64);
451 CHECK(prepare_reorder(&pool_user_dst_memory, &pool_dst_pd, 0,
452 &pool_internal_dst_memory, &pool_reorder_dst,
455 mkldnn_primitive_at_t pool_srcs = { lrn_dst_memory, 0 };
457 pool_dst_memory = pool_internal_dst_memory ? pool_internal_dst_memory
458 : pool_user_dst_memory;
460 const_mkldnn_primitive_t pool_dsts[]
461 = { pool_dst_memory, pool_workspace_memory };
463 /* finally create a pooling primitive */
464 mkldnn_primitive_t pool;
465 CHECK(mkldnn_primitive_create(&pool, pool_pd, &pool_srcs, pool_dsts));
467 /* build a simple net */
469 mkldnn_primitive_t net_fwd[10];
471 if (conv_reorder_src)
472 net_fwd[n_fwd++] = conv_reorder_src;
473 if (conv_reorder_weights)
474 net_fwd[n_fwd++] = conv_reorder_weights;
475 net_fwd[n_fwd++] = conv;
476 net_fwd[n_fwd++] = relu;
477 net_fwd[n_fwd++] = lrn;
478 net_fwd[n_fwd++] = pool;
479 if (pool_reorder_dst)
480 net_fwd[n_fwd++] = pool_reorder_dst;
482 void *net_output = NULL; // output from forward stream:
484 /*----------------------------------------------------------------------*/
485 /*----------------- Backward Stream -------------------------------------*/
486 /* ... user diff_data ...*/
487 float *net_diff_dst = (float *)aligned_malloc(
488 product(pool_dst_sizes, 4) * sizeof(float), 64);
490 init_net_data(net_diff_dst, 4, pool_dst_sizes);
492 /* create memory primitives for user diff dst data*/
493 mkldnn_primitive_t pool_user_diff_dst_memory;
494 init_data_memory(4, pool_dst_sizes, mkldnn_nchw, mkldnn_f32, engine,
495 net_diff_dst, &pool_user_diff_dst_memory);
497 /* Pooling Backward */
498 /* pooling diff src memory descriptor */
499 const mkldnn_memory_desc_t *pool_diff_src_md
500 = mkldnn_primitive_desc_query_memory_d(lrn_dst_pd);
502 /* pooling diff dst memory descriptor */
503 const mkldnn_memory_desc_t *pool_diff_dst_md
504 = mkldnn_primitive_desc_query_memory_d(pool_dst_pd);
506 /* create backward pooling descriptor */
507 mkldnn_pooling_desc_t pool_bwd_desc;
508 CHECK(mkldnn_pooling_backward_desc_init(
509 &pool_bwd_desc, mkldnn_pooling_max, pool_diff_src_md,
510 pool_diff_dst_md, pool_strides, pool_kernel, pool_padding,
511 pool_padding, mkldnn_padding_zero));
513 /* backward primitive descriptor needs to hint forward descriptor*/
514 mkldnn_primitive_desc_t pool_bwd_pd;
515 CHECK(mkldnn_primitive_desc_create(&pool_bwd_pd, &pool_bwd_desc, engine,
518 /* create reorder primitive between user diff dst and pool diff dst
520 mkldnn_primitive_t pool_diff_dst_memory;
521 mkldnn_primitive_t pool_reorder_diff_dst, pool_internal_diff_dst_memory;
522 const_mkldnn_primitive_desc_t pool_diff_dst_pd
523 = mkldnn_primitive_desc_query_pd(pool_bwd_pd,
524 mkldnn_query_diff_dst_pd, 0);
525 size_t pool_diff_dst_size
526 = mkldnn_memory_primitive_desc_get_size(pool_diff_dst_pd);
527 float *pool_diff_dst_buffer
528 = (float *)aligned_malloc(pool_diff_dst_size, 64);
529 CHECK(prepare_reorder(&pool_user_diff_dst_memory, &pool_diff_dst_pd, 1,
530 &pool_internal_diff_dst_memory,
531 &pool_reorder_diff_dst, pool_diff_dst_buffer));
533 pool_diff_dst_memory = pool_internal_diff_dst_memory
534 ? pool_internal_diff_dst_memory
535 : pool_user_diff_dst_memory;
537 /* create memory primitive for pool diff src data */
538 mkldnn_primitive_t pool_diff_src_memory;
539 const_mkldnn_primitive_desc_t pool_diff_src_pd
540 = mkldnn_primitive_desc_query_pd(pool_bwd_pd,
541 mkldnn_query_diff_src_pd, 0);
542 size_t pool_diff_src_size
543 = mkldnn_memory_primitive_desc_get_size(pool_diff_src_pd);
544 float *pool_diff_src_buffer
545 = (float *)aligned_malloc(pool_diff_src_size, 64);
546 CHECK(mkldnn_primitive_create(
547 &pool_diff_src_memory, pool_diff_src_pd, NULL, NULL));
548 CHECK(mkldnn_memory_set_data_handle(pool_diff_src_memory,
549 pool_diff_src_buffer));
551 mkldnn_primitive_at_t pool_diff_dsts[]
552 = { mkldnn_primitive_at(pool_diff_dst_memory, 0),
553 mkldnn_primitive_at(pool_workspace_memory, 0) };
555 const_mkldnn_primitive_t pool_diff_srcs[] = { pool_diff_src_memory };
557 /* finally create backward pooling primitive */
558 mkldnn_primitive_t pool_bwd;
559 CHECK(mkldnn_primitive_create(&pool_bwd, pool_bwd_pd, pool_diff_dsts,
563 const mkldnn_memory_desc_t *lrn_diff_dst_md
564 = mkldnn_primitive_desc_query_memory_d(pool_diff_src_pd);
566 /* create backward lrn descriptor */
567 mkldnn_lrn_desc_t lrn_bwd_desc;
568 CHECK(mkldnn_lrn_backward_desc_init(
569 &lrn_bwd_desc, mkldnn_lrn_across_channels, lrn_src_md,
570 lrn_diff_dst_md, local_size, alpha, beta, k));
572 mkldnn_primitive_desc_t lrn_bwd_pd;
573 CHECK(mkldnn_primitive_desc_create(&lrn_bwd_pd, &lrn_bwd_desc, engine,
576 /* create memory primitives for lrn diff src */
577 mkldnn_primitive_t lrn_diff_src_memory;
578 const_mkldnn_primitive_desc_t lrn_diff_src_pd
579 = mkldnn_primitive_desc_query_pd(lrn_bwd_pd,
580 mkldnn_query_diff_src_pd, 0);
581 size_t lrn_diff_src_size
582 = mkldnn_memory_primitive_desc_get_size(lrn_diff_src_pd);
583 float *lrn_diff_src_buffer = (float *)aligned_malloc(lrn_diff_src_size, 64);
584 CHECK(mkldnn_primitive_create(&lrn_diff_src_memory, lrn_diff_src_pd, NULL,
586 CHECK(mkldnn_memory_set_data_handle(lrn_diff_src_memory,
587 lrn_diff_src_buffer));
589 mkldnn_primitive_at_t lrn_diff_dsts[]
590 = { mkldnn_primitive_at(relu_dst_memory,
591 0), // lrn_bwd requires src as first input
592 mkldnn_primitive_at(pool_diff_src_memory, 0),
593 mkldnn_primitive_at(lrn_workspace_memory, 0) };
595 const_mkldnn_primitive_t lrn_diff_srcs[] = { lrn_diff_src_memory };
597 /* finally create backward lrn primitive */
598 mkldnn_primitive_t lrn_bwd;
599 CHECK(mkldnn_primitive_create(&lrn_bwd, lrn_bwd_pd, lrn_diff_dsts,
603 const mkldnn_memory_desc_t *relu_diff_dst_md
604 = mkldnn_primitive_desc_query_memory_d(lrn_diff_src_pd);
606 /* create backward relu descriptor */
607 mkldnn_eltwise_desc_t relu_bwd_desc;
608 CHECK(mkldnn_eltwise_backward_desc_init(&relu_bwd_desc,
609 mkldnn_eltwise_relu, relu_diff_dst_md, relu_src_md,
612 mkldnn_primitive_desc_t relu_bwd_pd;
613 CHECK(mkldnn_primitive_desc_create(&relu_bwd_pd, &relu_bwd_desc, engine,
616 /* create memory primitives for relu diff src */
617 mkldnn_primitive_t relu_diff_src_memory;
618 const_mkldnn_primitive_desc_t relu_diff_src_pd
619 = mkldnn_primitive_desc_query_pd(relu_bwd_pd,
620 mkldnn_query_diff_src_pd, 0);
621 size_t relu_diff_src_size
622 = mkldnn_memory_primitive_desc_get_size(relu_diff_src_pd);
623 float *relu_diff_src_buffer
624 = (float *)aligned_malloc(relu_diff_src_size, 64);
626 CHECK(mkldnn_primitive_create(&relu_diff_src_memory, relu_diff_src_pd, NULL,
628 CHECK(mkldnn_memory_set_data_handle(relu_diff_src_memory,
629 relu_diff_src_buffer));
631 mkldnn_primitive_at_t relu_diff_dsts[]
632 = { mkldnn_primitive_at(conv_internal_dst_memory, 0),
633 mkldnn_primitive_at(lrn_diff_src_memory, 0) };
635 const_mkldnn_primitive_t relu_diff_srcs[] = { relu_diff_src_memory };
637 /* finally create backward relu primitive */
638 mkldnn_primitive_t relu_bwd;
639 CHECK(mkldnn_primitive_create(&relu_bwd, relu_pd, relu_diff_dsts,
642 /* Backward convolution with respect to weights */
643 float *conv_diff_bias_buffer = (float *)aligned_malloc(
644 product(conv_bias_sizes, 1) * sizeof(float), 64);
645 float *conv_user_diff_weights_buffer = (float *)aligned_malloc(
646 product(conv_user_weights_sizes, 4) * sizeof(float), 64);
648 /* initialize memory for diff weights in user format */
649 mkldnn_primitive_t conv_user_diff_weights_memory;
650 init_data_memory(4, conv_user_weights_sizes, mkldnn_nchw, mkldnn_f32,
651 engine, conv_user_diff_weights_buffer,
652 &conv_user_diff_weights_memory);
654 /* memory descriptors should be in format `any` to allow backward
656 * weights to chose the format it prefers for best performance */
657 mkldnn_memory_desc_t conv_diff_src_md, conv_diff_weights_md,
658 conv_diff_bias_md, conv_diff_dst_md;
659 CHECK(mkldnn_memory_desc_init(
660 &conv_diff_src_md, 4, conv_user_src_sizes, mkldnn_f32, mkldnn_any));
661 CHECK(mkldnn_memory_desc_init(&conv_diff_weights_md, 4,
662 conv_user_weights_sizes, mkldnn_f32, mkldnn_any));
663 CHECK(mkldnn_memory_desc_init(
664 &conv_diff_bias_md, 1, conv_bias_sizes, mkldnn_f32, mkldnn_x));
665 CHECK(mkldnn_memory_desc_init(
666 &conv_diff_dst_md, 4, conv_user_dst_sizes, mkldnn_f32, mkldnn_any));
668 /* create backward convolution descriptor */
669 mkldnn_convolution_desc_t conv_bwd_weights_desc;
670 CHECK(mkldnn_convolution_backward_weights_desc_init(&conv_bwd_weights_desc,
671 mkldnn_convolution_direct, &conv_diff_src_md, &conv_diff_weights_md,
672 &conv_diff_bias_md, &conv_diff_dst_md, conv_strides, conv_padding,
673 conv_padding, mkldnn_padding_zero));
675 mkldnn_primitive_desc_t conv_bwd_weights_pd;
676 CHECK(mkldnn_primitive_desc_create(
677 &conv_bwd_weights_pd, &conv_bwd_weights_desc, engine, conv_pd));
679 /* for best performance convolution backward might chose
680 * different memory format for src and diff_dst
681 * than the memory formats preferred by forward convolution
682 * for src and dst respectively */
683 /* create reorder primitives for src from forward convolution to the
684 * format chosen by backward convolution */
685 mkldnn_primitive_t conv_bwd_reorder_src, conv_bwd_internal_src_memory;
686 const_mkldnn_primitive_desc_t conv_diff_src_pd
687 = mkldnn_primitive_desc_query_pd(conv_bwd_weights_pd,
688 mkldnn_query_src_pd, 0);
689 size_t conv_diff_src_size
690 = mkldnn_memory_primitive_desc_get_size(conv_diff_src_pd);
691 float *conv_diff_src_buffer
692 = (float *)aligned_malloc(conv_diff_src_size, 64);
693 CHECK(prepare_reorder(&conv_src_memory, &conv_diff_src_pd, 1,
694 &conv_bwd_internal_src_memory, &conv_bwd_reorder_src,
695 conv_diff_src_buffer));
697 mkldnn_primitive_t conv_diff_src_memory
698 = conv_bwd_internal_src_memory ? conv_bwd_internal_src_memory
701 /* create reorder primitives for diff_dst between diff_src from relu_bwd
702 * and format preferred by conv_diff_weights */
703 mkldnn_primitive_t conv_reorder_diff_dst, conv_internal_diff_dst_memory;
704 const_mkldnn_primitive_desc_t conv_diff_dst_pd
705 = mkldnn_primitive_desc_query_pd(conv_bwd_weights_pd,
706 mkldnn_query_diff_dst_pd, 0);
707 size_t conv_diff_dst_size
708 = mkldnn_memory_primitive_desc_get_size(conv_diff_dst_pd);
709 float *conv_diff_dst_buffer
710 = (float *)aligned_malloc(conv_diff_dst_size, 64);
712 CHECK(prepare_reorder(&relu_diff_src_memory, &conv_diff_dst_pd, 1,
713 &conv_internal_diff_dst_memory,
714 &conv_reorder_diff_dst, conv_diff_dst_buffer));
716 mkldnn_primitive_t conv_diff_dst_memory
717 = conv_internal_diff_dst_memory ? conv_internal_diff_dst_memory
718 : relu_diff_src_memory;
720 /* create reorder primitives for conv diff weights memory */
721 mkldnn_primitive_t conv_reorder_diff_weights,
722 conv_internal_diff_weights_memory;
723 const_mkldnn_primitive_desc_t conv_diff_weights_pd
724 = mkldnn_primitive_desc_query_pd(conv_bwd_weights_pd,
725 mkldnn_query_diff_weights_pd, 0);
726 size_t conv_diff_weights_size
727 = mkldnn_memory_primitive_desc_get_size(conv_diff_weights_pd);
728 float *conv_diff_weights_buffer
729 = (float *)aligned_malloc(conv_diff_weights_size, 64);
730 CHECK(prepare_reorder(&conv_user_diff_weights_memory, &conv_diff_weights_pd,
731 0, &conv_internal_diff_weights_memory,
732 &conv_reorder_diff_weights,
733 conv_diff_weights_buffer));
735 mkldnn_primitive_t conv_diff_weights_memory
736 = conv_internal_diff_weights_memory
737 ? conv_internal_diff_weights_memory
738 : conv_user_diff_weights_memory;
740 /* create memory primitive for diff bias memory */
741 mkldnn_primitive_t conv_diff_bias_memory;
742 mkldnn_primitive_desc_t conv_diff_bias_pd;
743 CHECK(mkldnn_memory_primitive_desc_create(&conv_diff_bias_pd,
744 &conv_diff_bias_md, engine));
745 CHECK(mkldnn_primitive_create(&conv_diff_bias_memory, conv_diff_bias_pd,
747 CHECK(mkldnn_memory_set_data_handle(conv_diff_bias_memory,
748 conv_diff_bias_buffer));
750 mkldnn_primitive_at_t conv_diff_dsts[]
751 = { mkldnn_primitive_at(conv_diff_src_memory, 0),
752 mkldnn_primitive_at(conv_diff_dst_memory, 0) };
754 const_mkldnn_primitive_t conv_diff_weights[]
755 = { conv_diff_weights_memory, conv_diff_bias_memory };
757 /* finally created backward convolution weights primitive */
758 mkldnn_primitive_t conv_bwd_weights;
759 CHECK(mkldnn_primitive_create(&conv_bwd_weights, conv_bwd_weights_pd,
760 conv_diff_dsts, conv_diff_weights));
762 /* build backward stream */
764 mkldnn_primitive_t net_bwd[10];
766 if (pool_reorder_diff_dst)
767 net_bwd[n_bwd++] = pool_reorder_diff_dst;
768 net_bwd[n_bwd++] = pool_bwd;
769 net_bwd[n_bwd++] = lrn_bwd;
770 net_bwd[n_bwd++] = relu_bwd;
771 if (conv_bwd_reorder_src)
772 net_bwd[n_bwd++] = conv_bwd_reorder_src;
773 if (conv_reorder_diff_dst)
774 net_bwd[n_bwd++] = conv_reorder_diff_dst;
775 net_bwd[n_bwd++] = conv_bwd_weights;
776 if (conv_reorder_diff_weights)
777 net_bwd[n_bwd++] = conv_reorder_diff_weights;
779 // output from backward stream
780 void *net_diff_weights = NULL;
781 void *net_diff_bias = NULL;
783 int n_iter = 10; //number of iterations for training.
784 /* Execute the net */
785 for (int i = 0; i < n_iter; i++) {
786 mkldnn_stream_t stream_fwd;
787 CHECK(mkldnn_stream_create(&stream_fwd, mkldnn_eager));
788 CHECK(mkldnn_stream_submit(stream_fwd, n_fwd, net_fwd, NULL));
789 CHECK(mkldnn_stream_wait(stream_fwd, n_fwd, NULL));
790 CHECK(mkldnn_stream_destroy(stream_fwd));
792 /* Update net_diff_dst */
793 CHECK(mkldnn_memory_get_data_handle(pool_user_dst_memory, &net_output));
794 /*...user updates net_diff_dst using net_output...*/
795 // some user defined func update_diff_dst(net_diff_dst, net_output)
798 mkldnn_stream_t stream_bwd;
799 CHECK(mkldnn_stream_create(&stream_bwd, mkldnn_eager));
800 CHECK(mkldnn_stream_submit(stream_bwd, n_bwd, net_bwd, NULL));
801 CHECK(mkldnn_stream_wait(stream_bwd, n_bwd, NULL));
802 CHECK(mkldnn_stream_destroy(stream_bwd));
804 /*... update weights ... */
805 CHECK(mkldnn_memory_get_data_handle(conv_user_diff_weights_memory,
807 CHECK(mkldnn_memory_get_data_handle(conv_diff_bias_memory,
809 /* ...user updates weights and bias using diff weights and bias...*/
810 // some user defined func update_weights(conv_user_weights_memory,
812 // net_diff_weights, net_diff_bias);
815 /* Cleanup forward */
816 CHECK(mkldnn_primitive_desc_destroy(pool_pd));
817 CHECK(mkldnn_primitive_desc_destroy(lrn_pd));
818 CHECK(mkldnn_primitive_desc_destroy(relu_pd));
819 CHECK(mkldnn_primitive_desc_destroy(conv_pd));
824 mkldnn_primitive_destroy(conv_user_src_memory);
825 mkldnn_primitive_destroy(conv_user_weights_memory);
826 mkldnn_primitive_destroy(conv_user_bias_memory);
827 mkldnn_primitive_destroy(conv_internal_src_memory);
828 mkldnn_primitive_destroy(conv_internal_weights_memory);
829 mkldnn_primitive_destroy(conv_internal_dst_memory);
830 mkldnn_primitive_destroy(conv_reorder_src);
831 mkldnn_primitive_destroy(conv_reorder_weights);
832 mkldnn_primitive_destroy(conv);
837 _free(conv_src_buffer);
838 _free(conv_weights_buffer);
839 _free(conv_dst_buffer);
841 mkldnn_primitive_destroy(relu_dst_memory);
842 mkldnn_primitive_destroy(relu);
844 _free(relu_dst_buffer);
846 mkldnn_primitive_destroy(lrn_workspace_memory);
847 mkldnn_primitive_destroy(lrn_dst_memory);
848 mkldnn_primitive_destroy(lrn);
850 _free(lrn_workspace_buffer);
851 _free(lrn_dst_buffer);
853 mkldnn_primitive_destroy(pool_user_dst_memory);
854 mkldnn_primitive_destroy(pool_internal_dst_memory);
855 mkldnn_primitive_destroy(pool_workspace_memory);
856 mkldnn_primitive_destroy(pool_reorder_dst);
857 mkldnn_primitive_destroy(pool);
859 _free(pool_dst_buffer);
860 _free(pool_workspace_buffer);
862 /* Cleanup backward */
863 CHECK(mkldnn_primitive_desc_destroy(pool_bwd_pd));
864 CHECK(mkldnn_primitive_desc_destroy(lrn_bwd_pd));
865 CHECK(mkldnn_primitive_desc_destroy(relu_bwd_pd));
866 CHECK(mkldnn_primitive_desc_destroy(conv_diff_bias_pd));
867 CHECK(mkldnn_primitive_desc_destroy(conv_bwd_weights_pd));
869 mkldnn_primitive_destroy(pool_user_diff_dst_memory);
870 mkldnn_primitive_destroy(pool_diff_src_memory);
871 mkldnn_primitive_destroy(pool_internal_diff_dst_memory);
872 mkldnn_primitive_destroy(pool_reorder_diff_dst);
873 mkldnn_primitive_destroy(pool_bwd);
876 _free(pool_diff_dst_buffer);
877 _free(pool_diff_src_buffer);
879 mkldnn_primitive_destroy(lrn_diff_src_memory);
880 mkldnn_primitive_destroy(lrn_bwd);
882 _free(lrn_diff_src_buffer);
884 mkldnn_primitive_destroy(relu_diff_src_memory);
885 mkldnn_primitive_destroy(relu_bwd);
887 _free(relu_diff_src_buffer);
889 mkldnn_primitive_destroy(conv_user_diff_weights_memory);
890 mkldnn_primitive_destroy(conv_diff_bias_memory);
891 mkldnn_primitive_destroy(conv_bwd_internal_src_memory);
892 mkldnn_primitive_destroy(conv_bwd_reorder_src);
893 mkldnn_primitive_destroy(conv_internal_diff_dst_memory);
894 mkldnn_primitive_destroy(conv_reorder_diff_dst);
895 mkldnn_primitive_destroy(conv_internal_diff_weights_memory);
896 mkldnn_primitive_destroy(conv_reorder_diff_weights);
897 mkldnn_primitive_destroy(conv_bwd_weights);
899 _free(conv_diff_weights_buffer);
900 _free(conv_diff_bias_buffer);
901 _free(conv_user_diff_weights_buffer);
902 _free(conv_diff_src_buffer);
903 _free(conv_diff_dst_buffer);
905 mkldnn_engine_destroy(engine);
907 return mkldnn_success;
910 int main(int argc, char **argv)
912 mkldnn_status_t result = simple_net();
913 printf("%s\n", (result == mkldnn_success) ? "passed" : "failed");