Publishing 2019 R1 content
[platform/upstream/dldt.git] / inference-engine / thirdparty / mkl-dnn / tests / gtests / test_binary_convolution_dw_conv_forward_common.hpp
1 /*******************************************************************************
2 * Copyright 2019 Intel Corporation
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *     http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
16
17 #ifndef TEST_BINARY_CONVOLUTION_DW_CONV_FORWARD_COMMON_HPP
18 #define TEST_BINARY_CONVOLUTION_DW_CONV_FORWARD_COMMON_HPP
19
20 #include "mkldnn_test_common.hpp"
21 #include "gtest/gtest.h"
22 #include "math_utils.hpp"
23 #include "mkldnn.hpp"
24
25 using namespace mkldnn::impl::math;
26
27 namespace mkldnn {
28
29 void compute_ref_bin_conv_fwd(const test_binary_convolution_dw_conv_params_t &p,
30         const memory::desc &src_d,
31         const memory::desc &weights_d,
32         const memory::desc &dst_d,
33         const memory &src,
34         const memory &weights,
35         const memory &dst,
36         const memory &depthwise_weights,
37         const memory &depthwise_bias)
38 {
39     auto src_dims = src_d.data.dims;
40     auto dst_dims = dst_d.data.dims;
41     auto sizes = p.sizes;
42     test_convolution_sizes_t c = {(int)src_dims[0], 1, sizes.ic, (int)src_dims[2], (int)src_dims[3],
43                                   (int)dst_dims[1], (int)dst_dims[2], (int)dst_dims[3],
44                                   sizes.conv1_kh, sizes.conv1_kw, sizes.conv1_padh, sizes.conv1_padw, sizes.conv1_strh, sizes.conv1_strw};
45
46     float pad_value = -1.f;
47
48     uint8_t* src_data = (uint8_t*)src.get_data_handle();
49     uint8_t* weights_data = (uint8_t*)weights.get_data_handle();
50     float* dst_data = (float*)dst.get_data_handle();
51
52     float *d_weights_data = (float *)depthwise_weights.get_data_handle();
53     float *d_bias_data = (float *)depthwise_bias.get_data_handle();
54
55     int nbits = 8;
56
57     size_t padded_ic = src_d.data.layout_desc.blocking.padding_dims[1];
58     size_t padded_ic_w = weights_d.data.layout_desc.blocking.padding_dims[1];
59     size_t padded_oc_w = weights_d.data.layout_desc.blocking.padding_dims[0];
60
61     auto extract_bit = [](uint8_t val, uint8_t bit) -> uint8_t {
62         return (uint8_t) ((val >> bit) & 0x0001);
63     };
64
65     mkldnn::impl::parallel_nd(c.mb, c.ng, c.oc / c.ng, c.oh, c.ow,
66         [&](int n, int g, int oc, int oh, int ow) {
67             int32_t a = 0;
68             int roi = 0;
69             for (int ic = 0; ic < c.ic; ic++) {
70                 for (int kh = 0; kh < c.kh; kh++) {
71                     for (int kw = 0; kw < c.kw; kw++) {
72                         int ih = oh * c.strh - c.padh + kh * (1 + c.dilh);
73                         int iw = ow * c.strw - c.padw + kw * (1 + c.dilw);
74
75                         size_t iidx = n * padded_ic * c.ih * c.iw
76                                       + g * padded_ic / c.ng * c.ih * c.iw
77                                       + ic * c.ih * c.iw + ih * c.iw + iw;
78                         iidx = map_index(src_d, iidx);
79
80                         uint8_t s;
81                         if (ih < 0 || ih >= c.ih || iw < 0 || iw >= c.iw) {
82                             if (pad_value == 0.0f) {
83                                 continue;
84                             } else {
85                                 s = pad_value == 1.0f ? (uint8_t)1 : (uint8_t)0;
86                             }
87                         } else {
88                              s = extract_bit(src_data[iidx/nbits], (uint8_t)(iidx % nbits));
89                         }
90
91                         size_t widx = g * padded_oc_w / c.ng * padded_ic_w
92                                       / c.ng * c.kh * c.kw
93                                       + oc * padded_ic_w / c.ng * c.kh * c.kw
94                                       + ic * c.kh * c.kw + kh * c.kw + kw;
95                         widx = map_index(weights_d, widx);
96
97                         uint8_t w = extract_bit(weights_data[widx/nbits], (uint8_t)(widx % nbits));
98
99                         a += (int32_t)(s ^ w);
100
101                         roi++;
102                     }
103                 }
104             }
105
106             float a_fp = (float)(roi - 2*a);
107
108             size_t oidx = n * c.oc * c.oh * c.ow +
109                           g * c.oc / c.ng * c.oh * c.ow +
110                           oc * c.oh * c.ow +
111                           oh * c.ow +
112                           ow;
113
114             switch (p.eltwise_algorithm) {
115                 case algorithm_undef:
116                     break;
117                 case eltwise_relu:
118                     a_fp = relu_fwd(a_fp, p.eltwise_alpha);
119                     break;
120                 case eltwise_tanh:
121                     a_fp = tanh_fwd(a_fp);
122                     break;
123                 case eltwise_elu:
124                     a_fp = elu_fwd(a_fp, p.eltwise_alpha);
125                     break;
126                 case eltwise_square:
127                     a_fp = square_fwd(a_fp);
128                     break;
129                 case eltwise_abs:
130                     a_fp = abs_fwd(a_fp);
131                     break;
132                 case eltwise_sqrt:
133                     a_fp = sqrt_fwd(a_fp);
134                     break;
135                 case eltwise_linear:
136                     a_fp = linear_fwd(a_fp, p.eltwise_alpha, p.eltwise_beta);
137                     break;
138                 case eltwise_bounded_relu:
139                     a_fp = bounded_relu_fwd(a_fp, p.eltwise_alpha);
140                     break;
141                 case eltwise_soft_relu:
142                     a_fp = soft_relu_fwd(a_fp);
143                     break;
144                 case eltwise_logistic:
145                     a_fp = logistic_fwd(a_fp);
146                     break;
147                 case eltwise_clamp:
148                     a_fp = clamp_fwd(a_fp, p.eltwise_alpha, p.eltwise_beta);
149                     break;
150                 default:
151                     assert(!"unknown alg_kind");
152             }
153
154             switch (p.depthwise_algorithm) {
155                 case algorithm_undef:
156                     break;
157                 case depthwise_scale_shift:
158                     a_fp = scale_shift_fwd(a_fp, d_weights_data[g * c.oc / c.ng + oc], d_bias_data[g * c.oc / c.ng + oc]);
159                     break;
160                 case depthwise_prelu:
161                     a_fp = prelu_fwd(a_fp, d_weights_data[g * c.oc / c.ng + oc]);
162                     break;
163                 default: assert(!"unknown alg_kind");
164             }
165
166             dst_data[map_index(dst_d, oidx)] = a_fp;
167         }
168     );
169 }
170
171 void compute_ref_dw_conv_fwd(const test_binary_convolution_dw_conv_params_t &p,
172         const memory &src, const memory &weights, const memory &bias, const memory &dst,
173         const memory &depthwise_weights, const memory &depthwise_bias)
174 {
175     const memory::desc src_d = src.get_primitive_desc().desc();
176     const memory::desc weights_d = weights.get_primitive_desc().desc();
177     const memory::desc dst_d = dst.get_primitive_desc().desc();
178
179     auto src_dims = src_d.data.dims;
180     auto dst_dims = dst_d.data.dims;
181
182     int MB = src_dims[0];
183     int G = src_dims[1];
184     int IC = src_dims[1];
185     int IH = src_dims[2];
186     int IW = src_dims[3];
187     int OC = dst_dims[1];
188     int OH = dst_dims[2];
189     int OW = dst_dims[3];
190
191     int KH = p.sizes.conv2_kh;
192     int KW = p.sizes.conv2_kw;
193     int SH = p.sizes.conv2_strh;
194     int SW = p.sizes.conv2_strw;
195     int PH = p.sizes.conv2_padh;
196     int PW = p.sizes.conv2_padw;
197     int DH = 0;
198     int DW = 0;
199
200     float *src_data = (float *)src.get_data_handle();
201     float *weights_data = (float *)weights.get_data_handle();
202     float *bias_data = (float *)bias.get_data_handle();
203     float *dst_data = (float *)dst.get_data_handle();
204
205     float *d_weights_data = (float *)depthwise_weights.get_data_handle();
206     float *d_bias_data = (float *)depthwise_bias.get_data_handle();
207
208     mkldnn::impl::parallel_nd(MB, G, OC / G, OH, OW,
209         [&](int n, int g, int oc, int oh, int ow) {
210             int oidx = n * OC * OH * OW
211                        + g * OC / G * OH * OW
212                        + oc * OH * OW + oh * OW + ow;
213
214             float a = (float)0;
215
216             for (int ic = 0; ic < IC / G; ic++) {
217                 for (int kh = 0; kh < KH; kh++) {
218                     for (int kw = 0; kw < KW; kw++) {
219                         int iw = ow * SW
220                                  - PW + kw * (1 + DW);
221                         int ih = oh * SH
222                                  - PH + kh * (1 + DH);
223                         if (iw < 0 || iw >= IW) continue;
224                         if (ih < 0 || ih >= IH) continue;
225                         int iidx = n * IC * IH * IW
226                                    + g * IC / G * IH * IW
227                                    + ic * IH * IW + ih * IW + iw;
228                         int widx = g * OC / G * IC
229                                    / G * KH * KW
230                                    + oc * IC / G * KH * KW
231                                    + ic * KH * KW + kh * KW + kw;
232
233                         iidx = map_index(src_d, iidx);
234
235                         float s = src_data[iidx];
236                         float w = weights_data[map_index(weights_d, widx)];
237
238                         a += s * w;
239
240                     }
241                 }
242             }
243
244             float a_fp = (float)a;
245
246             a_fp += bias_data[G > 1 ? g : oc];
247
248             if (p.with_sum)
249                 a_fp += dst_data[map_index(dst_d, oidx)];
250
251             switch (p.eltwise_algorithm) {
252                 case algorithm_undef:
253                     break;
254                 case eltwise_relu:
255                     a_fp = relu_fwd(a_fp, p.eltwise_alpha);
256                     break;
257                 case eltwise_tanh:
258                     a_fp = tanh_fwd(a_fp);
259                     break;
260                 case eltwise_elu:
261                     a_fp = elu_fwd(a_fp, p.eltwise_alpha);
262                     break;
263                 case eltwise_square:
264                     a_fp = square_fwd(a_fp);
265                     break;
266                 case eltwise_abs:
267                     a_fp = abs_fwd(a_fp);
268                     break;
269                 case eltwise_sqrt:
270                     a_fp = sqrt_fwd(a_fp);
271                     break;
272                 case eltwise_linear:
273                     a_fp = linear_fwd(a_fp, p.eltwise_alpha, p.eltwise_beta);
274                     break;
275                 case eltwise_bounded_relu:
276                     a_fp = bounded_relu_fwd(a_fp, p.eltwise_alpha);
277                     break;
278                 case eltwise_soft_relu:
279                     a_fp = soft_relu_fwd(a_fp);
280                     break;
281                 case eltwise_logistic:
282                     a_fp = logistic_fwd(a_fp);
283                     break;
284                 case eltwise_clamp:
285                     a_fp = clamp_fwd(a_fp, p.eltwise_alpha, p.eltwise_beta);
286                     break;
287                 default:
288                     assert(!"unknown alg_kind");
289             }
290
291             switch (p.depthwise_algorithm) {
292                 case algorithm_undef:
293                     break;
294                 case depthwise_scale_shift:
295                     a_fp = scale_shift_fwd(a_fp, d_weights_data[g * OC / G + oc], d_bias_data[g * OC / G + oc]);
296                     break;
297                 case depthwise_prelu:
298                     a_fp = prelu_fwd(a_fp, d_weights_data[g * OC / G + oc]);
299                     break;
300                 default: assert(!"unknown alg_kind");
301             }
302
303             dst_data[map_index(dst_d, oidx)] = (float)a_fp;
304         }
305     );
306 }
307
308 void compute_ref_binarization_fwd(const test_binary_convolution_dw_conv_params_t &p,
309     const memory::desc &src_md, const memory &src, const memory &weights, const memory &dst) {
310     auto src_data = (float*)src.get_data_handle();
311     auto weights_data = (float*)weights.get_data_handle();
312     auto dst_data = (uint8_t*)dst.get_data_handle();
313
314     const memory::desc src_d = src.get_primitive_desc().desc();
315     const memory::desc weights_d = weights.get_primitive_desc().desc();
316     const memory::desc dst_d = dst.get_primitive_desc().desc();
317
318     int N = src_md.data.ndims > 0 ? src_md.data.dims[0] : 1;
319     int C = src_md.data.ndims > 1 ? src_md.data.dims[1] : 1;
320     int H = src_md.data.ndims > 2 ? src_md.data.dims[2] : 1;
321     int W = src_md.data.ndims > 3 ? src_md.data.dims[3] : 1;
322
323     int nbits = 8;
324     int CB = div_up(C, nbits);
325
326     int padded_ic = src_d.data.layout_desc.blocking.padding_dims[1];
327     int padded_oc = dst_d.data.layout_desc.blocking.padding_dims[1];
328
329     for (int n = 0; n < N; ++n) {
330         for (int cb = 0; cb < CB; ++cb) {
331             for (int h = 0; h < H; ++h) {
332                 for (int w = 0; w < W; ++w) {
333
334                     uint8_t bin_val = 0x00;
335                     for (int c = cb * nbits, shift = 0; c < std::min(C, (cb + 1) * nbits); c++, shift++) {
336                         int src_idx = n*padded_ic*H*W + c*H*W + h*W + w;
337                         int wei_idx = c;
338
339                         float s_val = src_data[map_index(src_d, src_idx)];
340                         float w_val = weights_data[map_index(weights_d, wei_idx)];
341
342                         auto bit = uint8_t((s_val > w_val) ? 0x01 : 0x00);
343                         bin_val |= (bit << shift);
344                     }
345
346                     int dst_idx = n*padded_oc*H*W + cb*nbits*H*W + h*W + w;
347                     dst_idx = map_index(dst_d, dst_idx);
348                     dst_data[dst_idx / nbits] = bin_val;
349                 }
350             }
351         }
352     }
353 }
354
355 class binary_convolution_forward_test : public ::testing::TestWithParam<test_binary_convolution_dw_conv_params_t>
356 {
357 protected:
358     virtual void SetUp()
359     {
360         test_binary_convolution_dw_conv_params_t p = ::testing::TestWithParam<test_binary_convolution_dw_conv_params_t>::GetParam();
361
362         ASSERT_TRUE(p.engine_kind == engine::kind::cpu);
363         ASSERT_EQ(p.aalgorithm, algorithm::binary_convolution_direct);
364
365         test_convolution_dw_conv_sizes_t cd = p.sizes;
366
367         auto eng = engine(p.engine_kind, 0);
368         auto aprop_kind = prop_kind::forward;
369         bool with_binarization = p.binarization_algorithm != algorithm_undef;
370 //        int nbits = 8;
371
372         memory::data_type data_type_bin_conv_src = memory::data_type::bin;
373         memory::data_type data_type_bin_conv_wei = memory::data_type::bin;
374         memory::data_type data_type_bin_conv_bia = data_traits<float>::data_type;
375         memory::data_type data_type_bin_conv_dst = data_traits<float>::data_type;
376
377         memory::data_type data_type_dw_conv_wei = data_traits<float>::data_type;
378         memory::data_type data_type_dw_conv_bia = data_traits<float>::data_type;
379         memory::data_type data_type_dw_conv_dst = with_binarization ? memory::data_type::bin
380                                                                     : data_traits<float>::data_type;
381
382         int bin_conv_oh = (cd.ih - ((cd.conv1_kh - 1) + 1) + 2 * cd.conv1_padh) / cd.conv1_strh + 1;
383         int bin_conv_ow = (cd.iw - ((cd.conv1_kw - 1) + 1) + 2 * cd.conv1_padw) / cd.conv1_strw + 1;
384
385         int dw_conv_oh = (bin_conv_oh - ((cd.conv2_kh - 1) + 1) + 2 * cd.conv2_padh) / cd.conv2_strh + 1;
386         int dw_conv_ow = (bin_conv_ow - ((cd.conv2_kw - 1) + 1) + 2 * cd.conv2_padw) / cd.conv2_strw + 1;
387
388         std::vector<ptrdiff_t> bin_conv_padR = { cd.conv1_padh, cd.conv1_padw };
389         bin_conv_padR[0] += dw_conv_oh - bin_conv_oh;
390         bin_conv_padR[1] += dw_conv_ow - bin_conv_ow;
391
392         auto bin_conv_src_desc = create_md({ cd.mb, cd.ic, cd.ih, cd.iw }, data_type_bin_conv_src, p.formats.src_format);
393         auto bin_conv_weights_desc = create_md({ cd.conv1_oc, cd.ic, cd.conv1_kh, cd.conv1_kw }, data_type_bin_conv_wei, p.formats.conv1_weights_format);
394         auto bin_conv_dst_desc = create_md({ cd.mb, cd.conv1_oc, dw_conv_oh, dw_conv_ow }, data_type_bin_conv_dst, p.formats.dst_format);
395
396         auto bin_conv_src = test_memory(bin_conv_src_desc, eng);
397         auto bin_conv_weights = test_memory(bin_conv_weights_desc, eng);
398
399         fill_data<uint8_t>(bin_conv_src.get_size() / sizeof(uint8_t), (uint8_t*)bin_conv_src.get().get_data_handle());
400         fill_data<uint8_t>(bin_conv_weights.get_size() / sizeof(uint8_t), (uint8_t*)bin_conv_weights.get().get_data_handle());
401
402         auto dw_conv_weights_desc = create_md({ cd.conv2_oc, 1, 1, cd.conv2_kh, cd.conv2_kw }, data_type_dw_conv_wei, p.formats.conv2_weights_format);
403         auto dw_conv_dst_desc = create_md({ cd.mb, cd.conv2_oc, dw_conv_oh, dw_conv_ow }, data_type_dw_conv_dst, p.formats.dst_format);
404         auto dw_conv_bias_desc = create_md({ cd.conv2_oc }, data_type_dw_conv_bia, p.formats.conv2_bias_format);
405
406         auto dw_conv_weights = test_memory(dw_conv_weights_desc, eng);
407         auto dw_conv_bias = test_memory(dw_conv_bias_desc, eng);
408         auto dw_conv_dst = test_memory(dw_conv_dst_desc, eng);
409
410         if (with_binarization)
411             fill_data<uint8_t>(dw_conv_dst.get_size() / sizeof(uint8_t), (uint8_t*)dw_conv_dst.get().get_data_handle());
412         else
413             fill_data<float>(dw_conv_dst.get_size() / sizeof(float), (float*)dw_conv_dst.get().get_data_handle());
414
415         fill_data<float>(dw_conv_weights.get_size() / sizeof(float), (float*)dw_conv_weights.get().get_data_handle());
416         fill_data<float>(dw_conv_bias.get_size() / sizeof(float), (float*)dw_conv_bias.get().get_data_handle());
417
418         auto bin_conv_desc = binary_convolution_forward::desc(aprop_kind, p.aalgorithm,
419                                                               bin_conv_src_desc, bin_conv_weights_desc, bin_conv_dst_desc,
420                                                               { cd.conv1_strh, cd.conv1_strw }, { 0, 0 },
421                                                               { cd.conv1_padh, cd.conv1_padw }, bin_conv_padR, -1.f);
422
423         mkldnn::post_ops bin_conv_post_ops;
424         if (p.eltwise_algorithm != algorithm_undef)
425             bin_conv_post_ops.append_eltwise(1.0, p.eltwise_algorithm, p.eltwise_alpha, p.eltwise_beta);
426
427         auto bin_conv_depthwise_weights_desc = create_md({ cd.conv1_oc }, data_type_bin_conv_bia, memory::x);
428         auto bin_conv_depthwise_bias_desc = create_md({ cd.conv1_oc }, data_type_bin_conv_bia, memory::x);
429         auto bin_conv_depthwise_weights = memory({bin_conv_depthwise_weights_desc, eng});
430         auto bin_conv_depthwise_bias = memory({bin_conv_depthwise_bias_desc, eng});
431
432         if (p.depthwise_algorithm != algorithm_undef) {
433             fill_data<float>(bin_conv_depthwise_weights.get_primitive_desc().get_size() / sizeof(float),
434                              (float *)bin_conv_depthwise_weights.get_data_handle(), 1., true);
435             fill_data<float>(bin_conv_depthwise_bias.get_primitive_desc().get_size() / sizeof(float),
436                              (float *)bin_conv_depthwise_bias.get_data_handle(), 1., true);
437
438             bin_conv_post_ops.append_depthwise(p.depthwise_algorithm, static_cast<const float*>(bin_conv_depthwise_weights.get_data_handle()),
439                                                static_cast<const float*>(bin_conv_depthwise_bias.get_data_handle()));
440         }
441
442         bin_conv_post_ops.append_dw_conv(bin_conv_oh, bin_conv_ow, cd.conv2_kh, cd.conv2_kw, cd.conv2_strh, cd.conv2_strw,
443                                          static_cast<const float*>(dw_conv_weights.get().get_data_handle()),
444                                          static_cast<const float*>(dw_conv_bias.get().get_data_handle()));
445
446         if (p.with_sum)
447             bin_conv_post_ops.append_sum();
448
449         if (p.eltwise_algorithm != algorithm_undef)
450             bin_conv_post_ops.append_eltwise(1.0, p.eltwise_algorithm, p.eltwise_alpha, p.eltwise_beta);
451
452         auto dw_conv_depthwise_weights_desc = create_md({ cd.conv2_oc }, data_type_bin_conv_bia, memory::x);
453         auto dw_conv_depthwise_bias_desc = create_md({ cd.conv2_oc }, data_type_bin_conv_bia, memory::x);
454         auto dw_conv_depthwise_weights = memory({dw_conv_depthwise_weights_desc, eng});
455         auto dw_conv_depthwise_bias = memory({dw_conv_depthwise_bias_desc, eng});
456
457         if (p.depthwise_algorithm != algorithm_undef) {
458             fill_data<float>(dw_conv_depthwise_weights.get_primitive_desc().get_size() / sizeof(float),
459                              (float *)dw_conv_depthwise_weights.get_data_handle(), 1., true);
460             fill_data<float>(dw_conv_depthwise_bias.get_primitive_desc().get_size() / sizeof(float),
461                              (float *)dw_conv_depthwise_bias.get_data_handle(), 1., true);
462
463             bin_conv_post_ops.append_depthwise(p.depthwise_algorithm, static_cast<const float*>(dw_conv_depthwise_weights.get_data_handle()),
464                                  static_cast<const float*>(dw_conv_depthwise_bias.get_data_handle()));
465         }
466
467         auto dw_conv_binarization_weights_desc = create_md({ cd.conv2_oc }, memory::data_type::f32, memory::x);
468         auto dw_conv_binarization_weights = memory({dw_conv_binarization_weights_desc, eng});
469
470         if (p.binarization_algorithm != algorithm_undef) {
471             fill_data<float>(dw_conv_binarization_weights.get_primitive_desc().get_size() / sizeof(float),
472                              (float *)dw_conv_binarization_weights.get_data_handle(), 0.f, p.sizes.conv2_oc * p.sizes.conv2_kh * p.sizes.conv2_kw);
473
474             bin_conv_post_ops.append_binarization(p.binarization_algorithm, static_cast<const float*>(dw_conv_binarization_weights.get_data_handle()));
475         }
476
477         mkldnn::primitive_attr bin_conv_attr;
478         bin_conv_attr.set_post_ops(bin_conv_post_ops);
479
480         auto bin_conv_primitive_desc = binary_convolution_forward::primitive_desc(bin_conv_desc, bin_conv_attr, eng);
481
482         auto bin_conv = binary_convolution_forward(bin_conv_primitive_desc, bin_conv_src.get(), bin_conv_weights.get(), dw_conv_dst.get());
483
484         auto bin_conv_dst_desc_ref = create_md({ cd.mb, cd.conv1_oc, bin_conv_oh, bin_conv_ow }, data_type_bin_conv_dst, p.formats.dst_format);
485         auto ref_bin_conv_dst = test_memory(bin_conv_dst_desc_ref, eng);
486         compute_ref_bin_conv_fwd(p, bin_conv_src_desc, bin_conv_weights_desc, bin_conv_dst_desc_ref,
487                                  bin_conv_src.get(), bin_conv_weights.get(), ref_bin_conv_dst.get(),
488                                  bin_conv_depthwise_weights, bin_conv_depthwise_bias);
489
490         if (with_binarization) {
491             auto ref_dw_conv_dst_desc = create_md({ cd.mb, cd.conv2_oc, dw_conv_oh, dw_conv_ow }, memory::data_type::f32, p.formats.dst_format);
492             auto ref_dw_conv_dst = test_memory(ref_dw_conv_dst_desc, eng);
493
494             compute_ref_dw_conv_fwd(p, ref_bin_conv_dst.get(), dw_conv_weights.get(), dw_conv_bias.get(),
495                                     ref_dw_conv_dst.get(),
496                                     dw_conv_depthwise_weights, dw_conv_depthwise_bias);
497
498             auto ref_binarization_dst = test_memory(dw_conv_dst_desc, eng);
499
500             compute_ref_binarization_fwd(p, ref_dw_conv_dst_desc, ref_dw_conv_dst.get(), dw_conv_binarization_weights, ref_binarization_dst.get());
501
502             std::vector<primitive> pipeline;
503             pipeline.push_back(bin_conv);
504             auto s = stream(stream::kind::lazy);
505             s.submit(pipeline).wait();
506
507             compare_data<uint8_t>(ref_binarization_dst.get(), dw_conv_dst.get(), 0, true);
508         } else {
509             auto ref_dw_conv_dst = test_memory(dw_conv_dst_desc, eng);
510             memcpy((float *) ref_dw_conv_dst.get().get_data_handle(), (float *) dw_conv_dst.get().get_data_handle(),
511                    ref_dw_conv_dst.get_size());
512             compute_ref_dw_conv_fwd(p, ref_bin_conv_dst.get(), dw_conv_weights.get(), dw_conv_bias.get(),
513                                     ref_dw_conv_dst.get(),
514                                     dw_conv_depthwise_weights, dw_conv_depthwise_bias);
515
516             std::vector<primitive> pipeline;
517             pipeline.push_back(bin_conv);
518             auto s = stream(stream::kind::lazy);
519             s.submit(pipeline).wait();
520
521             compare_data<float>(ref_dw_conv_dst.get(), dw_conv_dst.get(), 1e-3);
522         }
523     }
524 };
525
526 }
527
528 #endif