Publishing 2019 R1 content
[platform/upstream/dldt.git] / inference-engine / thirdparty / mkl-dnn / examples / simple_net.c
1 /*******************************************************************************
2 * Copyright 2016-2018 Intel Corporation
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *     http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
16
17 // Required for posix_memalign
18 #define _POSIX_C_SOURCE 200112L
19
20 #include <string.h>
21 #include <stdio.h>
22 #include <stdlib.h>
23 #include "mkldnn.h"
24 #ifdef _WIN32
25 #include <malloc.h>
26 #endif
27
28 #define BATCH 8
29 #define IC 3
30 #define OC 96
31 #define CONV_IH 227
32 #define CONV_IW 227
33 #define CONV_OH 55
34 #define CONV_OW 55
35 #define CONV_STRIDE 4
36 #define CONV_PAD 0
37 #define POOL_OH 27
38 #define POOL_OW 27
39 #define POOL_STRIDE 2
40 #define POOL_PAD 0
41
42 #define CHECK(f) do { \
43     mkldnn_status_t s = f; \
44     if (s != mkldnn_success) { \
45         printf("[%s:%d] error: %s returns %d\n", __FILE__, __LINE__, #f, s); \
46         exit(2); \
47     } \
48 } while(0)
49
50 #define CHECK_TRUE(expr) do { \
51     int e_ = expr; \
52     if (!e_) { \
53         printf("[%s:%d] %s failed\n", __FILE__, __LINE__, #expr); \
54         exit(2); \
55     } \
56 } while(0)
57
58 void *aligned_malloc(size_t size, size_t alignment) {
59 #ifdef _WIN32
60     return _aligned_malloc(size, alignment);
61 #elif defined(_SX)
62     return malloc(size);
63 #else
64     void *p;
65     return !posix_memalign(&p, alignment, size) ? p : NULL;
66 #endif
67 }
68
69 #ifdef _WIN32
70 void _free(void *ptr) {
71     _aligned_free(ptr);
72 }
73 #else
74 void _free(void *ptr) {
75     free(ptr);
76 }
77 #endif
78
79 static size_t product(ptrdiff_t *arr, size_t size) {
80     size_t prod = 1;
81     for (size_t i = 0; i < size; ++i) prod *= arr[i];
82     return prod;
83 }
84
85 static void init_data_memory(uint32_t dim, const ptrdiff_t *dims,
86         mkldnn_memory_format_t user_fmt, mkldnn_data_type_t mkldnn_f32,
87         mkldnn_engine_t engine, float *data, mkldnn_primitive_t *memory)
88 {
89     mkldnn_memory_desc_t prim_md;
90     mkldnn_primitive_desc_t user_pd;
91     CHECK(mkldnn_memory_desc_init(&prim_md, dim, dims, mkldnn_f32, user_fmt));
92     CHECK(mkldnn_memory_primitive_desc_create(&user_pd, &prim_md, engine));
93     CHECK(mkldnn_primitive_create(memory, user_pd, NULL, NULL));
94
95     void *req = NULL;
96     CHECK(mkldnn_memory_get_data_handle(*memory, &req));
97     CHECK_TRUE(req == NULL);
98     CHECK(mkldnn_memory_set_data_handle(*memory, data));
99     CHECK(mkldnn_memory_get_data_handle(*memory, &req));
100     CHECK_TRUE(req == data);
101     CHECK(mkldnn_primitive_desc_destroy(user_pd));
102 }
103
104 mkldnn_status_t prepare_reorder(
105         mkldnn_primitive_t *user_memory, /** in */
106         const_mkldnn_primitive_desc_t *prim_memory_pd, /** in */
107         int dir_is_user_to_prim, /** in: user -> prim or prim -> user */
108         mkldnn_primitive_t *prim_memory, /** out: memory primitive created */
109         mkldnn_primitive_t *reorder, /** out: reorder primitive created */
110         float *buffer)
111 {
112     const_mkldnn_primitive_desc_t user_memory_pd;
113     mkldnn_primitive_get_primitive_desc(*user_memory, &user_memory_pd);
114
115     if (!mkldnn_memory_primitive_desc_equal(user_memory_pd, *prim_memory_pd)) {
116         /* memory_create(&p, m, NULL) means allocate memory */
117         CHECK(mkldnn_primitive_create(prim_memory, *prim_memory_pd,
118                 NULL, NULL));
119         mkldnn_primitive_desc_t reorder_pd;
120         if (dir_is_user_to_prim) {
121             /* reorder primitive descriptor doesn't need engine, because it is
122              * already appeared in in- and out- memory primitive descriptors */
123             CHECK(mkldnn_reorder_primitive_desc_create(&reorder_pd,
124                         user_memory_pd, *prim_memory_pd));
125             mkldnn_primitive_at_t inputs = { *user_memory, 0 };
126             const_mkldnn_primitive_t outputs[] = { *prim_memory };
127             CHECK(mkldnn_primitive_create(reorder, reorder_pd, &inputs,
128                         outputs));
129         } else {
130             CHECK(mkldnn_reorder_primitive_desc_create(&reorder_pd,
131                         *prim_memory_pd, user_memory_pd));
132             mkldnn_primitive_at_t inputs = { *prim_memory, 0 };
133             const_mkldnn_primitive_t outputs[] = { *user_memory };
134             CHECK(mkldnn_primitive_create(reorder, reorder_pd, &inputs,
135                         outputs));
136         }
137         CHECK(mkldnn_memory_set_data_handle(*prim_memory, buffer));
138         CHECK(mkldnn_primitive_desc_destroy(reorder_pd));
139     } else {
140         *prim_memory = NULL;
141         *reorder = NULL;
142     }
143
144     return mkldnn_success;
145 }
146
147 mkldnn_status_t simple_net() {
148
149     mkldnn_engine_t engine;
150     CHECK(mkldnn_engine_create(&engine, mkldnn_cpu, 0 /* idx */));
151
152     float *net_src = (float *)aligned_malloc(
153             BATCH * IC * CONV_IH * CONV_IW * sizeof(float), 64);
154     float *net_dst = (float *)aligned_malloc(
155             BATCH * OC * POOL_OH * POOL_OW * sizeof(float), 64);
156
157     /* AlexNet: conv
158      * {BATCH, IC, CONV_IH, CONV_IW} (x) {OC, IC, CONV_KH, CONV_KW} ->
159      * {BATCH, OC, CONV_OH, CONV_OW}
160      * strides: {CONV_STRIDE, CONV_STRIDE}
161      */
162     ptrdiff_t conv_user_src_sizes[4] = { BATCH, IC, CONV_IH, CONV_IW };
163     ptrdiff_t conv_user_weights_sizes[4] = { OC, IC, 11, 11 };
164     ptrdiff_t conv_bias_sizes[4] = { OC };
165     ptrdiff_t conv_user_dst_sizes[4] = { BATCH, OC, CONV_OH, CONV_OW };
166     ptrdiff_t conv_strides[2] = { CONV_STRIDE, CONV_STRIDE };
167     ptrdiff_t conv_padding[2] = { CONV_PAD, CONV_PAD };
168
169     float *conv_src = net_src;
170     float *conv_weights = (float *)aligned_malloc(
171             product(conv_user_weights_sizes, 4) * sizeof(float), 64);
172     float *conv_bias = (float *)aligned_malloc(
173             product(conv_bias_sizes, 1) * sizeof(float), 64);
174
175     /* create memory for user data */
176     mkldnn_primitive_t conv_user_src_memory, conv_user_weights_memory,
177         conv_user_bias_memory;
178     init_data_memory(4, conv_user_src_sizes, mkldnn_nchw, mkldnn_f32, engine,
179             conv_src, &conv_user_src_memory);
180     init_data_memory(4, conv_user_weights_sizes, mkldnn_oihw, mkldnn_f32,
181             engine, conv_weights, &conv_user_weights_memory);
182     init_data_memory(1, conv_bias_sizes, mkldnn_x, mkldnn_f32, engine,
183             conv_bias, &conv_user_bias_memory);
184
185     /* create data descriptors for convolution w/ no specified format */
186
187     mkldnn_memory_desc_t conv_src_md, conv_weights_md, conv_bias_md,
188         conv_dst_md;
189     CHECK(mkldnn_memory_desc_init(&conv_src_md, 4, conv_user_src_sizes,
190         mkldnn_f32, mkldnn_any));
191     CHECK(mkldnn_memory_desc_init(&conv_weights_md, 4, conv_user_weights_sizes,
192         mkldnn_f32, mkldnn_any));
193     CHECK(mkldnn_memory_desc_init(&conv_bias_md, 1, conv_bias_sizes,
194         mkldnn_f32, mkldnn_x));
195     CHECK(mkldnn_memory_desc_init(&conv_dst_md, 4, conv_user_dst_sizes,
196         mkldnn_f32, mkldnn_any));
197
198     /* create a convolution */
199     mkldnn_convolution_desc_t conv_any_desc;
200     CHECK(mkldnn_convolution_forward_desc_init(&conv_any_desc, mkldnn_forward,
201             mkldnn_convolution_direct, &conv_src_md, &conv_weights_md,
202             &conv_bias_md, &conv_dst_md, conv_strides, conv_padding,
203             conv_padding, mkldnn_padding_zero));
204
205     mkldnn_primitive_desc_t conv_pd;
206     CHECK(mkldnn_primitive_desc_create(&conv_pd, &conv_any_desc,
207             engine, NULL));
208
209     mkldnn_primitive_t conv_internal_src_memory, conv_internal_weights_memory,
210         conv_internal_dst_memory;
211
212     /* create memory for dst data, we don't need reorder it to user data */
213     const_mkldnn_primitive_desc_t dst_pd
214             = mkldnn_primitive_desc_query_pd(conv_pd, mkldnn_query_dst_pd, 0);
215     CHECK(mkldnn_primitive_create(
216             &conv_internal_dst_memory, dst_pd, NULL, NULL));
217     size_t conv_dst_size = mkldnn_memory_primitive_desc_get_size(dst_pd);
218     float *conv_dst_buffer = (float *)aligned_malloc(conv_dst_size, 64);
219     CHECK(mkldnn_memory_set_data_handle(
220             conv_internal_dst_memory, conv_dst_buffer));
221
222     /* create reorder primitives between user data and convolution srcs
223      * if required */
224     mkldnn_primitive_t conv_reorder_src, conv_reorder_weights;
225
226     const_mkldnn_primitive_desc_t src_pd = mkldnn_primitive_desc_query_pd(
227             conv_pd, mkldnn_query_src_pd, 0);
228     size_t conv_src_size = mkldnn_memory_primitive_desc_get_size(src_pd);
229     float *conv_src_buffer = (float *)aligned_malloc(conv_src_size, 64);
230     CHECK(prepare_reorder(&conv_user_src_memory, &src_pd, 1,
231             &conv_internal_src_memory, &conv_reorder_src, conv_src_buffer));
232
233     const_mkldnn_primitive_desc_t weights_pd = mkldnn_primitive_desc_query_pd(
234             conv_pd, mkldnn_query_weights_pd, 0);
235     size_t conv_weights_size
236             = mkldnn_memory_primitive_desc_get_size(weights_pd);
237     float *conv_weights_buffer = (float *)aligned_malloc(conv_weights_size, 64);
238     CHECK(prepare_reorder(&conv_user_weights_memory, &weights_pd, 1,
239             &conv_internal_weights_memory, &conv_reorder_weights,
240             conv_weights_buffer));
241
242     mkldnn_primitive_t conv_src_memory = conv_internal_src_memory ?
243         conv_internal_src_memory : conv_user_src_memory;
244     mkldnn_primitive_t conv_weights_memory = conv_internal_weights_memory ?
245         conv_internal_weights_memory : conv_user_weights_memory;
246
247     mkldnn_primitive_at_t conv_srcs[] = {
248         mkldnn_primitive_at(conv_src_memory, 0),
249         mkldnn_primitive_at(conv_weights_memory, 0),
250         mkldnn_primitive_at(conv_user_bias_memory, 0)
251     };
252
253     const_mkldnn_primitive_t conv_dsts[] = { conv_internal_dst_memory };
254
255     /* finally create a convolution primitive */
256     mkldnn_primitive_t conv;
257     CHECK(mkldnn_primitive_create(&conv, conv_pd, conv_srcs, conv_dsts));
258
259     /* AlexNet: relu
260      * {BATCH, OC, CONV_OH, CONV_OW} -> {BATCH, OC, CONV_OH, CONV_OW}
261      */
262     float negative_slope = 1.0f;
263
264     /* create relu memory descriptor on dst memory descriptor
265      * from previous primitive */
266     const_mkldnn_primitive_desc_t conv_dst_pd = mkldnn_primitive_desc_query_pd(
267             conv_pd, mkldnn_query_dst_pd, 0);
268     const mkldnn_memory_desc_t *relu_src_md =
269         mkldnn_primitive_desc_query_memory_d(conv_dst_pd);
270
271     /* create a relu */
272     mkldnn_eltwise_desc_t relu_desc;
273     CHECK(mkldnn_eltwise_forward_desc_init(&relu_desc, mkldnn_forward,
274                 mkldnn_eltwise_relu, relu_src_md, negative_slope, 0));
275
276     mkldnn_primitive_desc_t relu_pd;
277     CHECK(mkldnn_primitive_desc_create(&relu_pd, &relu_desc, engine, NULL));
278
279     mkldnn_primitive_t relu_dst_memory;
280     const_mkldnn_primitive_desc_t relu_dst_pd = mkldnn_primitive_desc_query_pd(
281             relu_pd, mkldnn_query_dst_pd, 0);
282     CHECK(mkldnn_primitive_create(&relu_dst_memory, relu_dst_pd, NULL, NULL));
283     size_t relu_dst_size = mkldnn_memory_primitive_desc_get_size(relu_dst_pd);
284     float *relu_dst_buffer = (float *)aligned_malloc(relu_dst_size, 64);
285     CHECK(mkldnn_memory_set_data_handle(relu_dst_memory, relu_dst_buffer));
286
287     /* finally create a relu primitive */
288     mkldnn_primitive_t relu;
289     mkldnn_primitive_at_t relu_srcs = { conv_internal_dst_memory, 0 };
290     const_mkldnn_primitive_t relu_dsts[] = { relu_dst_memory };
291
292     CHECK(mkldnn_primitive_create(&relu, relu_pd, &relu_srcs, relu_dsts));
293
294     /* AlexNet: lrn
295      * {BATCH, OC, CONV_OH, CONV_OW} -> {BATCH, OC, CONV_OH, CONV_OW}
296      * local size: 5
297      * alpha: 0.0001
298      * beta: 0.75
299      */
300     uint32_t local_size = 5;
301     float alpha = 0.0001f;
302     float beta = 0.75f;
303     float k = 1.0f;
304
305     /* create lrn memory descriptor on dst memory descriptor
306      *  from previous primitive */
307     const mkldnn_memory_desc_t *lrn_src_md =
308         mkldnn_primitive_desc_query_memory_d(relu_dst_pd);
309
310     /* create a lrn */
311     mkldnn_lrn_desc_t lrn_desc;
312     CHECK(mkldnn_lrn_forward_desc_init(&lrn_desc, mkldnn_forward,
313             mkldnn_lrn_across_channels, lrn_src_md, local_size,
314             alpha, beta, k));
315
316     mkldnn_primitive_desc_t lrn_pd;
317     CHECK(mkldnn_primitive_desc_create(&lrn_pd, &lrn_desc, engine, NULL));
318
319     mkldnn_primitive_t lrn_dst_memory;
320     const_mkldnn_primitive_desc_t lrn_dst_pd = mkldnn_primitive_desc_query_pd(
321             lrn_pd, mkldnn_query_dst_pd, 0);
322     CHECK(mkldnn_primitive_create(&lrn_dst_memory, lrn_dst_pd, NULL, NULL));
323     size_t lrn_dst_size = mkldnn_memory_primitive_desc_get_size(lrn_dst_pd);
324     float *lrn_dst_buffer = (float *)aligned_malloc(lrn_dst_size, 64);
325     CHECK(mkldnn_memory_set_data_handle(lrn_dst_memory, lrn_dst_buffer));
326
327     mkldnn_primitive_t lrn_scratch_memory;
328     const_mkldnn_primitive_desc_t lrn_scratch_pd =
329         mkldnn_primitive_desc_query_pd(lrn_pd, mkldnn_query_workspace_pd, 0);
330     CHECK(mkldnn_primitive_create(&lrn_scratch_memory,
331             lrn_scratch_pd, NULL, NULL));
332     size_t lrn_scratch_size =
333         mkldnn_memory_primitive_desc_get_size(lrn_scratch_pd);
334     float *lrn_scratch_buffer = (float*)aligned_malloc(lrn_scratch_size, 64);
335     CHECK(mkldnn_memory_set_data_handle(lrn_scratch_memory,
336             lrn_scratch_buffer));
337
338     mkldnn_primitive_at_t lrn_srcs = { relu_dst_memory, 0 };
339
340     const_mkldnn_primitive_t lrn_dsts[] = { lrn_dst_memory,
341             lrn_scratch_memory };
342
343     /* finally create a lrn primitive */
344     mkldnn_primitive_t lrn;
345     CHECK(mkldnn_primitive_create(&lrn, lrn_pd, &lrn_srcs, lrn_dsts));
346
347     /* AlexNet: pool
348      * {BATCH, OC, CONV_OH, CONV_OW} -> {BATCH, OC, POOL_OH, POOL_OW}
349      * kernel: {3, 3}
350      * strides: {POOL_STRIDE, POOL_STRIDE}
351      */
352
353     ptrdiff_t pool_dst_sizes[4] = { BATCH, OC, POOL_OH, POOL_OW };
354     ptrdiff_t pool_kernel[2] = { 3, 3 };
355     ptrdiff_t pool_strides[2] = { POOL_STRIDE, POOL_STRIDE };
356     ptrdiff_t pool_padding[2] = { POOL_PAD, POOL_PAD };
357
358     /* create pooling memory descriptor on dst descriptor
359      *  from previous primitive */
360     const mkldnn_memory_desc_t *pool_src_md =
361         mkldnn_primitive_desc_query_memory_d(lrn_dst_pd);
362
363     /* create descriptors for dst pooling data */
364     mkldnn_memory_desc_t pool_dst_md;
365     CHECK(mkldnn_memory_desc_init(
366             &pool_dst_md, 4, pool_dst_sizes, mkldnn_f32, mkldnn_any));
367
368     /* create memory for user data */
369     mkldnn_primitive_t pool_user_dst_memory;
370     init_data_memory(4, pool_dst_sizes, mkldnn_nchw, mkldnn_f32, engine,
371         net_dst, &pool_user_dst_memory);
372
373     /* create a pooling */
374     mkldnn_pooling_desc_t pool_desc;
375     CHECK(mkldnn_pooling_forward_desc_init(&pool_desc, mkldnn_forward,
376             mkldnn_pooling_max, pool_src_md, &pool_dst_md, pool_strides,
377             pool_kernel, pool_padding, pool_padding, mkldnn_padding_zero));
378
379     mkldnn_primitive_desc_t pool_pd;
380     CHECK(mkldnn_primitive_desc_create(&pool_pd, &pool_desc, engine, NULL));
381
382     /* create memory for workspace */
383     mkldnn_primitive_t pool_indices_memory;
384     const_mkldnn_primitive_desc_t pool_indices_pd =
385         mkldnn_primitive_desc_query_pd(pool_pd, mkldnn_query_workspace_pd, 0);
386     CHECK(mkldnn_primitive_create(&pool_indices_memory,
387             pool_indices_pd, NULL, NULL));
388     size_t pool_indices_size =
389         mkldnn_memory_primitive_desc_get_size(pool_indices_pd);
390     float *pool_indices_buffer = (float*)aligned_malloc(pool_indices_size, 64);
391     CHECK(mkldnn_memory_set_data_handle(pool_indices_memory,
392             pool_indices_buffer));
393
394     mkldnn_primitive_t pool_dst_memory;
395
396     /* create reorder primitives between user data and pooling dsts
397      * if required */
398     mkldnn_primitive_t pool_reorder_dst, pool_internal_dst_memory;
399     const_mkldnn_primitive_desc_t pool_dst_pd =
400         mkldnn_primitive_desc_query_pd(pool_pd, mkldnn_query_dst_pd, 0);
401     size_t pool_dst_size = mkldnn_memory_primitive_desc_get_size(pool_dst_pd);
402     float *pool_dst_buffer = (float *)aligned_malloc(pool_dst_size, 64);
403     CHECK(prepare_reorder(&pool_user_dst_memory, &pool_dst_pd, 0,
404             &pool_internal_dst_memory, &pool_reorder_dst, pool_dst_buffer));
405
406     mkldnn_primitive_at_t pool_srcs = { lrn_dst_memory, 0 };
407
408     pool_dst_memory = pool_internal_dst_memory ? pool_internal_dst_memory
409         : pool_user_dst_memory;
410
411     const_mkldnn_primitive_t pool_dsts[] = { pool_dst_memory,
412             pool_indices_memory };
413
414     /* finally create a pooling primitive */
415     mkldnn_primitive_t pool;
416     CHECK(mkldnn_primitive_create(&pool, pool_pd, &pool_srcs, pool_dsts));
417
418     /* build a simple net */
419     uint32_t n = 0;
420     mkldnn_primitive_t net[10];
421
422     if (conv_reorder_src) net[n++] = conv_reorder_src;
423     if (conv_reorder_weights) net[n++] = conv_reorder_weights;
424     net[n++] = conv;
425     net[n++] = relu;
426     net[n++] = lrn;
427     net[n++] = pool;
428     if (pool_reorder_dst) net[n++] = pool_reorder_dst;
429
430     mkldnn_stream_t stream;
431     CHECK(mkldnn_stream_create(&stream, mkldnn_eager));
432     CHECK(mkldnn_stream_submit(stream, n, net, NULL));
433     CHECK(mkldnn_stream_wait(stream, n, NULL));
434
435     /* clean-up */
436     CHECK(mkldnn_primitive_desc_destroy(conv_pd));
437     CHECK(mkldnn_primitive_desc_destroy(relu_pd));
438     CHECK(mkldnn_primitive_desc_destroy(lrn_pd));
439     CHECK(mkldnn_primitive_desc_destroy(pool_pd));
440
441     mkldnn_stream_destroy(stream);
442
443     _free(net_src);
444     _free(net_dst);
445
446     mkldnn_primitive_destroy(conv_user_src_memory);
447     mkldnn_primitive_destroy(conv_user_weights_memory);
448     mkldnn_primitive_destroy(conv_user_bias_memory);
449     mkldnn_primitive_destroy(conv_internal_src_memory);
450     mkldnn_primitive_destroy(conv_internal_weights_memory);
451     mkldnn_primitive_destroy(conv_internal_dst_memory);
452     mkldnn_primitive_destroy(conv_reorder_src);
453     mkldnn_primitive_destroy(conv_reorder_weights);
454     mkldnn_primitive_destroy(conv);
455
456     _free(conv_weights);
457     _free(conv_bias);
458
459     _free(conv_src_buffer);
460     _free(conv_weights_buffer);
461     _free(conv_dst_buffer);
462
463     mkldnn_primitive_destroy(relu_dst_memory);
464     mkldnn_primitive_destroy(relu);
465
466     _free(relu_dst_buffer);
467
468     mkldnn_primitive_destroy(lrn_scratch_memory);
469     mkldnn_primitive_destroy(lrn_dst_memory);
470     mkldnn_primitive_destroy(lrn);
471
472     _free(lrn_scratch_buffer);
473     _free(lrn_dst_buffer);
474
475     mkldnn_primitive_destroy(pool_user_dst_memory);
476     mkldnn_primitive_destroy(pool_internal_dst_memory);
477     mkldnn_primitive_destroy(pool_indices_memory);
478     mkldnn_primitive_destroy(pool_reorder_dst);
479     mkldnn_primitive_destroy(pool);
480
481     _free(pool_dst_buffer);
482     _free(pool_indices_buffer);
483
484     mkldnn_engine_destroy(engine);
485
486     return mkldnn_success;
487 }
488
489 int main(int argc, char **argv) {
490     mkldnn_status_t result = simple_net();
491     printf("%s\n", (result == mkldnn_success) ? "passed" : "failed");
492     return result;
493 }