1 /*******************************************************************************
2 * Copyright 2019 Intel Corporation
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
8 * http://www.apache.org/licenses/LICENSE-2.0
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
17 #ifndef TEST_BINARY_CONVOLUTION_DW_CONV_FORWARD_COMMON_HPP
18 #define TEST_BINARY_CONVOLUTION_DW_CONV_FORWARD_COMMON_HPP
20 #include "mkldnn_test_common.hpp"
21 #include "gtest/gtest.h"
22 #include "math_utils.hpp"
25 using namespace mkldnn::impl::math;
29 void compute_ref_bin_conv_fwd(const test_binary_convolution_dw_conv_params_t &p,
30 const memory::desc &src_d,
31 const memory::desc &weights_d,
32 const memory::desc &dst_d,
34 const memory &weights,
36 const memory &depthwise_weights,
37 const memory &depthwise_bias)
39 auto src_dims = src_d.data.dims;
40 auto dst_dims = dst_d.data.dims;
42 test_convolution_sizes_t c = {(int)src_dims[0], 1, sizes.ic, (int)src_dims[2], (int)src_dims[3],
43 (int)dst_dims[1], (int)dst_dims[2], (int)dst_dims[3],
44 sizes.conv1_kh, sizes.conv1_kw, sizes.conv1_padh, sizes.conv1_padw, sizes.conv1_strh, sizes.conv1_strw};
46 float pad_value = -1.f;
48 uint8_t* src_data = (uint8_t*)src.get_data_handle();
49 uint8_t* weights_data = (uint8_t*)weights.get_data_handle();
50 float* dst_data = (float*)dst.get_data_handle();
52 float *d_weights_data = (float *)depthwise_weights.get_data_handle();
53 float *d_bias_data = (float *)depthwise_bias.get_data_handle();
57 size_t padded_ic = src_d.data.layout_desc.blocking.padding_dims[1];
58 size_t padded_ic_w = weights_d.data.layout_desc.blocking.padding_dims[1];
59 size_t padded_oc_w = weights_d.data.layout_desc.blocking.padding_dims[0];
61 auto extract_bit = [](uint8_t val, uint8_t bit) -> uint8_t {
62 return (uint8_t) ((val >> bit) & 0x0001);
65 mkldnn::impl::parallel_nd(c.mb, c.ng, c.oc / c.ng, c.oh, c.ow,
66 [&](int n, int g, int oc, int oh, int ow) {
69 for (int ic = 0; ic < c.ic; ic++) {
70 for (int kh = 0; kh < c.kh; kh++) {
71 for (int kw = 0; kw < c.kw; kw++) {
72 int ih = oh * c.strh - c.padh + kh * (1 + c.dilh);
73 int iw = ow * c.strw - c.padw + kw * (1 + c.dilw);
75 size_t iidx = n * padded_ic * c.ih * c.iw
76 + g * padded_ic / c.ng * c.ih * c.iw
77 + ic * c.ih * c.iw + ih * c.iw + iw;
78 iidx = map_index(src_d, iidx);
81 if (ih < 0 || ih >= c.ih || iw < 0 || iw >= c.iw) {
82 if (pad_value == 0.0f) {
85 s = pad_value == 1.0f ? (uint8_t)1 : (uint8_t)0;
88 s = extract_bit(src_data[iidx/nbits], (uint8_t)(iidx % nbits));
91 size_t widx = g * padded_oc_w / c.ng * padded_ic_w
93 + oc * padded_ic_w / c.ng * c.kh * c.kw
94 + ic * c.kh * c.kw + kh * c.kw + kw;
95 widx = map_index(weights_d, widx);
97 uint8_t w = extract_bit(weights_data[widx/nbits], (uint8_t)(widx % nbits));
99 a += (int32_t)(s ^ w);
106 float a_fp = (float)(roi - 2*a);
108 size_t oidx = n * c.oc * c.oh * c.ow +
109 g * c.oc / c.ng * c.oh * c.ow +
114 switch (p.eltwise_algorithm) {
115 case algorithm_undef:
118 a_fp = relu_fwd(a_fp, p.eltwise_alpha);
121 a_fp = tanh_fwd(a_fp);
124 a_fp = elu_fwd(a_fp, p.eltwise_alpha);
127 a_fp = square_fwd(a_fp);
130 a_fp = abs_fwd(a_fp);
133 a_fp = sqrt_fwd(a_fp);
136 a_fp = linear_fwd(a_fp, p.eltwise_alpha, p.eltwise_beta);
138 case eltwise_bounded_relu:
139 a_fp = bounded_relu_fwd(a_fp, p.eltwise_alpha);
141 case eltwise_soft_relu:
142 a_fp = soft_relu_fwd(a_fp);
144 case eltwise_logistic:
145 a_fp = logistic_fwd(a_fp);
148 a_fp = clamp_fwd(a_fp, p.eltwise_alpha, p.eltwise_beta);
151 assert(!"unknown alg_kind");
154 switch (p.depthwise_algorithm) {
155 case algorithm_undef:
157 case depthwise_scale_shift:
158 a_fp = scale_shift_fwd(a_fp, d_weights_data[g * c.oc / c.ng + oc], d_bias_data[g * c.oc / c.ng + oc]);
160 case depthwise_prelu:
161 a_fp = prelu_fwd(a_fp, d_weights_data[g * c.oc / c.ng + oc]);
163 default: assert(!"unknown alg_kind");
166 dst_data[map_index(dst_d, oidx)] = a_fp;
171 void compute_ref_dw_conv_fwd(const test_binary_convolution_dw_conv_params_t &p,
172 const memory &src, const memory &weights, const memory &bias, const memory &dst,
173 const memory &depthwise_weights, const memory &depthwise_bias)
175 const memory::desc src_d = src.get_primitive_desc().desc();
176 const memory::desc weights_d = weights.get_primitive_desc().desc();
177 const memory::desc dst_d = dst.get_primitive_desc().desc();
179 auto src_dims = src_d.data.dims;
180 auto dst_dims = dst_d.data.dims;
182 int MB = src_dims[0];
184 int IC = src_dims[1];
185 int IH = src_dims[2];
186 int IW = src_dims[3];
187 int OC = dst_dims[1];
188 int OH = dst_dims[2];
189 int OW = dst_dims[3];
191 int KH = p.sizes.conv2_kh;
192 int KW = p.sizes.conv2_kw;
193 int SH = p.sizes.conv2_strh;
194 int SW = p.sizes.conv2_strw;
195 int PH = p.sizes.conv2_padh;
196 int PW = p.sizes.conv2_padw;
200 float *src_data = (float *)src.get_data_handle();
201 float *weights_data = (float *)weights.get_data_handle();
202 float *bias_data = (float *)bias.get_data_handle();
203 float *dst_data = (float *)dst.get_data_handle();
205 float *d_weights_data = (float *)depthwise_weights.get_data_handle();
206 float *d_bias_data = (float *)depthwise_bias.get_data_handle();
208 mkldnn::impl::parallel_nd(MB, G, OC / G, OH, OW,
209 [&](int n, int g, int oc, int oh, int ow) {
210 int oidx = n * OC * OH * OW
211 + g * OC / G * OH * OW
212 + oc * OH * OW + oh * OW + ow;
216 for (int ic = 0; ic < IC / G; ic++) {
217 for (int kh = 0; kh < KH; kh++) {
218 for (int kw = 0; kw < KW; kw++) {
220 - PW + kw * (1 + DW);
222 - PH + kh * (1 + DH);
223 if (iw < 0 || iw >= IW) continue;
224 if (ih < 0 || ih >= IH) continue;
225 int iidx = n * IC * IH * IW
226 + g * IC / G * IH * IW
227 + ic * IH * IW + ih * IW + iw;
228 int widx = g * OC / G * IC
230 + oc * IC / G * KH * KW
231 + ic * KH * KW + kh * KW + kw;
233 iidx = map_index(src_d, iidx);
235 float s = src_data[iidx];
236 float w = weights_data[map_index(weights_d, widx)];
244 float a_fp = (float)a;
246 a_fp += bias_data[G > 1 ? g : oc];
249 a_fp += dst_data[map_index(dst_d, oidx)];
251 switch (p.eltwise_algorithm) {
252 case algorithm_undef:
255 a_fp = relu_fwd(a_fp, p.eltwise_alpha);
258 a_fp = tanh_fwd(a_fp);
261 a_fp = elu_fwd(a_fp, p.eltwise_alpha);
264 a_fp = square_fwd(a_fp);
267 a_fp = abs_fwd(a_fp);
270 a_fp = sqrt_fwd(a_fp);
273 a_fp = linear_fwd(a_fp, p.eltwise_alpha, p.eltwise_beta);
275 case eltwise_bounded_relu:
276 a_fp = bounded_relu_fwd(a_fp, p.eltwise_alpha);
278 case eltwise_soft_relu:
279 a_fp = soft_relu_fwd(a_fp);
281 case eltwise_logistic:
282 a_fp = logistic_fwd(a_fp);
285 a_fp = clamp_fwd(a_fp, p.eltwise_alpha, p.eltwise_beta);
288 assert(!"unknown alg_kind");
291 switch (p.depthwise_algorithm) {
292 case algorithm_undef:
294 case depthwise_scale_shift:
295 a_fp = scale_shift_fwd(a_fp, d_weights_data[g * OC / G + oc], d_bias_data[g * OC / G + oc]);
297 case depthwise_prelu:
298 a_fp = prelu_fwd(a_fp, d_weights_data[g * OC / G + oc]);
300 default: assert(!"unknown alg_kind");
303 dst_data[map_index(dst_d, oidx)] = (float)a_fp;
308 void compute_ref_binarization_fwd(const test_binary_convolution_dw_conv_params_t &p,
309 const memory::desc &src_md, const memory &src, const memory &weights, const memory &dst) {
310 auto src_data = (float*)src.get_data_handle();
311 auto weights_data = (float*)weights.get_data_handle();
312 auto dst_data = (uint8_t*)dst.get_data_handle();
314 const memory::desc src_d = src.get_primitive_desc().desc();
315 const memory::desc weights_d = weights.get_primitive_desc().desc();
316 const memory::desc dst_d = dst.get_primitive_desc().desc();
318 int N = src_md.data.ndims > 0 ? src_md.data.dims[0] : 1;
319 int C = src_md.data.ndims > 1 ? src_md.data.dims[1] : 1;
320 int H = src_md.data.ndims > 2 ? src_md.data.dims[2] : 1;
321 int W = src_md.data.ndims > 3 ? src_md.data.dims[3] : 1;
324 int CB = div_up(C, nbits);
326 int padded_ic = src_d.data.layout_desc.blocking.padding_dims[1];
327 int padded_oc = dst_d.data.layout_desc.blocking.padding_dims[1];
329 for (int n = 0; n < N; ++n) {
330 for (int cb = 0; cb < CB; ++cb) {
331 for (int h = 0; h < H; ++h) {
332 for (int w = 0; w < W; ++w) {
334 uint8_t bin_val = 0x00;
335 for (int c = cb * nbits, shift = 0; c < std::min(C, (cb + 1) * nbits); c++, shift++) {
336 int src_idx = n*padded_ic*H*W + c*H*W + h*W + w;
339 float s_val = src_data[map_index(src_d, src_idx)];
340 float w_val = weights_data[map_index(weights_d, wei_idx)];
342 auto bit = uint8_t((s_val > w_val) ? 0x01 : 0x00);
343 bin_val |= (bit << shift);
346 int dst_idx = n*padded_oc*H*W + cb*nbits*H*W + h*W + w;
347 dst_idx = map_index(dst_d, dst_idx);
348 dst_data[dst_idx / nbits] = bin_val;
355 class binary_convolution_forward_test : public ::testing::TestWithParam<test_binary_convolution_dw_conv_params_t>
360 test_binary_convolution_dw_conv_params_t p = ::testing::TestWithParam<test_binary_convolution_dw_conv_params_t>::GetParam();
362 ASSERT_TRUE(p.engine_kind == engine::kind::cpu);
363 ASSERT_EQ(p.aalgorithm, algorithm::binary_convolution_direct);
365 test_convolution_dw_conv_sizes_t cd = p.sizes;
367 auto eng = engine(p.engine_kind, 0);
368 auto aprop_kind = prop_kind::forward;
369 bool with_binarization = p.binarization_algorithm != algorithm_undef;
372 memory::data_type data_type_bin_conv_src = memory::data_type::bin;
373 memory::data_type data_type_bin_conv_wei = memory::data_type::bin;
374 memory::data_type data_type_bin_conv_bia = data_traits<float>::data_type;
375 memory::data_type data_type_bin_conv_dst = data_traits<float>::data_type;
377 memory::data_type data_type_dw_conv_wei = data_traits<float>::data_type;
378 memory::data_type data_type_dw_conv_bia = data_traits<float>::data_type;
379 memory::data_type data_type_dw_conv_dst = with_binarization ? memory::data_type::bin
380 : data_traits<float>::data_type;
382 int bin_conv_oh = (cd.ih - ((cd.conv1_kh - 1) + 1) + 2 * cd.conv1_padh) / cd.conv1_strh + 1;
383 int bin_conv_ow = (cd.iw - ((cd.conv1_kw - 1) + 1) + 2 * cd.conv1_padw) / cd.conv1_strw + 1;
385 int dw_conv_oh = (bin_conv_oh - ((cd.conv2_kh - 1) + 1) + 2 * cd.conv2_padh) / cd.conv2_strh + 1;
386 int dw_conv_ow = (bin_conv_ow - ((cd.conv2_kw - 1) + 1) + 2 * cd.conv2_padw) / cd.conv2_strw + 1;
388 std::vector<ptrdiff_t> bin_conv_padR = { cd.conv1_padh, cd.conv1_padw };
389 bin_conv_padR[0] += dw_conv_oh - bin_conv_oh;
390 bin_conv_padR[1] += dw_conv_ow - bin_conv_ow;
392 auto bin_conv_src_desc = create_md({ cd.mb, cd.ic, cd.ih, cd.iw }, data_type_bin_conv_src, p.formats.src_format);
393 auto bin_conv_weights_desc = create_md({ cd.conv1_oc, cd.ic, cd.conv1_kh, cd.conv1_kw }, data_type_bin_conv_wei, p.formats.conv1_weights_format);
394 auto bin_conv_dst_desc = create_md({ cd.mb, cd.conv1_oc, dw_conv_oh, dw_conv_ow }, data_type_bin_conv_dst, p.formats.dst_format);
396 auto bin_conv_src = test_memory(bin_conv_src_desc, eng);
397 auto bin_conv_weights = test_memory(bin_conv_weights_desc, eng);
399 fill_data<uint8_t>(bin_conv_src.get_size() / sizeof(uint8_t), (uint8_t*)bin_conv_src.get().get_data_handle());
400 fill_data<uint8_t>(bin_conv_weights.get_size() / sizeof(uint8_t), (uint8_t*)bin_conv_weights.get().get_data_handle());
402 auto dw_conv_weights_desc = create_md({ cd.conv2_oc, 1, 1, cd.conv2_kh, cd.conv2_kw }, data_type_dw_conv_wei, p.formats.conv2_weights_format);
403 auto dw_conv_dst_desc = create_md({ cd.mb, cd.conv2_oc, dw_conv_oh, dw_conv_ow }, data_type_dw_conv_dst, p.formats.dst_format);
404 auto dw_conv_bias_desc = create_md({ cd.conv2_oc }, data_type_dw_conv_bia, p.formats.conv2_bias_format);
406 auto dw_conv_weights = test_memory(dw_conv_weights_desc, eng);
407 auto dw_conv_bias = test_memory(dw_conv_bias_desc, eng);
408 auto dw_conv_dst = test_memory(dw_conv_dst_desc, eng);
410 if (with_binarization)
411 fill_data<uint8_t>(dw_conv_dst.get_size() / sizeof(uint8_t), (uint8_t*)dw_conv_dst.get().get_data_handle());
413 fill_data<float>(dw_conv_dst.get_size() / sizeof(float), (float*)dw_conv_dst.get().get_data_handle());
415 fill_data<float>(dw_conv_weights.get_size() / sizeof(float), (float*)dw_conv_weights.get().get_data_handle());
416 fill_data<float>(dw_conv_bias.get_size() / sizeof(float), (float*)dw_conv_bias.get().get_data_handle());
418 auto bin_conv_desc = binary_convolution_forward::desc(aprop_kind, p.aalgorithm,
419 bin_conv_src_desc, bin_conv_weights_desc, bin_conv_dst_desc,
420 { cd.conv1_strh, cd.conv1_strw }, { 0, 0 },
421 { cd.conv1_padh, cd.conv1_padw }, bin_conv_padR, -1.f);
423 mkldnn::post_ops bin_conv_post_ops;
424 if (p.eltwise_algorithm != algorithm_undef)
425 bin_conv_post_ops.append_eltwise(1.0, p.eltwise_algorithm, p.eltwise_alpha, p.eltwise_beta);
427 auto bin_conv_depthwise_weights_desc = create_md({ cd.conv1_oc }, data_type_bin_conv_bia, memory::x);
428 auto bin_conv_depthwise_bias_desc = create_md({ cd.conv1_oc }, data_type_bin_conv_bia, memory::x);
429 auto bin_conv_depthwise_weights = memory({bin_conv_depthwise_weights_desc, eng});
430 auto bin_conv_depthwise_bias = memory({bin_conv_depthwise_bias_desc, eng});
432 if (p.depthwise_algorithm != algorithm_undef) {
433 fill_data<float>(bin_conv_depthwise_weights.get_primitive_desc().get_size() / sizeof(float),
434 (float *)bin_conv_depthwise_weights.get_data_handle(), 1., true);
435 fill_data<float>(bin_conv_depthwise_bias.get_primitive_desc().get_size() / sizeof(float),
436 (float *)bin_conv_depthwise_bias.get_data_handle(), 1., true);
438 bin_conv_post_ops.append_depthwise(p.depthwise_algorithm, static_cast<const float*>(bin_conv_depthwise_weights.get_data_handle()),
439 static_cast<const float*>(bin_conv_depthwise_bias.get_data_handle()));
442 bin_conv_post_ops.append_dw_conv(bin_conv_oh, bin_conv_ow, cd.conv2_kh, cd.conv2_kw, cd.conv2_strh, cd.conv2_strw,
443 static_cast<const float*>(dw_conv_weights.get().get_data_handle()),
444 static_cast<const float*>(dw_conv_bias.get().get_data_handle()));
447 bin_conv_post_ops.append_sum();
449 if (p.eltwise_algorithm != algorithm_undef)
450 bin_conv_post_ops.append_eltwise(1.0, p.eltwise_algorithm, p.eltwise_alpha, p.eltwise_beta);
452 auto dw_conv_depthwise_weights_desc = create_md({ cd.conv2_oc }, data_type_bin_conv_bia, memory::x);
453 auto dw_conv_depthwise_bias_desc = create_md({ cd.conv2_oc }, data_type_bin_conv_bia, memory::x);
454 auto dw_conv_depthwise_weights = memory({dw_conv_depthwise_weights_desc, eng});
455 auto dw_conv_depthwise_bias = memory({dw_conv_depthwise_bias_desc, eng});
457 if (p.depthwise_algorithm != algorithm_undef) {
458 fill_data<float>(dw_conv_depthwise_weights.get_primitive_desc().get_size() / sizeof(float),
459 (float *)dw_conv_depthwise_weights.get_data_handle(), 1., true);
460 fill_data<float>(dw_conv_depthwise_bias.get_primitive_desc().get_size() / sizeof(float),
461 (float *)dw_conv_depthwise_bias.get_data_handle(), 1., true);
463 bin_conv_post_ops.append_depthwise(p.depthwise_algorithm, static_cast<const float*>(dw_conv_depthwise_weights.get_data_handle()),
464 static_cast<const float*>(dw_conv_depthwise_bias.get_data_handle()));
467 auto dw_conv_binarization_weights_desc = create_md({ cd.conv2_oc }, memory::data_type::f32, memory::x);
468 auto dw_conv_binarization_weights = memory({dw_conv_binarization_weights_desc, eng});
470 if (p.binarization_algorithm != algorithm_undef) {
471 fill_data<float>(dw_conv_binarization_weights.get_primitive_desc().get_size() / sizeof(float),
472 (float *)dw_conv_binarization_weights.get_data_handle(), 0.f, p.sizes.conv2_oc * p.sizes.conv2_kh * p.sizes.conv2_kw);
474 bin_conv_post_ops.append_binarization(p.binarization_algorithm, static_cast<const float*>(dw_conv_binarization_weights.get_data_handle()));
477 mkldnn::primitive_attr bin_conv_attr;
478 bin_conv_attr.set_post_ops(bin_conv_post_ops);
480 auto bin_conv_primitive_desc = binary_convolution_forward::primitive_desc(bin_conv_desc, bin_conv_attr, eng);
482 auto bin_conv = binary_convolution_forward(bin_conv_primitive_desc, bin_conv_src.get(), bin_conv_weights.get(), dw_conv_dst.get());
484 auto bin_conv_dst_desc_ref = create_md({ cd.mb, cd.conv1_oc, bin_conv_oh, bin_conv_ow }, data_type_bin_conv_dst, p.formats.dst_format);
485 auto ref_bin_conv_dst = test_memory(bin_conv_dst_desc_ref, eng);
486 compute_ref_bin_conv_fwd(p, bin_conv_src_desc, bin_conv_weights_desc, bin_conv_dst_desc_ref,
487 bin_conv_src.get(), bin_conv_weights.get(), ref_bin_conv_dst.get(),
488 bin_conv_depthwise_weights, bin_conv_depthwise_bias);
490 if (with_binarization) {
491 auto ref_dw_conv_dst_desc = create_md({ cd.mb, cd.conv2_oc, dw_conv_oh, dw_conv_ow }, memory::data_type::f32, p.formats.dst_format);
492 auto ref_dw_conv_dst = test_memory(ref_dw_conv_dst_desc, eng);
494 compute_ref_dw_conv_fwd(p, ref_bin_conv_dst.get(), dw_conv_weights.get(), dw_conv_bias.get(),
495 ref_dw_conv_dst.get(),
496 dw_conv_depthwise_weights, dw_conv_depthwise_bias);
498 auto ref_binarization_dst = test_memory(dw_conv_dst_desc, eng);
500 compute_ref_binarization_fwd(p, ref_dw_conv_dst_desc, ref_dw_conv_dst.get(), dw_conv_binarization_weights, ref_binarization_dst.get());
502 std::vector<primitive> pipeline;
503 pipeline.push_back(bin_conv);
504 auto s = stream(stream::kind::lazy);
505 s.submit(pipeline).wait();
507 compare_data<uint8_t>(ref_binarization_dst.get(), dw_conv_dst.get(), 0, true);
509 auto ref_dw_conv_dst = test_memory(dw_conv_dst_desc, eng);
510 memcpy((float *) ref_dw_conv_dst.get().get_data_handle(), (float *) dw_conv_dst.get().get_data_handle(),
511 ref_dw_conv_dst.get_size());
512 compute_ref_dw_conv_fwd(p, ref_bin_conv_dst.get(), dw_conv_weights.get(), dw_conv_bias.get(),
513 ref_dw_conv_dst.get(),
514 dw_conv_depthwise_weights, dw_conv_depthwise_bias);
516 std::vector<primitive> pipeline;
517 pipeline.push_back(bin_conv);
518 auto s = stream(stream::kind::lazy);
519 s.submit(pipeline).wait();
521 compare_data<float>(ref_dw_conv_dst.get(), dw_conv_dst.get(), 1e-3);