Publishing 2019 R1 content
[platform/upstream/dldt.git] / inference-engine / thirdparty / mkl-dnn / tests / gtests / test_binary_convolution_forward_common.hpp
1 /*******************************************************************************
2 * Copyright 2019 Intel Corporation
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *     http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
16
17 #ifndef TEST_BINARY_CONVOLUTION_FORWARD_COMMON_HPP
18 #define TEST_BINARY_CONVOLUTION_FORWARD_COMMON_HPP
19
20 #include "mkldnn_test_common.hpp"
21 #include "gtest/gtest.h"
22 #include "math_utils.hpp"
23 #include "mkldnn.hpp"
24
25 using namespace mkldnn::impl::math;
26
27 namespace {
28
29 }
30
31 namespace mkldnn {
32
33 void compute_ref_bin_conv_fwd(const test_binary_convolution_params_t &p,
34         const memory::desc &src_d,
35         const memory::desc &weights_d,
36         const memory::desc &dst_d,
37         const memory &src,
38         const memory &weights,
39         const memory &dst,
40         const memory &depthwise_weights,
41         const memory &depthwise_bias)
42 {
43     auto c = p.sizes;
44
45     uint8_t* src_data = (uint8_t*)src.get_data_handle();
46     uint8_t* weights_data = (uint8_t*)weights.get_data_handle();
47     float* dst_data = (float*)dst.get_data_handle();
48
49     float *d_weights_data = (float *)depthwise_weights.get_data_handle();
50     float *d_bias_data = (float *)depthwise_bias.get_data_handle();
51
52     int nbits = 8;
53
54     size_t padded_ic = src_d.data.layout_desc.blocking.padding_dims[1];
55     size_t padded_ic_w = weights_d.data.layout_desc.blocking.padding_dims[1];
56     size_t padded_oc_w = weights_d.data.layout_desc.blocking.padding_dims[0];
57
58     auto extract_bit = [](uint8_t val, uint8_t bit) -> uint8_t {
59         return (uint8_t) ((val >> bit) & 0x0001);
60     };
61
62     mkldnn::impl::parallel_nd(c.mb, c.ng, c.oc / c.ng, c.oh, c.ow,
63         [&](int n, int g, int oc, int oh, int ow) {
64             int32_t a = 0;
65             int roi = 0;
66             for (int ic = 0; ic < c.ic; ic++) {
67                 for (int kh = 0; kh < c.kh; kh++) {
68                     for (int kw = 0; kw < c.kw; kw++) {
69                         int ih = oh * c.strh - c.padh + kh * (1 + c.dilh);
70                         int iw = ow * c.strw - c.padw + kw * (1 + c.dilw);
71
72                         size_t iidx = n * padded_ic * c.ih * c.iw
73                                       + g * padded_ic / c.ng * c.ih * c.iw
74                                       + ic * c.ih * c.iw + ih * c.iw + iw;
75                         iidx = map_index(src_d, iidx);
76
77                         uint8_t s;
78                         if (ih < 0 || ih >= c.ih || iw < 0 || iw >= c.iw) {
79                             if (p.pad_value == 0.0f) {
80                                 continue;
81                             } else {
82                                 s = p.pad_value == 1.0f ? (uint8_t)1 : (uint8_t)0;
83                             }
84                         } else {
85                              s = extract_bit(src_data[iidx/nbits], (uint8_t)(iidx % nbits));
86                         }
87
88                         size_t widx = g * padded_oc_w / c.ng * padded_ic_w
89                                       / c.ng * c.kh * c.kw
90                                       + oc * padded_ic_w / c.ng * c.kh * c.kw
91                                       + ic * c.kh * c.kw + kh * c.kw + kw;
92                         widx = map_index(weights_d, widx);
93
94                         uint8_t w = extract_bit(weights_data[widx/nbits], (uint8_t)(widx % nbits));
95
96                         a += (int32_t)(s ^ w);
97
98                         roi++;
99                     }
100                 }
101             }
102
103             float a_fp = (float)(roi - 2*a);
104
105             size_t oidx = n * c.oc * c.oh * c.ow +
106                           g * c.oc / c.ng * c.oh * c.ow +
107                           oc * c.oh * c.ow +
108                           oh * c.ow +
109                           ow;
110
111             if (p.with_sum)
112                 a_fp += dst_data[map_index(dst_d, oidx)];
113
114             switch (p.eltwise_algorithm) {
115                 case algorithm_undef:
116                     break;
117                 case eltwise_relu:
118                     a_fp = relu_fwd(a_fp, p.eltwise_alpha);
119                     break;
120                 case eltwise_tanh:
121                     a_fp = tanh_fwd(a_fp);
122                     break;
123                 case eltwise_elu:
124                     a_fp = elu_fwd(a_fp, p.eltwise_alpha);
125                     break;
126                 case eltwise_square:
127                     a_fp = square_fwd(a_fp);
128                     break;
129                 case eltwise_abs:
130                     a_fp = abs_fwd(a_fp);
131                     break;
132                 case eltwise_sqrt:
133                     a_fp = sqrt_fwd(a_fp);
134                     break;
135                 case eltwise_linear:
136                     a_fp = linear_fwd(a_fp, p.eltwise_alpha, p.eltwise_beta);
137                     break;
138                 case eltwise_bounded_relu:
139                     a_fp = bounded_relu_fwd(a_fp, p.eltwise_alpha);
140                     break;
141                 case eltwise_soft_relu:
142                     a_fp = soft_relu_fwd(a_fp);
143                     break;
144                 case eltwise_logistic:
145                     a_fp = logistic_fwd(a_fp);
146                     break;
147                 case eltwise_clamp:
148                     a_fp = clamp_fwd(a_fp, p.eltwise_alpha, p.eltwise_beta);
149                     break;
150                 default:
151                     assert(!"unknown alg_kind");
152             }
153
154             switch (p.depthwise_algorithm) {
155                 case algorithm_undef:
156                     break;
157                 case depthwise_scale_shift:
158                     a_fp = scale_shift_fwd(a_fp, d_weights_data[g * c.oc / c.ng + oc], d_bias_data[g * c.oc / c.ng + oc]);
159                     break;
160                 case depthwise_prelu:
161                     a_fp = prelu_fwd(a_fp, d_weights_data[g * c.oc / c.ng + oc]);
162                     break;
163                 default: assert(!"unknown alg_kind");
164             }
165
166             dst_data[map_index(dst_d, oidx)] = a_fp;
167         }
168     );
169 }
170
171 void compute_ref_binarization_fwd(const test_binary_convolution_params_t &p,
172     const memory::desc &src_md, const memory &src, const memory &weights, const memory &dst) {
173     auto src_data = (float*)src.get_data_handle();
174     auto weights_data = (float*)weights.get_data_handle();
175     auto dst_data = (uint8_t*)dst.get_data_handle();
176
177     const memory::desc src_d = src.get_primitive_desc().desc();
178     const memory::desc weights_d = weights.get_primitive_desc().desc();
179     const memory::desc dst_d = dst.get_primitive_desc().desc();
180
181     int N = src_md.data.ndims > 0 ? src_md.data.dims[0] : 1;
182     int C = src_md.data.ndims > 1 ? src_md.data.dims[1] : 1;
183     int H = src_md.data.ndims > 2 ? src_md.data.dims[2] : 1;
184     int W = src_md.data.ndims > 3 ? src_md.data.dims[3] : 1;
185
186     int nbits = 8;
187     int CB = div_up(C, nbits);
188
189     int padded_ic = src_d.data.layout_desc.blocking.padding_dims[1];
190     int padded_oc = dst_d.data.layout_desc.blocking.padding_dims[1];
191
192     for (int n = 0; n < N; ++n) {
193         for (int cb = 0; cb < CB; ++cb) {
194             for (int h = 0; h < H; ++h) {
195                 for (int w = 0; w < W; ++w) {
196
197                     uint8_t bin_val = 0x00;
198                     for (int c = cb * nbits, shift = 0; c < std::min(C, (cb + 1) * nbits); c++, shift++) {
199                         int src_idx = n*padded_ic*H*W + c*H*W + h*W + w;
200                         int wei_idx = c;
201
202                         float s_val = src_data[map_index(src_d, src_idx)];
203                         float w_val = weights_data[map_index(weights_d, wei_idx)];
204
205                         auto bit = uint8_t((s_val > w_val) ? 0x01 : 0x00);
206                         bin_val |= (bit << shift);
207                     }
208
209                     int dst_idx = n*padded_oc*H*W + cb*nbits*H*W + h*W + w;
210                     dst_idx = map_index(dst_d, dst_idx);
211                     dst_data[dst_idx / nbits] = bin_val;
212                 }
213             }
214         }
215     }
216 }
217
218 class binary_convolution_forward_test : public ::testing::TestWithParam<test_binary_convolution_params_t>
219 {
220 protected:
221     virtual void SetUp()
222     {
223         test_binary_convolution_params_t p = ::testing::TestWithParam<test_binary_convolution_params_t>::GetParam();
224
225         ASSERT_TRUE(p.engine_kind == engine::kind::cpu);
226         ASSERT_EQ(p.aalgorithm, algorithm::binary_convolution_direct);
227
228         test_convolution_sizes_t cd = p.sizes;
229
230         auto eng = engine(p.engine_kind, 0);
231         auto aprop_kind = prop_kind::forward;
232         bool with_binarization = p.binarization_algorithm != algorithm_undef;
233
234         memory::data_type data_type_src = memory::data_type::bin;
235         memory::data_type data_type_wei = memory::data_type::bin;
236         memory::data_type data_type_bia = memory::data_type::f32;
237         memory::data_type data_type_dst = with_binarization ? memory::data_type::bin
238                                                             : data_traits<float>::data_type;
239
240         auto c_src_desc = create_md({ cd.mb, cd.ic, cd.ih, cd.iw }, data_type_src, p.formats.src_format);
241         auto c_weights_desc = cd.ng > 1
242                 ? create_md({ cd.ng, cd.oc / cd.ng, cd.ic / cd.ng, cd.kh, cd.kw }, data_type_wei, p.formats.weights_format)
243                 : create_md({ cd.oc, cd.ic, cd.kh, cd.kw }, data_type_wei, p.formats.weights_format);
244         auto c_dst_desc = create_md({ cd.mb, cd.oc, cd.oh, cd.ow }, data_type_dst, p.formats.dst_format);
245
246         auto c_src = test_memory(c_src_desc, eng);
247         auto c_weights = test_memory(c_weights_desc, eng);
248         auto c_dst = test_memory(c_dst_desc, eng);
249
250         // Only true for dense format
251         if (with_binarization)
252             fill_data<uint8_t>(c_dst.get_size() / sizeof(uint8_t), (uint8_t*)c_dst.get().get_data_handle());
253         else
254             fill_data<float>(c_dst.get_size() / sizeof(float), (float*)c_dst.get().get_data_handle());
255         fill_data<uint8_t>(c_src.get_size() / sizeof(uint8_t), (uint8_t*)c_src.get().get_data_handle());
256         fill_data<uint8_t>(c_weights.get_size() / sizeof(uint8_t), (uint8_t*)c_weights.get().get_data_handle());
257
258         std::vector<ptrdiff_t> padR = {
259             right_padding(cd.ih, cd.oh, cd.kh, cd.padh, cd.strh, cd.dilh),
260             right_padding(cd.iw, cd.ow, cd.kw, cd.padw, cd.strw, cd.dilw)
261         };
262
263         auto bin_conv_desc = binary_convolution_forward::desc(aprop_kind, p.aalgorithm,
264                     c_src_desc, c_weights_desc, c_dst_desc,
265                     { cd.strh, cd.strw }, { cd.dilh, cd.dilw },
266                     { cd.padh, cd.padw }, padR, p.pad_value);
267
268         mkldnn::post_ops ops;
269
270         if (p.with_sum)
271             ops.append_sum();
272
273         if (p.eltwise_algorithm != algorithm_undef)
274             ops.append_eltwise(1.0, p.eltwise_algorithm, p.eltwise_alpha, p.eltwise_beta);
275
276         auto c_depthwise_weights_desc = create_md({ cd.oc }, data_type_bia, memory::x);
277         auto c_depthwise_bias_desc = create_md({ cd.oc }, data_type_bia, memory::x);
278
279         auto c_depthwise_weights = memory({c_depthwise_weights_desc, eng});
280         auto c_depthwise_bias = memory({c_depthwise_bias_desc, eng});
281
282         if (p.depthwise_algorithm != algorithm_undef) {
283             fill_data<float>(c_depthwise_weights.get_primitive_desc().get_size() / sizeof(float),
284                              (float *)c_depthwise_weights.get_data_handle(), 1., true);
285             fill_data<float>(c_depthwise_bias.get_primitive_desc().get_size() / sizeof(float),
286                              (float *)c_depthwise_bias.get_data_handle(), 1., true);
287
288             ops.append_depthwise(p.depthwise_algorithm, static_cast<const float*>(c_depthwise_weights.get_data_handle()),
289                                                         static_cast<const float*>(c_depthwise_bias.get_data_handle()));
290         }
291
292         auto c_binarization_weights_desc = create_md({ cd.oc }, memory::data_type::f32, memory::x);
293         auto c_binarization_weights = memory({c_binarization_weights_desc, eng});
294
295         if (p.binarization_algorithm != algorithm_undef) {
296             fill_data<float>(c_binarization_weights.get_primitive_desc().get_size() / sizeof(float),
297                              (float *)c_binarization_weights.get_data_handle(), 1., true);
298
299             ops.append_binarization(p.binarization_algorithm, static_cast<const float*>(c_binarization_weights.get_data_handle()));
300         }
301
302         mkldnn::primitive_attr attr;
303         attr.set_post_ops(ops);
304
305         auto bin_conv_primitive_desc = binary_convolution_forward::primitive_desc(bin_conv_desc, attr, eng);
306
307         auto bin_conv = binary_convolution_forward(bin_conv_primitive_desc, c_src.get(), c_weights.get(), c_dst.get());
308
309         if (with_binarization) {
310             auto c_dst_desc_ref = create_md({ cd.mb, cd.oc, cd.oh, cd.ow }, memory::data_type::f32, p.formats.dst_format);
311             auto c_dst_ref = test_memory(c_dst_desc_ref, eng);
312
313             std::vector<float> ref_dst_conv_data(c_dst_ref.get_size() / sizeof(float));
314             auto ref_conv_memory = memory(memory::primitive_desc(c_dst_desc_ref, eng), &ref_dst_conv_data[0]);
315
316             std::vector<uint8_t > ref_dst_data(c_dst.get_size() / sizeof(uint8_t));
317             auto ref_memory = memory(memory::primitive_desc(c_dst_desc, eng), &ref_dst_data[0]);
318
319             compute_ref_bin_conv_fwd(p, c_src_desc, c_weights_desc, c_dst_desc_ref,
320                                      c_src.get(), c_weights.get(), ref_conv_memory,
321                                      c_depthwise_weights, c_depthwise_bias);
322
323             compute_ref_binarization_fwd(p, c_dst_desc_ref, ref_conv_memory, c_binarization_weights, ref_memory);
324
325             std::vector<primitive> pipeline;
326             pipeline.push_back(bin_conv);
327             auto s = stream(stream::kind::lazy);
328             s.submit(pipeline).wait();
329
330             compare_data<uint8_t>(ref_memory, c_dst.get(), 0, true);
331         } else {
332             std::vector<float> ref_dst_data(c_dst.get_size() / sizeof(float));
333             memcpy(&ref_dst_data[0], (float*)c_dst.get().get_data_handle(), ref_dst_data.size() * sizeof(float));
334             auto ref_memory = memory(memory::primitive_desc(c_dst_desc, eng), &ref_dst_data[0]);
335
336             compute_ref_bin_conv_fwd(p, c_src_desc, c_weights_desc, c_dst_desc,
337                                      c_src.get(), c_weights.get(), ref_memory,
338                                      c_depthwise_weights, c_depthwise_bias);
339
340             std::vector<primitive> pipeline;
341             pipeline.push_back(bin_conv);
342             auto s = stream(stream::kind::lazy);
343             s.submit(pipeline).wait();
344
345             compare_data<float>(ref_memory, c_dst.get(), 1e-3);
346         }
347     }
348 };
349
350 }
351
352 #endif