Publishing R3
[platform/upstream/dldt.git] / inference-engine / thirdparty / mkl-dnn / examples / simple_training_net.cpp
1 /*******************************************************************************
2 * Copyright 2016-2018 Intel Corporation
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *     http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
16
17 #include <iostream>
18 #include <numeric>
19 #include <math.h>
20 #include <string>
21 #include "mkldnn.hpp"
22
23 using namespace mkldnn;
24
25 void simple_net()
26 {
27     auto cpu_engine = engine(engine::cpu, 0);
28
29     const int batch = 32;
30
31     std::vector<float> net_src(batch * 3 * 227 * 227);
32     std::vector<float> net_dst(batch * 96 * 27 * 27);
33
34     /* initializing non-zero values for src */
35     for (size_t i = 0; i < net_src.size(); ++i)
36         net_src[i] = sinf((float)i);
37
38     /* AlexNet: conv
39      * {batch, 3, 227, 227} (x) {96, 3, 11, 11} -> {batch, 96, 55, 55}
40      * strides: {4, 4}
41      */
42     memory::dims conv_src_tz = { batch, 3, 227, 227 };
43     memory::dims conv_weights_tz = { 96, 3, 11, 11 };
44     memory::dims conv_bias_tz = { 96 };
45     memory::dims conv_dst_tz = { batch, 96, 55, 55 };
46     memory::dims conv_strides = { 4, 4 };
47     auto conv_padding = { 0, 0 };
48
49     std::vector<float> conv_weights(
50             std::accumulate(conv_weights_tz.begin(), conv_weights_tz.end(), 1,
51                             std::multiplies<uint32_t>()));
52     std::vector<float> conv_bias(std::accumulate(conv_bias_tz.begin(),
53                                                  conv_bias_tz.end(), 1,
54                                                  std::multiplies<uint32_t>()));
55
56     /* initializing non-zero values for weights and bias */
57     for (size_t i = 0; i < conv_weights.size(); ++i)
58         conv_weights[i] = sinf((float)i);
59     for (size_t i = 0; i < conv_bias.size(); ++i)
60         conv_bias[i] = sinf((float)i);
61
62     /* create memory for user data */
63     auto conv_user_src_memory = memory(
64             { { { conv_src_tz }, memory::data_type::f32, memory::format::nchw },
65               cpu_engine },
66             net_src.data());
67     auto conv_user_weights_memory
68             = memory({ { { conv_weights_tz }, memory::data_type::f32,
69                          memory::format::oihw },
70                        cpu_engine },
71                      conv_weights.data());
72     auto conv_user_bias_memory = memory(
73             { { { conv_bias_tz }, memory::data_type::f32, memory::format::x },
74               cpu_engine },
75             conv_bias.data());
76
77     /* create mmemory descriptors for convolution data w/ no specified
78      * format(`any`)
79      * format `any` lets a primitive(convolution in this case)
80      * chose the memory format preferred for best performance. */
81     auto conv_src_md = memory::desc({ conv_src_tz }, memory::data_type::f32,
82                                     memory::format::any);
83     auto conv_bias_md = memory::desc({ conv_bias_tz }, memory::data_type::f32,
84                                      memory::format::any);
85     auto conv_weights_md = memory::desc(
86             { conv_weights_tz }, memory::data_type::f32, memory::format::any);
87     auto conv_dst_md = memory::desc({ conv_dst_tz }, memory::data_type::f32,
88                                     memory::format::any);
89
90     /* create a convolution primitive descriptor */
91     auto conv_desc = convolution_forward::desc(
92             prop_kind::forward, convolution_direct, conv_src_md,
93             conv_weights_md, conv_bias_md, conv_dst_md, conv_strides,
94             conv_padding, conv_padding, padding_kind::zero);
95     auto conv_pd = convolution_forward::primitive_desc(conv_desc, cpu_engine);
96
97     /* create reorder primitives between user input and conv src if needed */
98     auto conv_src_memory = conv_user_src_memory;
99     bool reorder_conv_src = false;
100     primitive conv_reorder_src;
101     if (memory::primitive_desc(conv_pd.src_primitive_desc())
102         != conv_user_src_memory.get_primitive_desc()) {
103         conv_src_memory = memory(conv_pd.src_primitive_desc());
104         conv_reorder_src = reorder(conv_user_src_memory, conv_src_memory);
105         reorder_conv_src = true;
106     }
107
108     auto conv_weights_memory = conv_user_weights_memory;
109     bool reorder_conv_weights = false;
110     primitive conv_reorder_weights;
111     if (memory::primitive_desc(conv_pd.weights_primitive_desc())
112         != conv_user_weights_memory.get_primitive_desc()) {
113         conv_weights_memory = memory(conv_pd.weights_primitive_desc());
114         conv_reorder_weights
115                 = reorder(conv_user_weights_memory, conv_weights_memory);
116         reorder_conv_weights = true;
117     }
118
119     /* create memory primitive for conv dst */
120     auto conv_dst_memory = memory(conv_pd.dst_primitive_desc());
121
122     /* finally create a convolution primitive */
123     auto conv
124             = convolution_forward(conv_pd, conv_src_memory, conv_weights_memory,
125                                   conv_user_bias_memory, conv_dst_memory);
126
127     /* AlexNet: relu
128      * {batch, 96, 55, 55} -> {batch, 96, 55, 55}
129      */
130     const float negative_slope = 1.0;
131
132     /* create relu primitive desc */
133     /* keep memory format of source same as the format of convolution
134      * output in order to avoid reorder */
135     auto relu_desc = eltwise_forward::desc(prop_kind::forward,
136             algorithm::eltwise_relu, conv_pd.dst_primitive_desc().desc(),
137             negative_slope);
138     auto relu_pd = eltwise_forward::primitive_desc(relu_desc, cpu_engine);
139
140     /* create relu dst memory primitive */
141     auto relu_dst_memory = memory(relu_pd.dst_primitive_desc());
142
143     /* finally create a relu primitive */
144     auto relu = eltwise_forward(relu_pd, conv_dst_memory, relu_dst_memory);
145
146     /* AlexNet: lrn
147      * {batch, 96, 55, 55} -> {batch, 96, 55, 55}
148      * local size: 5
149      * alpha: 0.0001
150      * beta: 0.75
151      * k: 1.0
152      */
153     const uint32_t local_size = 5;
154     const float alpha = 0.0001;
155     const float beta = 0.75;
156     const float k = 1.0;
157
158     /* create a lrn primitive descriptor */
159     auto lrn_desc = lrn_forward::desc(prop_kind::forward, lrn_across_channels,
160                                       relu_pd.dst_primitive_desc().desc(),
161                                       local_size, alpha, beta, k);
162     auto lrn_pd = lrn_forward::primitive_desc(lrn_desc, cpu_engine);
163
164     /* create lrn dst memory */
165     auto lrn_dst_memory = memory(lrn_pd.dst_primitive_desc());
166
167     /* create workspace only in training and only for forward primitive*/
168     /* query lrn_pd for workspace, this memory will be shared with forward lrn*/
169     auto lrn_workspace_memory = memory(lrn_pd.workspace_primitive_desc());
170
171     /* finally create a lrn primitive */
172     auto lrn = lrn_forward(lrn_pd, relu_dst_memory, lrn_workspace_memory,
173                            lrn_dst_memory);
174
175     /* AlexNet: pool
176      * {batch, 96, 55, 55} -> {batch, 96, 27, 27}
177      * kernel: {3, 3}
178      * strides: {2, 2}
179      */
180     memory::dims pool_dst_tz = { batch, 96, 27, 27 };
181     memory::dims pool_kernel = { 3, 3 };
182     memory::dims pool_strides = { 2, 2 };
183     auto pool_padding = { 0, 0 };
184
185     /* create memory for pool dst data in user format */
186     auto pool_user_dst_memory = memory(
187             { { { pool_dst_tz }, memory::data_type::f32, memory::format::nchw },
188               cpu_engine },
189             net_dst.data());
190
191     /* create pool dst memory descriptor in format any */
192     auto pool_dst_md = memory::desc({ pool_dst_tz }, memory::data_type::f32,
193                                     memory::format::any);
194
195     /* create a pooling primitive descriptor */
196     auto pool_desc = pooling_forward::desc(
197             prop_kind::forward, pooling_max,
198             lrn_dst_memory.get_primitive_desc().desc(), pool_dst_md,
199             pool_strides, pool_kernel, pool_padding, pool_padding,
200             padding_kind::zero);
201     auto pool_pd = pooling_forward::primitive_desc(pool_desc, cpu_engine);
202
203     /* create reorder primitive between pool dst and user dst format
204      * if needed */
205     auto pool_dst_memory = pool_user_dst_memory;
206     bool reorder_pool_dst = false;
207     primitive pool_reorder_dst;
208     if (memory::primitive_desc(pool_pd.dst_primitive_desc())
209         != pool_user_dst_memory.get_primitive_desc()) {
210         pool_dst_memory = memory(pool_pd.dst_primitive_desc());
211         pool_reorder_dst = reorder(pool_dst_memory, pool_user_dst_memory);
212         reorder_pool_dst = true;
213     }
214
215     /* create pooling workspace memory if training */
216     auto pool_workspace_memory = memory(pool_pd.workspace_primitive_desc());
217
218     /* finally create a pooling primitive */
219     auto pool = pooling_forward(pool_pd, lrn_dst_memory, pool_dst_memory,
220                                 pool_workspace_memory);
221
222     /* build forward net */
223     std::vector<primitive> net_fwd;
224     if (reorder_conv_src)
225         net_fwd.push_back(conv_reorder_src);
226     if (reorder_conv_weights)
227         net_fwd.push_back(conv_reorder_weights);
228     net_fwd.push_back(conv);
229     net_fwd.push_back(relu);
230     net_fwd.push_back(lrn);
231     net_fwd.push_back(pool);
232     if (reorder_pool_dst)
233         net_fwd.push_back(pool_reorder_dst);
234
235     /*----------------------------------------------------------------------*/
236     /*----------------- Backward Stream -------------------------------------*/
237     /* ... user diff_data ...*/
238     std::vector<float> net_diff_dst(batch * 96 * 27 * 27);
239     for (size_t i = 0; i < net_diff_dst.size(); ++i)
240         net_diff_dst[i] = sinf((float)i);
241
242     /* create memory for user diff dst data */
243     auto pool_user_diff_dst_memory = memory(
244             { { { pool_dst_tz }, memory::data_type::f32, memory::format::nchw },
245               cpu_engine },
246             net_diff_dst.data());
247
248     /* Backward pooling */
249     /* create memory descriptorsfor pooling */
250     auto pool_diff_src_md = lrn_dst_memory.get_primitive_desc().desc();
251     auto pool_diff_dst_md = pool_dst_memory.get_primitive_desc().desc();
252
253     /* create backward pooling descriptor*/
254     auto pool_bwd_desc = pooling_backward::desc(
255             pooling_max, pool_diff_src_md, pool_diff_dst_md, pool_strides,
256             pool_kernel, pool_padding, pool_padding, padding_kind::zero);
257     /* backward primitive descriptor needs to hint forward descriptor */
258     auto pool_bwd_pd = pooling_backward::primitive_desc(pool_bwd_desc,
259                                                         cpu_engine, pool_pd);
260
261     /* create reorder primitive between user diff dst and pool diff dst
262      * if required */
263     auto pool_diff_dst_memory = pool_user_diff_dst_memory;
264     primitive pool_reorder_diff_dst;
265     bool reorder_pool_diff_dst = false;
266     if (memory::primitive_desc(pool_dst_memory.get_primitive_desc())
267         != pool_user_diff_dst_memory.get_primitive_desc()) {
268         pool_diff_dst_memory = memory(pool_dst_memory.get_primitive_desc());
269         pool_reorder_diff_dst
270                 = reorder(pool_user_diff_dst_memory, pool_diff_dst_memory);
271         reorder_pool_diff_dst = true;
272     }
273
274     /* create memory primitive for pool diff src */
275     auto pool_diff_src_memory = memory(pool_bwd_pd.diff_src_primitive_desc());
276
277     /* finally create backward pooling primitive */
278     auto pool_bwd
279             = pooling_backward(pool_bwd_pd, pool_diff_dst_memory,
280                                pool_workspace_memory, pool_diff_src_memory);
281
282     /* Backward lrn */
283     auto lrn_diff_dst_md = lrn_dst_memory.get_primitive_desc().desc();
284
285     /* create backward lrn primitive descriptor */
286     auto lrn_bwd_desc = lrn_backward::desc(
287             lrn_across_channels, lrn_pd.src_primitive_desc().desc(),
288             lrn_diff_dst_md, local_size, alpha, beta, k);
289     auto lrn_bwd_pd
290             = lrn_backward::primitive_desc(lrn_bwd_desc, cpu_engine, lrn_pd);
291
292     /* create memory for lrn diff src */
293     auto lrn_diff_src_memory = memory(lrn_bwd_pd.diff_src_primitive_desc());
294
295     /* finally create a lrn backward primitive */
296     // backward lrn needs src: relu dst in this topology
297     auto lrn_bwd
298             = lrn_backward(lrn_bwd_pd, relu_dst_memory, pool_diff_src_memory,
299                            lrn_workspace_memory, lrn_diff_src_memory);
300
301     /* Backward relu */
302     auto relu_diff_dst_md = lrn_diff_src_memory.get_primitive_desc().desc();
303     auto relu_src_md = conv_pd.dst_primitive_desc().desc();
304
305     /* create backward relu primitive_descriptor */
306     auto relu_bwd_desc = eltwise_backward::desc(algorithm::eltwise_relu,
307             relu_diff_dst_md, relu_src_md, negative_slope);
308     auto relu_bwd_pd
309             = eltwise_backward::primitive_desc(relu_bwd_desc, cpu_engine, relu_pd);
310
311     /* create memory for relu diff src */
312     auto relu_diff_src_memory = memory(relu_bwd_pd.diff_src_primitive_desc());
313
314     /* finally create a backward relu primitive */
315     auto relu_bwd = eltwise_backward(relu_bwd_pd, conv_dst_memory,
316                                   lrn_diff_src_memory, relu_diff_src_memory);
317
318     /* Backward convolution with respect to weights */
319     /* create user format diff weights and diff bias memory */
320     std::vector<float> conv_user_diff_weights_buffer(
321             std::accumulate(conv_weights_tz.begin(), conv_weights_tz.end(), 1,
322                             std::multiplies<uint32_t>()));
323     std::vector<float> conv_diff_bias_buffer(
324             std::accumulate(conv_bias_tz.begin(), conv_bias_tz.end(), 1,
325                             std::multiplies<uint32_t>()));
326
327     auto conv_user_diff_weights_memory
328             = memory({ { { conv_weights_tz }, memory::data_type::f32,
329                          memory::format::nchw },
330                        cpu_engine },
331                      conv_user_diff_weights_buffer.data());
332     auto conv_diff_bias_memory = memory(
333             { { { conv_bias_tz }, memory::data_type::f32, memory::format::x },
334               cpu_engine },
335             conv_diff_bias_buffer.data());
336
337     /* create memory primitives descriptors */
338
339     auto conv_bwd_src_md = memory::desc({ conv_src_tz }, memory::data_type::f32,
340                                         memory::format::any);
341     auto conv_diff_bias_md = memory::desc(
342             { conv_bias_tz }, memory::data_type::f32, memory::format::any);
343     auto conv_diff_weights_md = memory::desc(
344             { conv_weights_tz }, memory::data_type::f32, memory::format::any);
345     auto conv_diff_dst_md = memory::desc(
346             { conv_dst_tz }, memory::data_type::f32, memory::format::any);
347
348     /* create backward convolution primitive descriptor */
349     auto conv_bwd_weights_desc = convolution_backward_weights::desc(
350             convolution_direct, conv_bwd_src_md, conv_diff_weights_md,
351             conv_diff_bias_md, conv_diff_dst_md, conv_strides, conv_padding,
352             conv_padding, padding_kind::zero);
353     auto conv_bwd_weights_pd = convolution_backward_weights::primitive_desc(
354             conv_bwd_weights_desc, cpu_engine, conv_pd);
355
356     /* for best performance convolution backward might chose
357      * different memory format for src and diff_dst
358      * than the memory formats preferred by forward convolution
359      * for src and dst respectively */
360     /* create reorder primitives for src from forward convolution to the
361      * format chosen by backward convolution */
362     auto conv_bwd_src_memory = conv_src_memory;
363     primitive conv_bwd_reorder_src;
364     auto reorder_conv_bwd_src = false;
365     if (memory::primitive_desc(conv_bwd_weights_pd.src_primitive_desc())
366         != conv_src_memory.get_primitive_desc())
367     {
368         conv_bwd_src_memory = memory(conv_bwd_weights_pd.src_primitive_desc());
369         conv_bwd_reorder_src = reorder(conv_src_memory, conv_bwd_src_memory);
370         reorder_conv_bwd_src = true;
371     }
372
373     /* create reorder primitives for diff_dst between diff_src from relu_bwd
374      * and format preferred by conv_diff_weights */
375     auto conv_diff_dst_memory = relu_diff_src_memory;
376     primitive conv_reorder_diff_dst;
377     auto reorder_conv_diff_dst = false;
378     if (memory::primitive_desc(conv_bwd_weights_pd.diff_dst_primitive_desc())
379         != relu_diff_src_memory.get_primitive_desc())
380     {
381         conv_diff_dst_memory
382                 = memory(conv_bwd_weights_pd.diff_dst_primitive_desc());
383         conv_reorder_diff_dst
384                 = reorder(relu_diff_src_memory, conv_diff_dst_memory);
385         reorder_conv_diff_dst = true;
386     }
387
388     /* create reorder primitives between conv diff weights and user diff weights
389      * if needed */
390     auto conv_diff_weights_memory = conv_user_diff_weights_memory;
391     primitive conv_reorder_diff_weights;
392     bool reorder_conv_diff_weights = false;
393     if (memory::primitive_desc(
394                 conv_bwd_weights_pd.diff_weights_primitive_desc())
395         != conv_user_diff_weights_memory.get_primitive_desc()) {
396         conv_diff_weights_memory
397                 = memory(conv_bwd_weights_pd.diff_weights_primitive_desc());
398         conv_reorder_diff_weights = reorder(conv_diff_weights_memory,
399                                             conv_user_diff_weights_memory);
400         reorder_conv_diff_weights = true;
401     }
402
403     /* finally create backward convolution primitive */
404     auto conv_bwd_weights = convolution_backward_weights(
405             conv_bwd_weights_pd, conv_bwd_src_memory, conv_diff_dst_memory,
406             conv_diff_weights_memory, conv_diff_bias_memory);
407
408     /* build backward propagation net */
409     std::vector<primitive> net_bwd;
410     if (reorder_pool_diff_dst)
411         net_bwd.push_back(pool_reorder_diff_dst);
412     net_bwd.push_back(pool_bwd);
413     net_bwd.push_back(lrn_bwd);
414     net_bwd.push_back(relu_bwd);
415     if (reorder_conv_bwd_src)
416         net_bwd.push_back(conv_bwd_reorder_src);
417     if (reorder_conv_diff_dst)
418         net_bwd.push_back(conv_reorder_diff_dst);
419     net_bwd.push_back(conv_bwd_weights);
420     if (reorder_conv_diff_weights)
421         net_bwd.push_back(conv_reorder_diff_weights);
422
423     int n_iter = 1; //number of iterations for training
424     /* execute */
425     while (n_iter) {
426         /* forward */
427         stream(stream::kind::eager).submit(net_fwd).wait();
428
429         /* update net_diff_dst */
430         // auto net_output = pool_user_dst_memory.get_data_handle();
431         /*..user updates net_diff_dst using net_output...*/
432         // some user defined func update_diff_dst(net_diff_dst.data(),
433         // net_output)
434
435         stream(stream::kind::eager).submit(net_bwd).wait();
436         /* update weights and bias using diff weights and bias*/
437         // auto net_diff_weights
438         //     = conv_user_diff_weights_memory.get_data_handle();
439         // auto net_diff_bias = conv_diff_bias_memory.get_data_handle();
440         /* ...user updates weights and bias using diff weights and bias...*/
441         // some user defined func update_weights(conv_weights.data(),
442         // conv_bias.data(), net_diff_weights, net_diff_bias);
443
444         --n_iter;
445     }
446 }
447
448 int main(int argc, char **argv)
449 {
450     try
451     {
452         simple_net();
453         std::cout << "passed" << std::endl;
454     }
455     catch (error &e)
456     {
457         std::cerr << "status: " << e.status << std::endl;
458         std::cerr << "message: " << e.message << std::endl;
459     }
460     return 0;
461 }