1 /*******************************************************************************
2 * Copyright 2016-2018 Intel Corporation
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
8 * http://www.apache.org/licenses/LICENSE-2.0
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
23 using namespace mkldnn;
27 auto cpu_engine = engine(engine::cpu, 0);
31 std::vector<float> net_src(batch * 3 * 227 * 227);
32 std::vector<float> net_dst(batch * 96 * 27 * 27);
34 /* initializing non-zero values for src */
35 for (size_t i = 0; i < net_src.size(); ++i)
36 net_src[i] = sinf((float)i);
39 * {batch, 3, 227, 227} (x) {96, 3, 11, 11} -> {batch, 96, 55, 55}
42 memory::dims conv_src_tz = { batch, 3, 227, 227 };
43 memory::dims conv_weights_tz = { 96, 3, 11, 11 };
44 memory::dims conv_bias_tz = { 96 };
45 memory::dims conv_dst_tz = { batch, 96, 55, 55 };
46 memory::dims conv_strides = { 4, 4 };
47 auto conv_padding = { 0, 0 };
49 std::vector<float> conv_weights(
50 std::accumulate(conv_weights_tz.begin(), conv_weights_tz.end(), 1,
51 std::multiplies<uint32_t>()));
52 std::vector<float> conv_bias(std::accumulate(conv_bias_tz.begin(),
53 conv_bias_tz.end(), 1,
54 std::multiplies<uint32_t>()));
56 /* initializing non-zero values for weights and bias */
57 for (size_t i = 0; i < conv_weights.size(); ++i)
58 conv_weights[i] = sinf((float)i);
59 for (size_t i = 0; i < conv_bias.size(); ++i)
60 conv_bias[i] = sinf((float)i);
62 /* create memory for user data */
63 auto conv_user_src_memory = memory(
64 { { { conv_src_tz }, memory::data_type::f32, memory::format::nchw },
67 auto conv_user_weights_memory
68 = memory({ { { conv_weights_tz }, memory::data_type::f32,
69 memory::format::oihw },
72 auto conv_user_bias_memory = memory(
73 { { { conv_bias_tz }, memory::data_type::f32, memory::format::x },
77 /* create mmemory descriptors for convolution data w/ no specified
79 * format `any` lets a primitive(convolution in this case)
80 * chose the memory format preferred for best performance. */
81 auto conv_src_md = memory::desc({ conv_src_tz }, memory::data_type::f32,
83 auto conv_bias_md = memory::desc({ conv_bias_tz }, memory::data_type::f32,
85 auto conv_weights_md = memory::desc(
86 { conv_weights_tz }, memory::data_type::f32, memory::format::any);
87 auto conv_dst_md = memory::desc({ conv_dst_tz }, memory::data_type::f32,
90 /* create a convolution primitive descriptor */
91 auto conv_desc = convolution_forward::desc(
92 prop_kind::forward, convolution_direct, conv_src_md,
93 conv_weights_md, conv_bias_md, conv_dst_md, conv_strides,
94 conv_padding, conv_padding, padding_kind::zero);
95 auto conv_pd = convolution_forward::primitive_desc(conv_desc, cpu_engine);
97 /* create reorder primitives between user input and conv src if needed */
98 auto conv_src_memory = conv_user_src_memory;
99 bool reorder_conv_src = false;
100 primitive conv_reorder_src;
101 if (memory::primitive_desc(conv_pd.src_primitive_desc())
102 != conv_user_src_memory.get_primitive_desc()) {
103 conv_src_memory = memory(conv_pd.src_primitive_desc());
104 conv_reorder_src = reorder(conv_user_src_memory, conv_src_memory);
105 reorder_conv_src = true;
108 auto conv_weights_memory = conv_user_weights_memory;
109 bool reorder_conv_weights = false;
110 primitive conv_reorder_weights;
111 if (memory::primitive_desc(conv_pd.weights_primitive_desc())
112 != conv_user_weights_memory.get_primitive_desc()) {
113 conv_weights_memory = memory(conv_pd.weights_primitive_desc());
115 = reorder(conv_user_weights_memory, conv_weights_memory);
116 reorder_conv_weights = true;
119 /* create memory primitive for conv dst */
120 auto conv_dst_memory = memory(conv_pd.dst_primitive_desc());
122 /* finally create a convolution primitive */
124 = convolution_forward(conv_pd, conv_src_memory, conv_weights_memory,
125 conv_user_bias_memory, conv_dst_memory);
128 * {batch, 96, 55, 55} -> {batch, 96, 55, 55}
130 const float negative_slope = 1.0;
132 /* create relu primitive desc */
133 /* keep memory format of source same as the format of convolution
134 * output in order to avoid reorder */
135 auto relu_desc = eltwise_forward::desc(prop_kind::forward,
136 algorithm::eltwise_relu, conv_pd.dst_primitive_desc().desc(),
138 auto relu_pd = eltwise_forward::primitive_desc(relu_desc, cpu_engine);
140 /* create relu dst memory primitive */
141 auto relu_dst_memory = memory(relu_pd.dst_primitive_desc());
143 /* finally create a relu primitive */
144 auto relu = eltwise_forward(relu_pd, conv_dst_memory, relu_dst_memory);
147 * {batch, 96, 55, 55} -> {batch, 96, 55, 55}
153 const uint32_t local_size = 5;
154 const float alpha = 0.0001;
155 const float beta = 0.75;
158 /* create a lrn primitive descriptor */
159 auto lrn_desc = lrn_forward::desc(prop_kind::forward, lrn_across_channels,
160 relu_pd.dst_primitive_desc().desc(),
161 local_size, alpha, beta, k);
162 auto lrn_pd = lrn_forward::primitive_desc(lrn_desc, cpu_engine);
164 /* create lrn dst memory */
165 auto lrn_dst_memory = memory(lrn_pd.dst_primitive_desc());
167 /* create workspace only in training and only for forward primitive*/
168 /* query lrn_pd for workspace, this memory will be shared with forward lrn*/
169 auto lrn_workspace_memory = memory(lrn_pd.workspace_primitive_desc());
171 /* finally create a lrn primitive */
172 auto lrn = lrn_forward(lrn_pd, relu_dst_memory, lrn_workspace_memory,
176 * {batch, 96, 55, 55} -> {batch, 96, 27, 27}
180 memory::dims pool_dst_tz = { batch, 96, 27, 27 };
181 memory::dims pool_kernel = { 3, 3 };
182 memory::dims pool_strides = { 2, 2 };
183 auto pool_padding = { 0, 0 };
185 /* create memory for pool dst data in user format */
186 auto pool_user_dst_memory = memory(
187 { { { pool_dst_tz }, memory::data_type::f32, memory::format::nchw },
191 /* create pool dst memory descriptor in format any */
192 auto pool_dst_md = memory::desc({ pool_dst_tz }, memory::data_type::f32,
193 memory::format::any);
195 /* create a pooling primitive descriptor */
196 auto pool_desc = pooling_forward::desc(
197 prop_kind::forward, pooling_max,
198 lrn_dst_memory.get_primitive_desc().desc(), pool_dst_md,
199 pool_strides, pool_kernel, pool_padding, pool_padding,
201 auto pool_pd = pooling_forward::primitive_desc(pool_desc, cpu_engine);
203 /* create reorder primitive between pool dst and user dst format
205 auto pool_dst_memory = pool_user_dst_memory;
206 bool reorder_pool_dst = false;
207 primitive pool_reorder_dst;
208 if (memory::primitive_desc(pool_pd.dst_primitive_desc())
209 != pool_user_dst_memory.get_primitive_desc()) {
210 pool_dst_memory = memory(pool_pd.dst_primitive_desc());
211 pool_reorder_dst = reorder(pool_dst_memory, pool_user_dst_memory);
212 reorder_pool_dst = true;
215 /* create pooling workspace memory if training */
216 auto pool_workspace_memory = memory(pool_pd.workspace_primitive_desc());
218 /* finally create a pooling primitive */
219 auto pool = pooling_forward(pool_pd, lrn_dst_memory, pool_dst_memory,
220 pool_workspace_memory);
222 /* build forward net */
223 std::vector<primitive> net_fwd;
224 if (reorder_conv_src)
225 net_fwd.push_back(conv_reorder_src);
226 if (reorder_conv_weights)
227 net_fwd.push_back(conv_reorder_weights);
228 net_fwd.push_back(conv);
229 net_fwd.push_back(relu);
230 net_fwd.push_back(lrn);
231 net_fwd.push_back(pool);
232 if (reorder_pool_dst)
233 net_fwd.push_back(pool_reorder_dst);
235 /*----------------------------------------------------------------------*/
236 /*----------------- Backward Stream -------------------------------------*/
237 /* ... user diff_data ...*/
238 std::vector<float> net_diff_dst(batch * 96 * 27 * 27);
239 for (size_t i = 0; i < net_diff_dst.size(); ++i)
240 net_diff_dst[i] = sinf((float)i);
242 /* create memory for user diff dst data */
243 auto pool_user_diff_dst_memory = memory(
244 { { { pool_dst_tz }, memory::data_type::f32, memory::format::nchw },
246 net_diff_dst.data());
248 /* Backward pooling */
249 /* create memory descriptorsfor pooling */
250 auto pool_diff_src_md = lrn_dst_memory.get_primitive_desc().desc();
251 auto pool_diff_dst_md = pool_dst_memory.get_primitive_desc().desc();
253 /* create backward pooling descriptor*/
254 auto pool_bwd_desc = pooling_backward::desc(
255 pooling_max, pool_diff_src_md, pool_diff_dst_md, pool_strides,
256 pool_kernel, pool_padding, pool_padding, padding_kind::zero);
257 /* backward primitive descriptor needs to hint forward descriptor */
258 auto pool_bwd_pd = pooling_backward::primitive_desc(pool_bwd_desc,
259 cpu_engine, pool_pd);
261 /* create reorder primitive between user diff dst and pool diff dst
263 auto pool_diff_dst_memory = pool_user_diff_dst_memory;
264 primitive pool_reorder_diff_dst;
265 bool reorder_pool_diff_dst = false;
266 if (memory::primitive_desc(pool_dst_memory.get_primitive_desc())
267 != pool_user_diff_dst_memory.get_primitive_desc()) {
268 pool_diff_dst_memory = memory(pool_dst_memory.get_primitive_desc());
269 pool_reorder_diff_dst
270 = reorder(pool_user_diff_dst_memory, pool_diff_dst_memory);
271 reorder_pool_diff_dst = true;
274 /* create memory primitive for pool diff src */
275 auto pool_diff_src_memory = memory(pool_bwd_pd.diff_src_primitive_desc());
277 /* finally create backward pooling primitive */
279 = pooling_backward(pool_bwd_pd, pool_diff_dst_memory,
280 pool_workspace_memory, pool_diff_src_memory);
283 auto lrn_diff_dst_md = lrn_dst_memory.get_primitive_desc().desc();
285 /* create backward lrn primitive descriptor */
286 auto lrn_bwd_desc = lrn_backward::desc(
287 lrn_across_channels, lrn_pd.src_primitive_desc().desc(),
288 lrn_diff_dst_md, local_size, alpha, beta, k);
290 = lrn_backward::primitive_desc(lrn_bwd_desc, cpu_engine, lrn_pd);
292 /* create memory for lrn diff src */
293 auto lrn_diff_src_memory = memory(lrn_bwd_pd.diff_src_primitive_desc());
295 /* finally create a lrn backward primitive */
296 // backward lrn needs src: relu dst in this topology
298 = lrn_backward(lrn_bwd_pd, relu_dst_memory, pool_diff_src_memory,
299 lrn_workspace_memory, lrn_diff_src_memory);
302 auto relu_diff_dst_md = lrn_diff_src_memory.get_primitive_desc().desc();
303 auto relu_src_md = conv_pd.dst_primitive_desc().desc();
305 /* create backward relu primitive_descriptor */
306 auto relu_bwd_desc = eltwise_backward::desc(algorithm::eltwise_relu,
307 relu_diff_dst_md, relu_src_md, negative_slope);
309 = eltwise_backward::primitive_desc(relu_bwd_desc, cpu_engine, relu_pd);
311 /* create memory for relu diff src */
312 auto relu_diff_src_memory = memory(relu_bwd_pd.diff_src_primitive_desc());
314 /* finally create a backward relu primitive */
315 auto relu_bwd = eltwise_backward(relu_bwd_pd, conv_dst_memory,
316 lrn_diff_src_memory, relu_diff_src_memory);
318 /* Backward convolution with respect to weights */
319 /* create user format diff weights and diff bias memory */
320 std::vector<float> conv_user_diff_weights_buffer(
321 std::accumulate(conv_weights_tz.begin(), conv_weights_tz.end(), 1,
322 std::multiplies<uint32_t>()));
323 std::vector<float> conv_diff_bias_buffer(
324 std::accumulate(conv_bias_tz.begin(), conv_bias_tz.end(), 1,
325 std::multiplies<uint32_t>()));
327 auto conv_user_diff_weights_memory
328 = memory({ { { conv_weights_tz }, memory::data_type::f32,
329 memory::format::nchw },
331 conv_user_diff_weights_buffer.data());
332 auto conv_diff_bias_memory = memory(
333 { { { conv_bias_tz }, memory::data_type::f32, memory::format::x },
335 conv_diff_bias_buffer.data());
337 /* create memory primitives descriptors */
339 auto conv_bwd_src_md = memory::desc({ conv_src_tz }, memory::data_type::f32,
340 memory::format::any);
341 auto conv_diff_bias_md = memory::desc(
342 { conv_bias_tz }, memory::data_type::f32, memory::format::any);
343 auto conv_diff_weights_md = memory::desc(
344 { conv_weights_tz }, memory::data_type::f32, memory::format::any);
345 auto conv_diff_dst_md = memory::desc(
346 { conv_dst_tz }, memory::data_type::f32, memory::format::any);
348 /* create backward convolution primitive descriptor */
349 auto conv_bwd_weights_desc = convolution_backward_weights::desc(
350 convolution_direct, conv_bwd_src_md, conv_diff_weights_md,
351 conv_diff_bias_md, conv_diff_dst_md, conv_strides, conv_padding,
352 conv_padding, padding_kind::zero);
353 auto conv_bwd_weights_pd = convolution_backward_weights::primitive_desc(
354 conv_bwd_weights_desc, cpu_engine, conv_pd);
356 /* for best performance convolution backward might chose
357 * different memory format for src and diff_dst
358 * than the memory formats preferred by forward convolution
359 * for src and dst respectively */
360 /* create reorder primitives for src from forward convolution to the
361 * format chosen by backward convolution */
362 auto conv_bwd_src_memory = conv_src_memory;
363 primitive conv_bwd_reorder_src;
364 auto reorder_conv_bwd_src = false;
365 if (memory::primitive_desc(conv_bwd_weights_pd.src_primitive_desc())
366 != conv_src_memory.get_primitive_desc())
368 conv_bwd_src_memory = memory(conv_bwd_weights_pd.src_primitive_desc());
369 conv_bwd_reorder_src = reorder(conv_src_memory, conv_bwd_src_memory);
370 reorder_conv_bwd_src = true;
373 /* create reorder primitives for diff_dst between diff_src from relu_bwd
374 * and format preferred by conv_diff_weights */
375 auto conv_diff_dst_memory = relu_diff_src_memory;
376 primitive conv_reorder_diff_dst;
377 auto reorder_conv_diff_dst = false;
378 if (memory::primitive_desc(conv_bwd_weights_pd.diff_dst_primitive_desc())
379 != relu_diff_src_memory.get_primitive_desc())
382 = memory(conv_bwd_weights_pd.diff_dst_primitive_desc());
383 conv_reorder_diff_dst
384 = reorder(relu_diff_src_memory, conv_diff_dst_memory);
385 reorder_conv_diff_dst = true;
388 /* create reorder primitives between conv diff weights and user diff weights
390 auto conv_diff_weights_memory = conv_user_diff_weights_memory;
391 primitive conv_reorder_diff_weights;
392 bool reorder_conv_diff_weights = false;
393 if (memory::primitive_desc(
394 conv_bwd_weights_pd.diff_weights_primitive_desc())
395 != conv_user_diff_weights_memory.get_primitive_desc()) {
396 conv_diff_weights_memory
397 = memory(conv_bwd_weights_pd.diff_weights_primitive_desc());
398 conv_reorder_diff_weights = reorder(conv_diff_weights_memory,
399 conv_user_diff_weights_memory);
400 reorder_conv_diff_weights = true;
403 /* finally create backward convolution primitive */
404 auto conv_bwd_weights = convolution_backward_weights(
405 conv_bwd_weights_pd, conv_bwd_src_memory, conv_diff_dst_memory,
406 conv_diff_weights_memory, conv_diff_bias_memory);
408 /* build backward propagation net */
409 std::vector<primitive> net_bwd;
410 if (reorder_pool_diff_dst)
411 net_bwd.push_back(pool_reorder_diff_dst);
412 net_bwd.push_back(pool_bwd);
413 net_bwd.push_back(lrn_bwd);
414 net_bwd.push_back(relu_bwd);
415 if (reorder_conv_bwd_src)
416 net_bwd.push_back(conv_bwd_reorder_src);
417 if (reorder_conv_diff_dst)
418 net_bwd.push_back(conv_reorder_diff_dst);
419 net_bwd.push_back(conv_bwd_weights);
420 if (reorder_conv_diff_weights)
421 net_bwd.push_back(conv_reorder_diff_weights);
423 int n_iter = 1; //number of iterations for training
427 stream(stream::kind::eager).submit(net_fwd).wait();
429 /* update net_diff_dst */
430 // auto net_output = pool_user_dst_memory.get_data_handle();
431 /*..user updates net_diff_dst using net_output...*/
432 // some user defined func update_diff_dst(net_diff_dst.data(),
435 stream(stream::kind::eager).submit(net_bwd).wait();
436 /* update weights and bias using diff weights and bias*/
437 // auto net_diff_weights
438 // = conv_user_diff_weights_memory.get_data_handle();
439 // auto net_diff_bias = conv_diff_bias_memory.get_data_handle();
440 /* ...user updates weights and bias using diff weights and bias...*/
441 // some user defined func update_weights(conv_weights.data(),
442 // conv_bias.data(), net_diff_weights, net_diff_bias);
448 int main(int argc, char **argv)
453 std::cout << "passed" << std::endl;
457 std::cerr << "status: " << e.status << std::endl;
458 std::cerr << "message: " << e.message << std::endl;