1 /*******************************************************************************
2 * Copyright 2019 Intel Corporation
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
8 * http://www.apache.org/licenses/LICENSE-2.0
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
17 #ifndef TEST_BINARY_CONVOLUTION_FORWARD_COMMON_HPP
18 #define TEST_BINARY_CONVOLUTION_FORWARD_COMMON_HPP
20 #include "mkldnn_test_common.hpp"
21 #include "gtest/gtest.h"
22 #include "math_utils.hpp"
25 using namespace mkldnn::impl::math;
33 void compute_ref_bin_conv_fwd(const test_binary_convolution_params_t &p,
34 const memory::desc &src_d,
35 const memory::desc &weights_d,
36 const memory::desc &dst_d,
38 const memory &weights,
40 const memory &depthwise_weights,
41 const memory &depthwise_bias)
45 uint8_t* src_data = (uint8_t*)src.get_data_handle();
46 uint8_t* weights_data = (uint8_t*)weights.get_data_handle();
47 float* dst_data = (float*)dst.get_data_handle();
49 float *d_weights_data = (float *)depthwise_weights.get_data_handle();
50 float *d_bias_data = (float *)depthwise_bias.get_data_handle();
54 size_t padded_ic = src_d.data.layout_desc.blocking.padding_dims[1];
55 size_t padded_ic_w = weights_d.data.layout_desc.blocking.padding_dims[1];
56 size_t padded_oc_w = weights_d.data.layout_desc.blocking.padding_dims[0];
58 auto extract_bit = [](uint8_t val, uint8_t bit) -> uint8_t {
59 return (uint8_t) ((val >> bit) & 0x0001);
62 mkldnn::impl::parallel_nd(c.mb, c.ng, c.oc / c.ng, c.oh, c.ow,
63 [&](int n, int g, int oc, int oh, int ow) {
66 for (int ic = 0; ic < c.ic; ic++) {
67 for (int kh = 0; kh < c.kh; kh++) {
68 for (int kw = 0; kw < c.kw; kw++) {
69 int ih = oh * c.strh - c.padh + kh * (1 + c.dilh);
70 int iw = ow * c.strw - c.padw + kw * (1 + c.dilw);
72 size_t iidx = n * padded_ic * c.ih * c.iw
73 + g * padded_ic / c.ng * c.ih * c.iw
74 + ic * c.ih * c.iw + ih * c.iw + iw;
75 iidx = map_index(src_d, iidx);
78 if (ih < 0 || ih >= c.ih || iw < 0 || iw >= c.iw) {
79 if (p.pad_value == 0.0f) {
82 s = p.pad_value == 1.0f ? (uint8_t)1 : (uint8_t)0;
85 s = extract_bit(src_data[iidx/nbits], (uint8_t)(iidx % nbits));
88 size_t widx = g * padded_oc_w / c.ng * padded_ic_w
90 + oc * padded_ic_w / c.ng * c.kh * c.kw
91 + ic * c.kh * c.kw + kh * c.kw + kw;
92 widx = map_index(weights_d, widx);
94 uint8_t w = extract_bit(weights_data[widx/nbits], (uint8_t)(widx % nbits));
96 a += (int32_t)(s ^ w);
103 float a_fp = (float)(roi - 2*a);
105 size_t oidx = n * c.oc * c.oh * c.ow +
106 g * c.oc / c.ng * c.oh * c.ow +
112 a_fp += dst_data[map_index(dst_d, oidx)];
114 switch (p.eltwise_algorithm) {
115 case algorithm_undef:
118 a_fp = relu_fwd(a_fp, p.eltwise_alpha);
121 a_fp = tanh_fwd(a_fp);
124 a_fp = elu_fwd(a_fp, p.eltwise_alpha);
127 a_fp = square_fwd(a_fp);
130 a_fp = abs_fwd(a_fp);
133 a_fp = sqrt_fwd(a_fp);
136 a_fp = linear_fwd(a_fp, p.eltwise_alpha, p.eltwise_beta);
138 case eltwise_bounded_relu:
139 a_fp = bounded_relu_fwd(a_fp, p.eltwise_alpha);
141 case eltwise_soft_relu:
142 a_fp = soft_relu_fwd(a_fp);
144 case eltwise_logistic:
145 a_fp = logistic_fwd(a_fp);
148 a_fp = clamp_fwd(a_fp, p.eltwise_alpha, p.eltwise_beta);
151 assert(!"unknown alg_kind");
154 switch (p.depthwise_algorithm) {
155 case algorithm_undef:
157 case depthwise_scale_shift:
158 a_fp = scale_shift_fwd(a_fp, d_weights_data[g * c.oc / c.ng + oc], d_bias_data[g * c.oc / c.ng + oc]);
160 case depthwise_prelu:
161 a_fp = prelu_fwd(a_fp, d_weights_data[g * c.oc / c.ng + oc]);
163 default: assert(!"unknown alg_kind");
166 dst_data[map_index(dst_d, oidx)] = a_fp;
171 void compute_ref_binarization_fwd(const test_binary_convolution_params_t &p,
172 const memory::desc &src_md, const memory &src, const memory &weights, const memory &dst) {
173 auto src_data = (float*)src.get_data_handle();
174 auto weights_data = (float*)weights.get_data_handle();
175 auto dst_data = (uint8_t*)dst.get_data_handle();
177 const memory::desc src_d = src.get_primitive_desc().desc();
178 const memory::desc weights_d = weights.get_primitive_desc().desc();
179 const memory::desc dst_d = dst.get_primitive_desc().desc();
181 int N = src_md.data.ndims > 0 ? src_md.data.dims[0] : 1;
182 int C = src_md.data.ndims > 1 ? src_md.data.dims[1] : 1;
183 int H = src_md.data.ndims > 2 ? src_md.data.dims[2] : 1;
184 int W = src_md.data.ndims > 3 ? src_md.data.dims[3] : 1;
187 int CB = div_up(C, nbits);
189 int padded_ic = src_d.data.layout_desc.blocking.padding_dims[1];
190 int padded_oc = dst_d.data.layout_desc.blocking.padding_dims[1];
192 for (int n = 0; n < N; ++n) {
193 for (int cb = 0; cb < CB; ++cb) {
194 for (int h = 0; h < H; ++h) {
195 for (int w = 0; w < W; ++w) {
197 uint8_t bin_val = 0x00;
198 for (int c = cb * nbits, shift = 0; c < std::min(C, (cb + 1) * nbits); c++, shift++) {
199 int src_idx = n*padded_ic*H*W + c*H*W + h*W + w;
202 float s_val = src_data[map_index(src_d, src_idx)];
203 float w_val = weights_data[map_index(weights_d, wei_idx)];
205 auto bit = uint8_t((s_val > w_val) ? 0x01 : 0x00);
206 bin_val |= (bit << shift);
209 int dst_idx = n*padded_oc*H*W + cb*nbits*H*W + h*W + w;
210 dst_idx = map_index(dst_d, dst_idx);
211 dst_data[dst_idx / nbits] = bin_val;
218 class binary_convolution_forward_test : public ::testing::TestWithParam<test_binary_convolution_params_t>
223 test_binary_convolution_params_t p = ::testing::TestWithParam<test_binary_convolution_params_t>::GetParam();
225 ASSERT_TRUE(p.engine_kind == engine::kind::cpu);
226 ASSERT_EQ(p.aalgorithm, algorithm::binary_convolution_direct);
228 test_convolution_sizes_t cd = p.sizes;
230 auto eng = engine(p.engine_kind, 0);
231 auto aprop_kind = prop_kind::forward;
232 bool with_binarization = p.binarization_algorithm != algorithm_undef;
234 memory::data_type data_type_src = memory::data_type::bin;
235 memory::data_type data_type_wei = memory::data_type::bin;
236 memory::data_type data_type_bia = memory::data_type::f32;
237 memory::data_type data_type_dst = with_binarization ? memory::data_type::bin
238 : data_traits<float>::data_type;
240 auto c_src_desc = create_md({ cd.mb, cd.ic, cd.ih, cd.iw }, data_type_src, p.formats.src_format);
241 auto c_weights_desc = cd.ng > 1
242 ? create_md({ cd.ng, cd.oc / cd.ng, cd.ic / cd.ng, cd.kh, cd.kw }, data_type_wei, p.formats.weights_format)
243 : create_md({ cd.oc, cd.ic, cd.kh, cd.kw }, data_type_wei, p.formats.weights_format);
244 auto c_dst_desc = create_md({ cd.mb, cd.oc, cd.oh, cd.ow }, data_type_dst, p.formats.dst_format);
246 auto c_src = test_memory(c_src_desc, eng);
247 auto c_weights = test_memory(c_weights_desc, eng);
248 auto c_dst = test_memory(c_dst_desc, eng);
250 // Only true for dense format
251 if (with_binarization)
252 fill_data<uint8_t>(c_dst.get_size() / sizeof(uint8_t), (uint8_t*)c_dst.get().get_data_handle());
254 fill_data<float>(c_dst.get_size() / sizeof(float), (float*)c_dst.get().get_data_handle());
255 fill_data<uint8_t>(c_src.get_size() / sizeof(uint8_t), (uint8_t*)c_src.get().get_data_handle());
256 fill_data<uint8_t>(c_weights.get_size() / sizeof(uint8_t), (uint8_t*)c_weights.get().get_data_handle());
258 std::vector<ptrdiff_t> padR = {
259 right_padding(cd.ih, cd.oh, cd.kh, cd.padh, cd.strh, cd.dilh),
260 right_padding(cd.iw, cd.ow, cd.kw, cd.padw, cd.strw, cd.dilw)
263 auto bin_conv_desc = binary_convolution_forward::desc(aprop_kind, p.aalgorithm,
264 c_src_desc, c_weights_desc, c_dst_desc,
265 { cd.strh, cd.strw }, { cd.dilh, cd.dilw },
266 { cd.padh, cd.padw }, padR, p.pad_value);
268 mkldnn::post_ops ops;
273 if (p.eltwise_algorithm != algorithm_undef)
274 ops.append_eltwise(1.0, p.eltwise_algorithm, p.eltwise_alpha, p.eltwise_beta);
276 auto c_depthwise_weights_desc = create_md({ cd.oc }, data_type_bia, memory::x);
277 auto c_depthwise_bias_desc = create_md({ cd.oc }, data_type_bia, memory::x);
279 auto c_depthwise_weights = memory({c_depthwise_weights_desc, eng});
280 auto c_depthwise_bias = memory({c_depthwise_bias_desc, eng});
282 if (p.depthwise_algorithm != algorithm_undef) {
283 fill_data<float>(c_depthwise_weights.get_primitive_desc().get_size() / sizeof(float),
284 (float *)c_depthwise_weights.get_data_handle(), 1., true);
285 fill_data<float>(c_depthwise_bias.get_primitive_desc().get_size() / sizeof(float),
286 (float *)c_depthwise_bias.get_data_handle(), 1., true);
288 ops.append_depthwise(p.depthwise_algorithm, static_cast<const float*>(c_depthwise_weights.get_data_handle()),
289 static_cast<const float*>(c_depthwise_bias.get_data_handle()));
292 auto c_binarization_weights_desc = create_md({ cd.oc }, memory::data_type::f32, memory::x);
293 auto c_binarization_weights = memory({c_binarization_weights_desc, eng});
295 if (p.binarization_algorithm != algorithm_undef) {
296 fill_data<float>(c_binarization_weights.get_primitive_desc().get_size() / sizeof(float),
297 (float *)c_binarization_weights.get_data_handle(), 1., true);
299 ops.append_binarization(p.binarization_algorithm, static_cast<const float*>(c_binarization_weights.get_data_handle()));
302 mkldnn::primitive_attr attr;
303 attr.set_post_ops(ops);
305 auto bin_conv_primitive_desc = binary_convolution_forward::primitive_desc(bin_conv_desc, attr, eng);
307 auto bin_conv = binary_convolution_forward(bin_conv_primitive_desc, c_src.get(), c_weights.get(), c_dst.get());
309 if (with_binarization) {
310 auto c_dst_desc_ref = create_md({ cd.mb, cd.oc, cd.oh, cd.ow }, memory::data_type::f32, p.formats.dst_format);
311 auto c_dst_ref = test_memory(c_dst_desc_ref, eng);
313 std::vector<float> ref_dst_conv_data(c_dst_ref.get_size() / sizeof(float));
314 auto ref_conv_memory = memory(memory::primitive_desc(c_dst_desc_ref, eng), &ref_dst_conv_data[0]);
316 std::vector<uint8_t > ref_dst_data(c_dst.get_size() / sizeof(uint8_t));
317 auto ref_memory = memory(memory::primitive_desc(c_dst_desc, eng), &ref_dst_data[0]);
319 compute_ref_bin_conv_fwd(p, c_src_desc, c_weights_desc, c_dst_desc_ref,
320 c_src.get(), c_weights.get(), ref_conv_memory,
321 c_depthwise_weights, c_depthwise_bias);
323 compute_ref_binarization_fwd(p, c_dst_desc_ref, ref_conv_memory, c_binarization_weights, ref_memory);
325 std::vector<primitive> pipeline;
326 pipeline.push_back(bin_conv);
327 auto s = stream(stream::kind::lazy);
328 s.submit(pipeline).wait();
330 compare_data<uint8_t>(ref_memory, c_dst.get(), 0, true);
332 std::vector<float> ref_dst_data(c_dst.get_size() / sizeof(float));
333 memcpy(&ref_dst_data[0], (float*)c_dst.get().get_data_handle(), ref_dst_data.size() * sizeof(float));
334 auto ref_memory = memory(memory::primitive_desc(c_dst_desc, eng), &ref_dst_data[0]);
336 compute_ref_bin_conv_fwd(p, c_src_desc, c_weights_desc, c_dst_desc,
337 c_src.get(), c_weights.get(), ref_memory,
338 c_depthwise_weights, c_depthwise_bias);
340 std::vector<primitive> pipeline;
341 pipeline.push_back(bin_conv);
342 auto s = stream(stream::kind::lazy);
343 s.submit(pipeline).wait();
345 compare_data<float>(ref_memory, c_dst.get(), 1e-3);