Publishing 2019 R1 content
[platform/upstream/dldt.git] / inference-engine / thirdparty / mkl-dnn / examples / simple_training_net.c
1 /*******************************************************************************
2 * Copyright 2016-2018 Intel Corporation
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *     http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
16
17 // Required for posix_memalign
18 #define _POSIX_C_SOURCE 200112L
19
20 #include <string.h>
21 #include <stdio.h>
22 #include <stdlib.h>
23 #include <math.h>
24 #include "mkldnn.h"
25 #ifdef _WIN32
26 #include <malloc.h>
27 #endif
28
29 #define BATCH 32
30 #define IC 3
31 #define OC 96
32 #define CONV_IH 227
33 #define CONV_IW 227
34 #define CONV_OH 55
35 #define CONV_OW 55
36 #define CONV_STRIDE 4
37 #define CONV_PAD 0
38 #define POOL_OH 27
39 #define POOL_OW 27
40 #define POOL_STRIDE 2
41 #define POOL_PAD 0
42
43 #define CHECK(f)                                                               \
44     do {                                                                       \
45         mkldnn_status_t s = f;                                                 \
46         if (s != mkldnn_success) {                                             \
47             printf("[%s:%d] error: %s returns %d\n", __FILE__, __LINE__, #f,   \
48                    s);                                                         \
49             exit(2);                                                           \
50         }                                                                      \
51     } while (0)
52
53 #define CHECK_TRUE(expr)                                                       \
54     do {                                                                       \
55         int e_ = expr;                                                         \
56         if (!e_) {                                                             \
57             printf("[%s:%d] %s failed\n", __FILE__, __LINE__, #expr);          \
58             exit(2);                                                           \
59         }                                                                      \
60     } while (0)
61
62 void *aligned_malloc(size_t size, size_t alignment) {
63 #ifdef _WIN32
64     return _aligned_malloc(size, alignment);
65 #else
66     void *p;
67     return !posix_memalign(&p, alignment, size) ? p : NULL;
68 #endif
69 }
70
71 #ifdef _WIN32
72 void _free(void *ptr) {
73     _aligned_free(ptr);
74 }
75 #else
76 void _free(void *ptr) {
77     free(ptr);
78 }
79 #endif
80
81 static size_t product(ptrdiff_t *arr, size_t size)
82 {
83     size_t prod = 1;
84     for (size_t i = 0; i < size; ++i)
85         prod *= arr[i];
86     return prod;
87 }
88
89 static void init_net_data(float *data, uint32_t dim, const ptrdiff_t *dims)
90 {
91     if (dim == 1) {
92         for (int i = 0; i < dims[0]; ++i) {
93             data[i] = (float)(i % 1637);
94         }
95     } else if (dim == 4) {
96         for (int in = 0; in < dims[0]; ++in) {
97             for (int ic = 0; ic < dims[1]; ++ic) {
98                 for (int ih = 0; ih < dims[2]; ++ih) {
99                     for (int iw = 0; iw < dims[3]; ++iw) {
100                         int indx = in * dims[1] * dims[2] * dims[3]
101                                    + ic * dims[2] * dims[3] + ih * dims[3] + iw;
102                         data[indx] = (float)(indx % 1637);
103                     }
104                 }
105             }
106         }
107     }
108 }
109
110 static void init_data_memory(uint32_t dim, const ptrdiff_t *dims,
111                              mkldnn_memory_format_t user_fmt,
112                              mkldnn_data_type_t data_type,
113                              mkldnn_engine_t engine, float *data,
114                              mkldnn_primitive_t *memory)
115 {
116     mkldnn_memory_desc_t prim_md;
117     mkldnn_primitive_desc_t user_pd;
118     CHECK(mkldnn_memory_desc_init(&prim_md, dim, dims, data_type, user_fmt));
119     CHECK(mkldnn_memory_primitive_desc_create(&user_pd, &prim_md, engine));
120     CHECK(mkldnn_primitive_create(memory, user_pd, NULL, NULL));
121
122     void *req = NULL;
123     CHECK(mkldnn_memory_get_data_handle(*memory, &req));
124     CHECK_TRUE(req == NULL);
125     CHECK(mkldnn_memory_set_data_handle(*memory, data));
126     CHECK(mkldnn_memory_get_data_handle(*memory, &req));
127     CHECK_TRUE(req == data);
128     CHECK(mkldnn_primitive_desc_destroy(user_pd));
129 }
130
131 mkldnn_status_t
132 prepare_reorder(mkldnn_primitive_t *user_memory,               /** in */
133                 const_mkldnn_primitive_desc_t *prim_memory_pd, /** in */
134                 int dir_is_user_to_prim, /** in: user -> prim or prim -> user */
135                 mkldnn_primitive_t *prim_memory, mkldnn_primitive_t
136                 *reorder, /** out: reorder primitive created */
137                 float *buffer)
138 {
139     const_mkldnn_primitive_desc_t user_memory_pd;
140     mkldnn_primitive_get_primitive_desc(*user_memory, &user_memory_pd);
141
142     if (!mkldnn_memory_primitive_desc_equal(user_memory_pd, *prim_memory_pd)) {
143         CHECK(mkldnn_primitive_create(prim_memory, *prim_memory_pd, NULL,
144                                       NULL));
145         CHECK(mkldnn_memory_set_data_handle(*prim_memory, buffer));
146
147         mkldnn_primitive_desc_t reorder_pd;
148         if (dir_is_user_to_prim) {
149             /* reorder primitive descriptor doesn't need engine, because it is
150              * already appeared in in- and out- memory primitive descriptors */
151             CHECK(mkldnn_reorder_primitive_desc_create(
152                     &reorder_pd, user_memory_pd, *prim_memory_pd));
153             mkldnn_primitive_at_t inputs = { *user_memory, 0 };
154             const_mkldnn_primitive_t outputs[] = { *prim_memory };
155             CHECK(mkldnn_primitive_create(reorder, reorder_pd, &inputs,
156                                           outputs));
157         } else {
158             CHECK(mkldnn_reorder_primitive_desc_create(
159                     &reorder_pd, *prim_memory_pd, user_memory_pd));
160             mkldnn_primitive_at_t inputs = { *prim_memory, 0 };
161             const_mkldnn_primitive_t outputs[] = { *user_memory };
162             CHECK(mkldnn_primitive_create(reorder, reorder_pd, &inputs,
163                                           outputs));
164         }
165         CHECK(mkldnn_primitive_desc_destroy(reorder_pd));
166     } else {
167         *prim_memory = NULL;
168         *reorder = NULL;
169     }
170
171     return mkldnn_success;
172 }
173
174 mkldnn_status_t simple_net()
175 {
176
177     mkldnn_engine_t engine;
178     CHECK(mkldnn_engine_create(&engine, mkldnn_cpu, 0 /* idx */));
179
180     ptrdiff_t net_src_sizes[4] = { BATCH, IC, CONV_IH, CONV_IW };
181     ptrdiff_t net_dst_sizes[4] = { BATCH, OC, POOL_OH, POOL_OW };
182
183     float *net_src =
184         (float *)aligned_malloc(product(net_src_sizes,4)*sizeof(float), 64);
185     float *net_dst =
186         (float *)aligned_malloc(product(net_dst_sizes, 4)*sizeof(float), 64);
187
188     init_net_data(net_src, 4, net_src_sizes);
189     memset(net_dst, 0, product(net_dst_sizes, 4)*sizeof(float));
190
191     /*----------------------------------------------------------------------*/
192     /*----------------- Forward Stream -------------------------------------*/
193     /* AlexNet: conv
194      * {BATCH, IC, CONV_IH, CONV_IW} (x) {OC, IC, 11, 11} ->
195      * {BATCH, OC, CONV_OH, CONV_OW}
196      * strides: {CONV_STRIDE, CONV_STRIDE}
197      */
198     ptrdiff_t *conv_user_src_sizes = net_src_sizes;
199     ptrdiff_t conv_user_weights_sizes[4] = { OC, IC, 11, 11 };
200     ptrdiff_t conv_bias_sizes[4] = { OC };
201     ptrdiff_t conv_user_dst_sizes[4] = { BATCH, OC, CONV_OH, CONV_OW };
202     ptrdiff_t conv_strides[2] = { CONV_STRIDE, CONV_STRIDE };
203     ptrdiff_t conv_padding[2] = { CONV_PAD, CONV_PAD };
204
205     float *conv_src = net_src;
206     float *conv_weights = (float *)aligned_malloc(
207             product(conv_user_weights_sizes, 4) * sizeof(float), 64);
208     float *conv_bias = (float *)aligned_malloc(
209             product(conv_bias_sizes, 1) * sizeof(float), 64);
210
211     init_net_data(conv_weights, 4, conv_user_weights_sizes);
212     init_net_data(conv_bias, 1, conv_bias_sizes);
213
214     /* create memory for user data */
215     mkldnn_primitive_t conv_user_src_memory, conv_user_weights_memory,
216             conv_user_bias_memory;
217     init_data_memory(4, conv_user_src_sizes, mkldnn_nchw, mkldnn_f32, engine,
218                      conv_src, &conv_user_src_memory);
219     init_data_memory(4, conv_user_weights_sizes, mkldnn_oihw, mkldnn_f32,
220             engine, conv_weights, &conv_user_weights_memory);
221     init_data_memory(1, conv_bias_sizes, mkldnn_x, mkldnn_f32, engine,
222             conv_bias, &conv_user_bias_memory);
223
224     /* create data descriptors for convolution w/ no specified format */
225     mkldnn_memory_desc_t conv_src_md, conv_weights_md, conv_bias_md,
226             conv_dst_md;
227     CHECK(mkldnn_memory_desc_init(
228             &conv_src_md, 4, conv_user_src_sizes, mkldnn_f32, mkldnn_any));
229     CHECK(mkldnn_memory_desc_init(&conv_weights_md, 4, conv_user_weights_sizes,
230             mkldnn_f32, mkldnn_any));
231     CHECK(mkldnn_memory_desc_init(
232             &conv_bias_md, 1, conv_bias_sizes, mkldnn_f32, mkldnn_x));
233     CHECK(mkldnn_memory_desc_init(
234             &conv_dst_md, 4, conv_user_dst_sizes, mkldnn_f32, mkldnn_any));
235
236     /* create a convolution */
237     mkldnn_convolution_desc_t conv_any_desc;
238     CHECK(mkldnn_convolution_forward_desc_init(
239             &conv_any_desc, mkldnn_forward, mkldnn_convolution_direct,
240             &conv_src_md, &conv_weights_md, &conv_bias_md, &conv_dst_md,
241             conv_strides, conv_padding, conv_padding, mkldnn_padding_zero));
242
243     mkldnn_primitive_desc_t conv_pd;
244     CHECK(mkldnn_primitive_desc_create(&conv_pd, &conv_any_desc, engine, NULL));
245
246     mkldnn_primitive_t conv_internal_src_memory, conv_internal_weights_memory,
247             conv_internal_dst_memory;
248
249     /* create memory for dst data, we don't need to reorder it to user data */
250     const_mkldnn_primitive_desc_t conv_dst_pd
251             = mkldnn_primitive_desc_query_pd(conv_pd, mkldnn_query_dst_pd, 0);
252     CHECK(mkldnn_primitive_create(
253             &conv_internal_dst_memory, conv_dst_pd, NULL, NULL));
254     size_t conv_dst_size = mkldnn_memory_primitive_desc_get_size(conv_dst_pd);
255     float *conv_dst_buffer = (float *)aligned_malloc(conv_dst_size, 64);
256     CHECK(mkldnn_memory_set_data_handle(
257             conv_internal_dst_memory, conv_dst_buffer));
258
259     /* create reorder primitives between user data and convolution srcs
260      * if required */
261     mkldnn_primitive_t conv_reorder_src, conv_reorder_weights;
262
263     const_mkldnn_primitive_desc_t conv_src_pd
264             = mkldnn_primitive_desc_query_pd(conv_pd, mkldnn_query_src_pd, 0);
265     size_t conv_src_size = mkldnn_memory_primitive_desc_get_size(conv_src_pd);
266     float *conv_src_buffer = (float *)aligned_malloc(conv_src_size, 64);
267     CHECK(prepare_reorder(&conv_user_src_memory, &conv_src_pd, 1,
268         &conv_internal_src_memory, &conv_reorder_src, conv_src_buffer));
269
270     const_mkldnn_primitive_desc_t conv_weights_pd
271             = mkldnn_primitive_desc_query_pd(
272                     conv_pd, mkldnn_query_weights_pd, 0);
273     size_t conv_weights_size
274             = mkldnn_memory_primitive_desc_get_size(conv_weights_pd);
275     float *conv_weights_buffer = (float *)aligned_malloc(conv_weights_size, 64);
276     CHECK(prepare_reorder(&conv_user_weights_memory, &conv_weights_pd, 1,
277             &conv_internal_weights_memory, &conv_reorder_weights,
278             conv_weights_buffer));
279
280     mkldnn_primitive_t conv_src_memory = conv_internal_src_memory
281                                                 ? conv_internal_src_memory
282                                                 : conv_user_src_memory;
283     mkldnn_primitive_t conv_weights_memory = conv_internal_weights_memory
284                                                 ? conv_internal_weights_memory
285                                                 : conv_user_weights_memory;
286
287     mkldnn_primitive_at_t conv_srcs[]
288             = { mkldnn_primitive_at(conv_src_memory, 0),
289                 mkldnn_primitive_at(conv_weights_memory, 0),
290                 mkldnn_primitive_at(conv_user_bias_memory, 0) };
291
292     const_mkldnn_primitive_t conv_dsts[] = { conv_internal_dst_memory };
293
294     /* finally create a convolution primitive */
295     mkldnn_primitive_t conv;
296     CHECK(mkldnn_primitive_create(&conv, conv_pd, conv_srcs, conv_dsts));
297
298     /* AlexNet: relu
299      * {BATCH, OC, CONV_OH, CONV_OW} -> {BATCH, OC, CONV_OH, CONV_OW}
300      */
301     float negative_slope = 1.0f;
302
303     /* keep memory format of source same as the format of convolution
304        * output in order to avoid reorder */
305     const mkldnn_memory_desc_t *relu_src_md
306             = mkldnn_primitive_desc_query_memory_d(conv_dst_pd);
307
308     /* create a relu primitive descriptor */
309     mkldnn_eltwise_desc_t relu_desc;
310     CHECK(mkldnn_eltwise_forward_desc_init(&relu_desc, mkldnn_forward,
311                 mkldnn_eltwise_relu, relu_src_md, negative_slope, 0));
312
313     mkldnn_primitive_desc_t relu_pd;
314     CHECK(mkldnn_primitive_desc_create(&relu_pd, &relu_desc, engine, NULL));
315
316     /* create relu dst memory primitive */
317     mkldnn_primitive_t relu_dst_memory;
318     const_mkldnn_primitive_desc_t relu_dst_pd
319             = mkldnn_primitive_desc_query_pd(relu_pd, mkldnn_query_dst_pd, 0);
320     CHECK(mkldnn_primitive_create(&relu_dst_memory, relu_dst_pd, NULL, NULL));
321     size_t relu_dst_size = mkldnn_memory_primitive_desc_get_size(relu_dst_pd);
322     float *relu_dst_buffer = (float *)aligned_malloc(relu_dst_size, 64);
323     CHECK(mkldnn_memory_set_data_handle(relu_dst_memory, relu_dst_buffer));
324
325     /* finally create a relu primitive */
326     mkldnn_primitive_t relu;
327     mkldnn_primitive_at_t relu_srcs = { conv_internal_dst_memory, 0 };
328     const_mkldnn_primitive_t relu_dsts[] = { relu_dst_memory };
329
330     CHECK(mkldnn_primitive_create(&relu, relu_pd, &relu_srcs, relu_dsts));
331
332     /* AlexNet: lrn
333      * {BATCH, OC, CONV_OH, CONV_OW} -> {BATCH, OC, CONV_OH, CONV_OW}
334      * local size: 5
335      * alpha: 0.0001
336      * beta: 0.75
337      * k: 1.0
338      */
339     uint32_t local_size = 5;
340     float alpha = 0.0001f;
341     float beta = 0.75f;
342     float k = 1.0f;
343
344     /* create lrn src memory descriptor using dst memory descriptor
345      *  from previous primitive */
346     const mkldnn_memory_desc_t *lrn_src_md
347             = mkldnn_primitive_desc_query_memory_d(relu_dst_pd);
348
349     /* create a lrn primitive descriptor */
350     mkldnn_lrn_desc_t lrn_desc;
351     CHECK(mkldnn_lrn_forward_desc_init(&lrn_desc, mkldnn_forward,
352                                        mkldnn_lrn_across_channels, lrn_src_md,
353                                        local_size, alpha, beta, k));
354
355     mkldnn_primitive_desc_t lrn_pd;
356     CHECK(mkldnn_primitive_desc_create(&lrn_pd, &lrn_desc, engine, NULL));
357
358     /* create primitives for lrn dst and workspace memory */
359     mkldnn_primitive_t lrn_dst_memory, lrn_workspace_memory;
360
361     const_mkldnn_primitive_desc_t lrn_dst_pd
362             = mkldnn_primitive_desc_query_pd(lrn_pd, mkldnn_query_dst_pd, 0);
363     CHECK(mkldnn_primitive_create(&lrn_dst_memory, lrn_dst_pd, NULL, NULL));
364     size_t lrn_dst_size = mkldnn_memory_primitive_desc_get_size(lrn_dst_pd);
365     float *lrn_dst_buffer = (float *)aligned_malloc(lrn_dst_size, 64);
366     CHECK(mkldnn_memory_set_data_handle(lrn_dst_memory, lrn_dst_buffer));
367
368     /* create workspace only in training and only for forward primitive*/
369     /* query lrn_pd for workspace, this memory will be shared with forward lrn*/
370     const_mkldnn_primitive_desc_t lrn_workspace_pd
371             = mkldnn_primitive_desc_query_pd(lrn_pd, mkldnn_query_workspace_pd,
372                                              0);
373     CHECK(mkldnn_primitive_create(&lrn_workspace_memory, lrn_workspace_pd, NULL,
374                                   NULL));
375     size_t lrn_workspace_size =
376         mkldnn_memory_primitive_desc_get_size(lrn_workspace_pd);
377     float *lrn_workspace_buffer =
378         (float*)aligned_malloc(lrn_workspace_size, 64);
379     memset(lrn_workspace_buffer, 0, lrn_workspace_size);
380     CHECK(mkldnn_memory_set_data_handle(lrn_workspace_memory,
381                                         lrn_workspace_buffer));
382
383     mkldnn_primitive_at_t lrn_srcs = { relu_dst_memory, 0 };
384
385     const_mkldnn_primitive_t lrn_dsts[]
386             = { lrn_dst_memory, lrn_workspace_memory };
387
388     /* finally create a lrn primitive */
389     mkldnn_primitive_t lrn;
390     CHECK(mkldnn_primitive_create(&lrn, lrn_pd, &lrn_srcs, lrn_dsts));
391
392     /* AlexNet: pool
393      * {BATCH, OC, CONV_OH, CONV_OW} -> {BATCH, OC, POOL_OH, POOL_OW}
394      * kernel: {3, 3}
395      * strides: {POOL_STRIDE, POOL_STRIDE}
396      */
397     ptrdiff_t *pool_dst_sizes = net_dst_sizes;
398     ptrdiff_t pool_kernel[2] = { 3, 3 };
399     ptrdiff_t pool_strides[2] = { POOL_STRIDE, POOL_STRIDE };
400     ptrdiff_t pool_padding[2] = { POOL_PAD, POOL_PAD };
401
402     /* create pooling src memory descriptor using dst descriptor
403      *  from previous primitive */
404     const mkldnn_memory_desc_t *pool_src_md
405             = mkldnn_primitive_desc_query_memory_d(lrn_dst_pd);
406
407     /* create descriptors for dst pooling data */
408     mkldnn_memory_desc_t pool_dst_md;
409     CHECK(mkldnn_memory_desc_init(&pool_dst_md, 4, pool_dst_sizes, mkldnn_f32,
410                                   mkldnn_any));
411
412     /* create memory for user dst data */
413     mkldnn_primitive_t pool_user_dst_memory;
414     init_data_memory(4, pool_dst_sizes, mkldnn_nchw, mkldnn_f32, engine,
415                      net_dst, &pool_user_dst_memory);
416
417     /* create a pooling primitive descriptor */
418     mkldnn_pooling_desc_t pool_desc;
419     CHECK(mkldnn_pooling_forward_desc_init(
420             &pool_desc, mkldnn_forward, mkldnn_pooling_max, pool_src_md,
421             &pool_dst_md, pool_strides, pool_kernel, pool_padding, pool_padding,
422             mkldnn_padding_zero));
423
424     mkldnn_primitive_desc_t pool_pd;
425     CHECK(mkldnn_primitive_desc_create(&pool_pd, &pool_desc, engine, NULL));
426
427     /* create memory for workspace */
428     mkldnn_primitive_t pool_workspace_memory;
429     const_mkldnn_primitive_desc_t pool_workspace_pd
430             = mkldnn_primitive_desc_query_pd(pool_pd, mkldnn_query_workspace_pd,
431                                              0);
432     CHECK(mkldnn_primitive_create(&pool_workspace_memory, pool_workspace_pd,
433                                   NULL, NULL));
434     size_t pool_workspace_size =
435         mkldnn_memory_primitive_desc_get_size(pool_workspace_pd);
436     float *pool_workspace_buffer =
437         (float*)aligned_malloc(pool_workspace_size, 64);
438     memset(pool_workspace_buffer, 0, pool_workspace_size);
439     CHECK(mkldnn_memory_set_data_handle(pool_workspace_memory,
440                                         pool_workspace_buffer));
441
442     mkldnn_primitive_t pool_dst_memory;
443
444     /* create reorder primitives between pooling dsts and user format dst
445      * if required */
446     mkldnn_primitive_t pool_reorder_dst, pool_internal_dst_memory;
447     const_mkldnn_primitive_desc_t pool_dst_pd
448             = mkldnn_primitive_desc_query_pd(pool_pd, mkldnn_query_dst_pd, 0);
449     size_t pool_dst_size = mkldnn_memory_primitive_desc_get_size(pool_dst_pd);
450     float *pool_dst_buffer = (float *)aligned_malloc(pool_dst_size, 64);
451     CHECK(prepare_reorder(&pool_user_dst_memory, &pool_dst_pd, 0,
452                           &pool_internal_dst_memory, &pool_reorder_dst,
453                           pool_dst_buffer));
454
455     mkldnn_primitive_at_t pool_srcs = { lrn_dst_memory, 0 };
456
457     pool_dst_memory = pool_internal_dst_memory ? pool_internal_dst_memory
458                                                : pool_user_dst_memory;
459
460     const_mkldnn_primitive_t pool_dsts[]
461             = { pool_dst_memory, pool_workspace_memory };
462
463     /* finally create a pooling primitive */
464     mkldnn_primitive_t pool;
465     CHECK(mkldnn_primitive_create(&pool, pool_pd, &pool_srcs, pool_dsts));
466
467     /* build a simple net */
468     uint32_t n_fwd = 0;
469     mkldnn_primitive_t net_fwd[10];
470
471     if (conv_reorder_src)
472         net_fwd[n_fwd++] = conv_reorder_src;
473     if (conv_reorder_weights)
474         net_fwd[n_fwd++] = conv_reorder_weights;
475     net_fwd[n_fwd++] = conv;
476     net_fwd[n_fwd++] = relu;
477     net_fwd[n_fwd++] = lrn;
478     net_fwd[n_fwd++] = pool;
479     if (pool_reorder_dst)
480         net_fwd[n_fwd++] = pool_reorder_dst;
481
482     void *net_output = NULL; // output from forward stream:
483
484     /*----------------------------------------------------------------------*/
485     /*----------------- Backward Stream -------------------------------------*/
486     /* ... user diff_data ...*/
487     float *net_diff_dst = (float *)aligned_malloc(
488         product(pool_dst_sizes, 4) * sizeof(float), 64);
489
490     init_net_data(net_diff_dst, 4, pool_dst_sizes);
491
492     /* create memory primitives for user diff dst data*/
493     mkldnn_primitive_t pool_user_diff_dst_memory;
494     init_data_memory(4, pool_dst_sizes, mkldnn_nchw, mkldnn_f32, engine,
495                      net_diff_dst, &pool_user_diff_dst_memory);
496
497     /* Pooling Backward */
498     /* pooling diff src memory descriptor */
499     const mkldnn_memory_desc_t *pool_diff_src_md
500             = mkldnn_primitive_desc_query_memory_d(lrn_dst_pd);
501
502     /* pooling diff dst memory descriptor */
503     const mkldnn_memory_desc_t *pool_diff_dst_md
504             = mkldnn_primitive_desc_query_memory_d(pool_dst_pd);
505
506     /* create backward pooling descriptor */
507     mkldnn_pooling_desc_t pool_bwd_desc;
508     CHECK(mkldnn_pooling_backward_desc_init(
509             &pool_bwd_desc, mkldnn_pooling_max, pool_diff_src_md,
510             pool_diff_dst_md, pool_strides, pool_kernel, pool_padding,
511             pool_padding, mkldnn_padding_zero));
512
513     /* backward primitive descriptor needs to hint forward descriptor*/
514     mkldnn_primitive_desc_t pool_bwd_pd;
515     CHECK(mkldnn_primitive_desc_create(&pool_bwd_pd, &pool_bwd_desc, engine,
516                                        pool_pd));
517
518     /* create reorder primitive between user diff dst and pool diff dst
519      * if required*/
520     mkldnn_primitive_t pool_diff_dst_memory;
521     mkldnn_primitive_t pool_reorder_diff_dst, pool_internal_diff_dst_memory;
522     const_mkldnn_primitive_desc_t pool_diff_dst_pd
523             = mkldnn_primitive_desc_query_pd(pool_bwd_pd,
524                                              mkldnn_query_diff_dst_pd, 0);
525     size_t pool_diff_dst_size
526         = mkldnn_memory_primitive_desc_get_size(pool_diff_dst_pd);
527     float *pool_diff_dst_buffer
528         = (float *)aligned_malloc(pool_diff_dst_size, 64);
529     CHECK(prepare_reorder(&pool_user_diff_dst_memory, &pool_diff_dst_pd, 1,
530                           &pool_internal_diff_dst_memory,
531                           &pool_reorder_diff_dst, pool_diff_dst_buffer));
532
533     pool_diff_dst_memory = pool_internal_diff_dst_memory
534                                    ? pool_internal_diff_dst_memory
535                                    : pool_user_diff_dst_memory;
536
537     /* create memory primitive for pool diff src data */
538     mkldnn_primitive_t pool_diff_src_memory;
539     const_mkldnn_primitive_desc_t pool_diff_src_pd
540             = mkldnn_primitive_desc_query_pd(pool_bwd_pd,
541                                              mkldnn_query_diff_src_pd, 0);
542     size_t pool_diff_src_size
543             = mkldnn_memory_primitive_desc_get_size(pool_diff_src_pd);
544     float *pool_diff_src_buffer
545             = (float *)aligned_malloc(pool_diff_src_size, 64);
546     CHECK(mkldnn_primitive_create(
547             &pool_diff_src_memory, pool_diff_src_pd, NULL, NULL));
548     CHECK(mkldnn_memory_set_data_handle(pool_diff_src_memory,
549                                         pool_diff_src_buffer));
550
551     mkldnn_primitive_at_t pool_diff_dsts[]
552             = { mkldnn_primitive_at(pool_diff_dst_memory, 0),
553                 mkldnn_primitive_at(pool_workspace_memory, 0) };
554
555     const_mkldnn_primitive_t pool_diff_srcs[] = { pool_diff_src_memory };
556
557     /* finally create backward pooling primitive */
558     mkldnn_primitive_t pool_bwd;
559     CHECK(mkldnn_primitive_create(&pool_bwd, pool_bwd_pd, pool_diff_dsts,
560                                   pool_diff_srcs));
561
562     /* Backward lrn */
563     const mkldnn_memory_desc_t *lrn_diff_dst_md
564             = mkldnn_primitive_desc_query_memory_d(pool_diff_src_pd);
565
566     /* create backward lrn descriptor */
567     mkldnn_lrn_desc_t lrn_bwd_desc;
568     CHECK(mkldnn_lrn_backward_desc_init(
569             &lrn_bwd_desc, mkldnn_lrn_across_channels, lrn_src_md,
570             lrn_diff_dst_md, local_size, alpha, beta, k));
571
572     mkldnn_primitive_desc_t lrn_bwd_pd;
573     CHECK(mkldnn_primitive_desc_create(&lrn_bwd_pd, &lrn_bwd_desc, engine,
574                                        lrn_pd));
575
576     /* create memory primitives for lrn diff src */
577     mkldnn_primitive_t lrn_diff_src_memory;
578     const_mkldnn_primitive_desc_t lrn_diff_src_pd
579             = mkldnn_primitive_desc_query_pd(lrn_bwd_pd,
580                                              mkldnn_query_diff_src_pd, 0);
581     size_t lrn_diff_src_size
582             = mkldnn_memory_primitive_desc_get_size(lrn_diff_src_pd);
583     float *lrn_diff_src_buffer = (float *)aligned_malloc(lrn_diff_src_size, 64);
584     CHECK(mkldnn_primitive_create(&lrn_diff_src_memory, lrn_diff_src_pd, NULL,
585                                   NULL));
586     CHECK(mkldnn_memory_set_data_handle(lrn_diff_src_memory,
587                                         lrn_diff_src_buffer));
588
589     mkldnn_primitive_at_t lrn_diff_dsts[]
590             = { mkldnn_primitive_at(relu_dst_memory,
591                                     0), // lrn_bwd requires src as first input
592                 mkldnn_primitive_at(pool_diff_src_memory, 0),
593                 mkldnn_primitive_at(lrn_workspace_memory, 0) };
594
595     const_mkldnn_primitive_t lrn_diff_srcs[] = { lrn_diff_src_memory };
596
597     /* finally create backward lrn primitive */
598     mkldnn_primitive_t lrn_bwd;
599     CHECK(mkldnn_primitive_create(&lrn_bwd, lrn_bwd_pd, lrn_diff_dsts,
600              lrn_diff_srcs));
601
602     /* Backward relu */
603     const mkldnn_memory_desc_t *relu_diff_dst_md
604             = mkldnn_primitive_desc_query_memory_d(lrn_diff_src_pd);
605
606     /* create backward relu descriptor */
607     mkldnn_eltwise_desc_t relu_bwd_desc;
608     CHECK(mkldnn_eltwise_backward_desc_init(&relu_bwd_desc,
609                 mkldnn_eltwise_relu, relu_diff_dst_md, relu_src_md,
610                 negative_slope, 0));
611
612     mkldnn_primitive_desc_t relu_bwd_pd;
613     CHECK(mkldnn_primitive_desc_create(&relu_bwd_pd, &relu_bwd_desc, engine,
614                                        relu_pd));
615
616     /* create memory primitives for relu diff src */
617     mkldnn_primitive_t relu_diff_src_memory;
618     const_mkldnn_primitive_desc_t relu_diff_src_pd
619             = mkldnn_primitive_desc_query_pd(relu_bwd_pd,
620                                              mkldnn_query_diff_src_pd, 0);
621     size_t relu_diff_src_size
622             = mkldnn_memory_primitive_desc_get_size(relu_diff_src_pd);
623     float *relu_diff_src_buffer
624             = (float *)aligned_malloc(relu_diff_src_size, 64);
625
626     CHECK(mkldnn_primitive_create(&relu_diff_src_memory, relu_diff_src_pd, NULL,
627                                   NULL));
628     CHECK(mkldnn_memory_set_data_handle(relu_diff_src_memory,
629                                         relu_diff_src_buffer));
630
631     mkldnn_primitive_at_t relu_diff_dsts[]
632             = { mkldnn_primitive_at(conv_internal_dst_memory, 0),
633                 mkldnn_primitive_at(lrn_diff_src_memory, 0) };
634
635     const_mkldnn_primitive_t relu_diff_srcs[] = { relu_diff_src_memory };
636
637     /* finally create backward relu primitive */
638     mkldnn_primitive_t relu_bwd;
639     CHECK(mkldnn_primitive_create(&relu_bwd, relu_pd, relu_diff_dsts,
640                                   relu_diff_srcs));
641
642     /* Backward convolution with respect to weights */
643     float *conv_diff_bias_buffer = (float *)aligned_malloc(
644             product(conv_bias_sizes, 1) * sizeof(float), 64);
645     float *conv_user_diff_weights_buffer = (float *)aligned_malloc(
646             product(conv_user_weights_sizes, 4) * sizeof(float), 64);
647
648     /* initialize memory for diff weights in user format */
649     mkldnn_primitive_t conv_user_diff_weights_memory;
650     init_data_memory(4, conv_user_weights_sizes, mkldnn_nchw, mkldnn_f32,
651             engine, conv_user_diff_weights_buffer,
652             &conv_user_diff_weights_memory);
653
654     /* memory descriptors should be in format `any` to allow backward
655      * convolution for
656      * weights to chose the format it prefers for best performance */
657     mkldnn_memory_desc_t conv_diff_src_md, conv_diff_weights_md,
658             conv_diff_bias_md, conv_diff_dst_md;
659     CHECK(mkldnn_memory_desc_init(
660             &conv_diff_src_md, 4, conv_user_src_sizes, mkldnn_f32, mkldnn_any));
661     CHECK(mkldnn_memory_desc_init(&conv_diff_weights_md, 4,
662             conv_user_weights_sizes, mkldnn_f32, mkldnn_any));
663     CHECK(mkldnn_memory_desc_init(
664             &conv_diff_bias_md, 1, conv_bias_sizes, mkldnn_f32, mkldnn_x));
665     CHECK(mkldnn_memory_desc_init(
666             &conv_diff_dst_md, 4, conv_user_dst_sizes, mkldnn_f32, mkldnn_any));
667
668     /* create backward convolution descriptor */
669     mkldnn_convolution_desc_t conv_bwd_weights_desc;
670     CHECK(mkldnn_convolution_backward_weights_desc_init(&conv_bwd_weights_desc,
671             mkldnn_convolution_direct, &conv_diff_src_md, &conv_diff_weights_md,
672             &conv_diff_bias_md, &conv_diff_dst_md, conv_strides, conv_padding,
673             conv_padding, mkldnn_padding_zero));
674
675     mkldnn_primitive_desc_t conv_bwd_weights_pd;
676     CHECK(mkldnn_primitive_desc_create(
677             &conv_bwd_weights_pd, &conv_bwd_weights_desc, engine, conv_pd));
678
679     /* for best performance convolution backward might chose
680      * different memory format for src and diff_dst
681      * than the memory formats preferred by forward convolution
682      * for src and dst respectively */
683     /* create reorder primitives for src from forward convolution to the
684      * format chosen by backward convolution */
685     mkldnn_primitive_t conv_bwd_reorder_src, conv_bwd_internal_src_memory;
686     const_mkldnn_primitive_desc_t conv_diff_src_pd
687             = mkldnn_primitive_desc_query_pd(conv_bwd_weights_pd,
688                                              mkldnn_query_src_pd, 0);
689     size_t conv_diff_src_size
690             = mkldnn_memory_primitive_desc_get_size(conv_diff_src_pd);
691     float *conv_diff_src_buffer
692             = (float *)aligned_malloc(conv_diff_src_size, 64);
693     CHECK(prepare_reorder(&conv_src_memory, &conv_diff_src_pd, 1,
694             &conv_bwd_internal_src_memory, &conv_bwd_reorder_src,
695             conv_diff_src_buffer));
696
697     mkldnn_primitive_t conv_diff_src_memory
698             = conv_bwd_internal_src_memory ? conv_bwd_internal_src_memory
699                                            : conv_src_memory;
700
701     /* create reorder primitives for diff_dst between diff_src from relu_bwd
702      * and format preferred by conv_diff_weights */
703     mkldnn_primitive_t conv_reorder_diff_dst, conv_internal_diff_dst_memory;
704     const_mkldnn_primitive_desc_t conv_diff_dst_pd
705             = mkldnn_primitive_desc_query_pd(conv_bwd_weights_pd,
706                                              mkldnn_query_diff_dst_pd, 0);
707     size_t conv_diff_dst_size
708             = mkldnn_memory_primitive_desc_get_size(conv_diff_dst_pd);
709     float *conv_diff_dst_buffer
710             = (float *)aligned_malloc(conv_diff_dst_size, 64);
711
712     CHECK(prepare_reorder(&relu_diff_src_memory, &conv_diff_dst_pd, 1,
713                           &conv_internal_diff_dst_memory,
714                           &conv_reorder_diff_dst, conv_diff_dst_buffer));
715
716     mkldnn_primitive_t conv_diff_dst_memory
717             = conv_internal_diff_dst_memory ? conv_internal_diff_dst_memory
718                                             : relu_diff_src_memory;
719
720     /* create reorder primitives for conv diff weights memory */
721     mkldnn_primitive_t conv_reorder_diff_weights,
722             conv_internal_diff_weights_memory;
723     const_mkldnn_primitive_desc_t conv_diff_weights_pd
724             = mkldnn_primitive_desc_query_pd(conv_bwd_weights_pd,
725                                              mkldnn_query_diff_weights_pd, 0);
726     size_t conv_diff_weights_size
727             = mkldnn_memory_primitive_desc_get_size(conv_diff_weights_pd);
728     float *conv_diff_weights_buffer
729             = (float *)aligned_malloc(conv_diff_weights_size, 64);
730     CHECK(prepare_reorder(&conv_user_diff_weights_memory, &conv_diff_weights_pd,
731                           0, &conv_internal_diff_weights_memory,
732                           &conv_reorder_diff_weights,
733                           conv_diff_weights_buffer));
734
735     mkldnn_primitive_t conv_diff_weights_memory
736             = conv_internal_diff_weights_memory
737                       ? conv_internal_diff_weights_memory
738                       : conv_user_diff_weights_memory;
739
740     /* create memory primitive for diff bias memory */
741     mkldnn_primitive_t conv_diff_bias_memory;
742     mkldnn_primitive_desc_t conv_diff_bias_pd;
743     CHECK(mkldnn_memory_primitive_desc_create(&conv_diff_bias_pd,
744                                               &conv_diff_bias_md, engine));
745     CHECK(mkldnn_primitive_create(&conv_diff_bias_memory, conv_diff_bias_pd,
746                                   NULL, NULL));
747     CHECK(mkldnn_memory_set_data_handle(conv_diff_bias_memory,
748                                         conv_diff_bias_buffer));
749
750     mkldnn_primitive_at_t conv_diff_dsts[]
751             = { mkldnn_primitive_at(conv_diff_src_memory, 0),
752                 mkldnn_primitive_at(conv_diff_dst_memory, 0) };
753
754     const_mkldnn_primitive_t conv_diff_weights[]
755             = { conv_diff_weights_memory, conv_diff_bias_memory };
756
757     /* finally created backward convolution weights primitive */
758     mkldnn_primitive_t conv_bwd_weights;
759     CHECK(mkldnn_primitive_create(&conv_bwd_weights, conv_bwd_weights_pd,
760                                   conv_diff_dsts, conv_diff_weights));
761
762     /* build backward stream */
763     uint32_t n_bwd = 0;
764     mkldnn_primitive_t net_bwd[10];
765
766     if (pool_reorder_diff_dst)
767         net_bwd[n_bwd++] = pool_reorder_diff_dst;
768     net_bwd[n_bwd++] = pool_bwd;
769     net_bwd[n_bwd++] = lrn_bwd;
770     net_bwd[n_bwd++] = relu_bwd;
771     if (conv_bwd_reorder_src)
772         net_bwd[n_bwd++] = conv_bwd_reorder_src;
773     if (conv_reorder_diff_dst)
774         net_bwd[n_bwd++] = conv_reorder_diff_dst;
775     net_bwd[n_bwd++] = conv_bwd_weights;
776     if (conv_reorder_diff_weights)
777         net_bwd[n_bwd++] = conv_reorder_diff_weights;
778
779     // output from backward stream
780     void *net_diff_weights = NULL;
781     void *net_diff_bias = NULL;
782
783     int n_iter = 10; //number of iterations for training.
784     /* Execute the net */
785     for (int i = 0; i < n_iter; i++) {
786         mkldnn_stream_t stream_fwd;
787         CHECK(mkldnn_stream_create(&stream_fwd, mkldnn_eager));
788         CHECK(mkldnn_stream_submit(stream_fwd, n_fwd, net_fwd, NULL));
789         CHECK(mkldnn_stream_wait(stream_fwd, n_fwd, NULL));
790         CHECK(mkldnn_stream_destroy(stream_fwd));
791
792         /* Update net_diff_dst */
793         CHECK(mkldnn_memory_get_data_handle(pool_user_dst_memory, &net_output));
794         /*...user updates net_diff_dst using net_output...*/
795         // some user defined func update_diff_dst(net_diff_dst, net_output)
796
797         /* Backward pass */
798         mkldnn_stream_t stream_bwd;
799         CHECK(mkldnn_stream_create(&stream_bwd, mkldnn_eager));
800         CHECK(mkldnn_stream_submit(stream_bwd, n_bwd, net_bwd, NULL));
801         CHECK(mkldnn_stream_wait(stream_bwd, n_bwd, NULL));
802         CHECK(mkldnn_stream_destroy(stream_bwd));
803
804         /*... update weights ... */
805         CHECK(mkldnn_memory_get_data_handle(conv_user_diff_weights_memory,
806                                             &net_diff_weights));
807         CHECK(mkldnn_memory_get_data_handle(conv_diff_bias_memory,
808                                             &net_diff_bias));
809         /* ...user updates weights and bias using diff weights and bias...*/
810         // some user defined func update_weights(conv_user_weights_memory,
811         // conv_bias_memory,
812         //      net_diff_weights, net_diff_bias);
813     }
814
815     /* Cleanup forward */
816     CHECK(mkldnn_primitive_desc_destroy(pool_pd));
817     CHECK(mkldnn_primitive_desc_destroy(lrn_pd));
818     CHECK(mkldnn_primitive_desc_destroy(relu_pd));
819     CHECK(mkldnn_primitive_desc_destroy(conv_pd));
820
821     _free(net_src);
822     _free(net_dst);
823
824     mkldnn_primitive_destroy(conv_user_src_memory);
825     mkldnn_primitive_destroy(conv_user_weights_memory);
826     mkldnn_primitive_destroy(conv_user_bias_memory);
827     mkldnn_primitive_destroy(conv_internal_src_memory);
828     mkldnn_primitive_destroy(conv_internal_weights_memory);
829     mkldnn_primitive_destroy(conv_internal_dst_memory);
830     mkldnn_primitive_destroy(conv_reorder_src);
831     mkldnn_primitive_destroy(conv_reorder_weights);
832     mkldnn_primitive_destroy(conv);
833
834     _free(conv_weights);
835     _free(conv_bias);
836
837     _free(conv_src_buffer);
838     _free(conv_weights_buffer);
839     _free(conv_dst_buffer);
840
841     mkldnn_primitive_destroy(relu_dst_memory);
842     mkldnn_primitive_destroy(relu);
843
844     _free(relu_dst_buffer);
845
846     mkldnn_primitive_destroy(lrn_workspace_memory);
847     mkldnn_primitive_destroy(lrn_dst_memory);
848     mkldnn_primitive_destroy(lrn);
849
850     _free(lrn_workspace_buffer);
851     _free(lrn_dst_buffer);
852
853     mkldnn_primitive_destroy(pool_user_dst_memory);
854     mkldnn_primitive_destroy(pool_internal_dst_memory);
855     mkldnn_primitive_destroy(pool_workspace_memory);
856     mkldnn_primitive_destroy(pool_reorder_dst);
857     mkldnn_primitive_destroy(pool);
858
859     _free(pool_dst_buffer);
860     _free(pool_workspace_buffer);
861
862     /* Cleanup backward */
863     CHECK(mkldnn_primitive_desc_destroy(pool_bwd_pd));
864     CHECK(mkldnn_primitive_desc_destroy(lrn_bwd_pd));
865     CHECK(mkldnn_primitive_desc_destroy(relu_bwd_pd));
866     CHECK(mkldnn_primitive_desc_destroy(conv_diff_bias_pd));
867     CHECK(mkldnn_primitive_desc_destroy(conv_bwd_weights_pd));
868
869     mkldnn_primitive_destroy(pool_user_diff_dst_memory);
870     mkldnn_primitive_destroy(pool_diff_src_memory);
871     mkldnn_primitive_destroy(pool_internal_diff_dst_memory);
872     mkldnn_primitive_destroy(pool_reorder_diff_dst);
873     mkldnn_primitive_destroy(pool_bwd);
874
875     _free(net_diff_dst);
876     _free(pool_diff_dst_buffer);
877     _free(pool_diff_src_buffer);
878
879     mkldnn_primitive_destroy(lrn_diff_src_memory);
880     mkldnn_primitive_destroy(lrn_bwd);
881
882     _free(lrn_diff_src_buffer);
883
884     mkldnn_primitive_destroy(relu_diff_src_memory);
885     mkldnn_primitive_destroy(relu_bwd);
886
887     _free(relu_diff_src_buffer);
888
889     mkldnn_primitive_destroy(conv_user_diff_weights_memory);
890     mkldnn_primitive_destroy(conv_diff_bias_memory);
891     mkldnn_primitive_destroy(conv_bwd_internal_src_memory);
892     mkldnn_primitive_destroy(conv_bwd_reorder_src);
893     mkldnn_primitive_destroy(conv_internal_diff_dst_memory);
894     mkldnn_primitive_destroy(conv_reorder_diff_dst);
895     mkldnn_primitive_destroy(conv_internal_diff_weights_memory);
896     mkldnn_primitive_destroy(conv_reorder_diff_weights);
897     mkldnn_primitive_destroy(conv_bwd_weights);
898
899     _free(conv_diff_weights_buffer);
900     _free(conv_diff_bias_buffer);
901     _free(conv_user_diff_weights_buffer);
902     _free(conv_diff_src_buffer);
903     _free(conv_diff_dst_buffer);
904
905     mkldnn_engine_destroy(engine);
906
907     return mkldnn_success;
908 }
909
910 int main(int argc, char **argv)
911 {
912     mkldnn_status_t result = simple_net();
913     printf("%s\n", (result == mkldnn_success) ? "passed" : "failed");
914     return result;
915 }