Publishing 2019 R1 content
[platform/upstream/dldt.git] / inference-engine / thirdparty / mkl-dnn / tests / gtests / mkldnn_test_common.hpp
1 /*******************************************************************************
2 * Copyright 2016-2018 Intel Corporation
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *     http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
16
17 #ifndef MKLDNN_TEST_COMMON_HPP
18 #define MKLDNN_TEST_COMMON_HPP
19
20 #include <limits>
21 #include <numeric>
22 #include <vector>
23 #include <cmath>
24 #include <stdint.h>
25
26 #include "gtest/gtest.h"
27
28 #if defined(_MSC_VER) && !defined(__clang__) && !defined(__INTEL_COMPILER)
29 #define collapse(x)
30 #endif
31
32 #include "mkldnn.hpp"
33
34 #include "src/common/mkldnn_thread.hpp"
35
36 template <typename data_t> struct data_traits { };
37 template <> struct data_traits<float> {
38     static const auto data_type = mkldnn::memory::data_type::f32;
39 };
40 template <> struct data_traits<uint8_t> {
41     static const auto data_type = mkldnn::memory::data_type::u8;
42 };
43 template <> struct data_traits<int8_t> {
44     static const auto data_type = mkldnn::memory::data_type::s8;
45 };
46 template <> struct data_traits<int16_t> {
47     static const auto data_type = mkldnn::memory::data_type::s16;
48 };
49 template <> struct data_traits<int32_t> {
50     static const auto data_type = mkldnn::memory::data_type::s32;
51 };
52
53 template <typename T> inline void assert_eq(T a, T b);
54 template <> inline void assert_eq<float>(float a, float b) {
55     ASSERT_FLOAT_EQ(a, b);
56 }
57
58 template <typename data_t> inline data_t out_round(float x,
59         mkldnn_round_mode_t rmode = mkldnn_round_nearest)
60 { return (data_t)(rmode == mkldnn_round_down ? floorf(x) : nearbyintf(x)); }
61 template <> inline float out_round<float>(float x, mkldnn_round_mode_t rmode)
62 { (void)rmode; return x; }
63
64 template <typename data_t, typename out_t>
65 out_t saturate(const out_t &x) {
66     out_t v = x;
67     if (v <= std::numeric_limits<data_t>::min())
68         v = std::numeric_limits<data_t>::min();
69     if (v > std::numeric_limits<data_t>::max())
70         v = std::numeric_limits<data_t>::max();
71     return v;
72 }
73
74 inline int right_padding(int i, int o, int k, int p, int s, int d = 0) {
75     return (o - 1) * s + (k - 1) * (d + 1) - (p + i - 1);
76 }
77
78 template <typename data_t> struct acc_t { typedef data_t type; };
79 template<> struct acc_t<int8_t> { typedef int type; };
80 template<> struct acc_t<uint8_t> { typedef int type; };
81
82 inline size_t map_index(const mkldnn::memory::desc &md, size_t index,
83     bool with_padding = true) {
84     using fmt = mkldnn::memory::format;
85
86     const fmt fwd_weights_g_qvnni = fmt::gOIhw8i16o2i;
87     const fmt fwd_weights_qvnni = fmt::OIhw8i16o2i;
88     const fmt bwd_weights_g_qvnni = fmt::gOIhw8o16i2o;
89     const fmt bwd_weights_qvnni = fmt::OIhw8o16i2o;
90
91     const fmt fwd_weights_g_vnni = fmt::gOIhw4i16o4i;
92     const fmt fwd_weights_vnni = fmt::OIhw4i16o4i;
93
94     const bool with_groups = (md.data.format == fwd_weights_g_qvnni)
95                           || (md.data.format == bwd_weights_g_qvnni)
96                           || (md.data.format == fwd_weights_g_vnni);
97
98     const bool qvnni = (md.data.format == fwd_weights_g_qvnni)
99                     || (md.data.format == bwd_weights_g_qvnni)
100                     || (md.data.format == fwd_weights_qvnni)
101                     || (md.data.format == bwd_weights_qvnni);
102
103     const bool vnni = (md.data.format == fwd_weights_g_vnni)
104                    || (md.data.format == fwd_weights_vnni);
105
106     const bool fwd_wei = (md.data.format == fwd_weights_g_qvnni)
107                       || (md.data.format == fwd_weights_qvnni)
108                       || (md.data.format == fwd_weights_g_vnni)
109                       || (md.data.format == fwd_weights_vnni);
110
111     const bool bwd_wei = (md.data.format == bwd_weights_g_qvnni)
112                       || (md.data.format == bwd_weights_qvnni);
113
114     const int ndims = md.data.ndims;
115     const ptrdiff_t *dims = md.data.dims;
116     const ptrdiff_t *pdims = md.data.layout_desc.blocking.padding_dims;
117     const ptrdiff_t *optd = md.data.layout_desc.blocking.offset_padding_to_data;
118
119     auto *strides_block = md.data.layout_desc.blocking.strides[0];
120     auto *strides_within_block = md.data.layout_desc.blocking.strides[1];
121
122     size_t ph_index = 0;
123     size_t oc_lb = 0, ic_sb = 0,
124         oc_sb = 0, ic_lb = 0;
125
126     for (int rd = 0; rd < ndims; ++rd) {
127         int d = ndims - rd - 1;
128
129         EXPECT_LE(dims[d], pdims[d]);
130
131         int cur_dim = with_padding ? pdims[d] : dims[d];
132         EXPECT_GT(cur_dim, 0);
133         int cur_block = md.data.layout_desc.blocking.block_dims[d];
134
135         size_t pos_d = /*static_cast<ssize_t>*/(index % cur_dim);
136         EXPECT_GE(optd[d], 0);
137         size_t cur_pos = optd[d] + pos_d;
138
139         size_t cur_pos_block = cur_pos / cur_block;
140         size_t cur_pos_within_block = cur_pos % cur_block;
141
142         if (d == (with_groups + 0)) {
143             if (qvnni) { oc_lb = pos_d % 16;  oc_sb = pos_d % 2; }
144             else  if (vnni) { oc_lb = pos_d % 16; }
145         }
146         if (d == (with_groups + 1)) {
147             if (qvnni) { ic_sb = pos_d % 2; ic_lb = pos_d % 16; }
148             else if (vnni) { ic_sb = pos_d % 4; }
149         }
150         ph_index += cur_pos_block*strides_block[d];
151         ph_index += cur_pos_within_block*strides_within_block[d];
152
153         index /= cur_dim;
154     }
155     int scale = (vnni) ? 3 : 1;
156     if (fwd_wei) {
157         //ph_index += -16 * ic_2 + oc_16 + ic_2;
158         ph_index += scale * oc_lb + ic_sb;
159         EXPECT_GE(ph_index, 16 * ic_sb);
160         ph_index -= 16 * ic_sb;
161     } else
162         if (bwd_wei) {
163             //ph_index += -16 * oc_2 + ic_16 + oc_2;
164             ph_index += ic_lb + oc_sb;
165             EXPECT_GE(ph_index, 16 * oc_sb);
166             ph_index -= 16 * oc_sb;
167         }
168     ph_index += md.data.layout_desc.blocking.offset_padding;
169
170     return ph_index;
171 }
172
173 #define MAX_NDIMS 12
174 // check_zero_tail - check on zero or set to zero padded memory
175 template <typename data_t>
176 void check_zero_tail(int set_zero_flag, mkldnn::memory &src) {
177
178     data_t *src_data = (data_t *)src.get_data_handle();
179
180     const mkldnn::memory::desc src_d = src.get_primitive_desc().desc();
181     const int ndims = src_d.data.ndims;
182     const ptrdiff_t *dims = src_d.data.dims;
183     const ptrdiff_t *pdims = src_d.data.layout_desc.blocking.padding_dims;
184
185     size_t idx[MAX_NDIMS] = {}, str[MAX_NDIMS] = {};
186     size_t nelems = 1;
187     int tail_flag = 0;
188     for (int i = 0; i < ndims; ++i) {
189         if (dims[ndims-i-1] != pdims[ndims-i-1]) tail_flag = 1;
190         nelems *= pdims[ndims-i-1];
191         idx[i] = 0;
192         str[i] = (i==0) ? 1 : str[i-1] * pdims[ndims-i];
193     }
194     if (tail_flag == 0) return;
195
196     for (size_t i = 0; i < nelems; ++i) {
197         size_t off = 0;
198         bool flag = 0;
199         for (int j = 0; j < ndims; ++j) {
200             off += idx[j] * str[j];
201             if (idx[j] >= (size_t)dims[ndims-j-1]) flag = 1;
202         }
203         if (flag == 1) {
204             size_t blk_off = map_index(src_d,off);
205             if (set_zero_flag) {
206                 src_data[blk_off] = 0.0;
207             } else {
208                 EXPECT_EQ(src_data[blk_off], 0.0) << " blk_off = " << blk_off
209                 << "off = " << off;
210             }
211         }
212         /*Update idx*/
213         for (int j = 0; j < ndims; ++j) {
214             idx[j] ++;
215             if (idx[j] < (size_t)pdims[ndims-j-1]) break;
216             idx[j] = 0;
217         }
218     }
219 }
220
221 inline mkldnn::memory::desc create_md(mkldnn::memory::dims dims,
222         mkldnn::memory::data_type data_type, mkldnn::memory::format fmt) {
223     using f = mkldnn::memory::format;
224     size_t ndims = 0;
225
226     switch (fmt) {
227     case f::x:
228         ndims = 1; break;
229     case f::nc:
230     case f::oi:
231     case f::io:
232         ndims = 2; break;
233     case f::nchw:
234     case f::nhwc:
235     case f::chwn:
236     case f::nChw8c:
237     case f::nChw16c:
238     case f::oihw:
239     case f::hwio:
240     case f::iohw:
241     case f::oIhw8i:
242     case f::oIhw16i:
243     case f::OIhw8i8o:
244     case f::OIhw16i16o:
245     case f::OIhw8i16o2i:
246     case f::OIhw8o16i2o:
247     case f::OIhw4i16o4i:
248     case f::OIhw8o8i:
249     case f::OIhw16o16i:
250     case f::IOhw16o16i:
251     case f::Ohwi8o:
252     case f::Ohwi16o:
253     case f::OhIw8o4i:
254     case f::OIhw4i16o4i_s8s8:
255     case f::OhIw8o4i_s8s8:
256     case f::OhIw8o32i:
257     case f::OhIw16o32i:
258         ndims = 4; break;
259     case f::ncdhw:
260     case f::ndhwc:
261     case f::nCdhw8c:
262     case f::nCdhw16c:
263     case f::dhwio:
264     case f::oidhw:
265     case f::goihw:
266     case f::hwigo:
267     case f::giohw:
268     case f::oIdhw8i:
269     case f::oIdhw16i:
270     case f::OIdhw8i8o:
271     case f::OIdhw16i16o:
272     case f::OIdhw8o8i:
273     case f::OIdhw16o16i:
274     case f::gOhwi8o:
275     case f::Goihw8g:
276     case f::Goihw16g:
277     case f::gOhwi16o:
278     case f::gOIhw8i8o:
279     case f::gOIhw16i16o:
280     case f::gOIhw8i16o2i:
281     case f::gOIhw8o16i2o:
282     case f::gOIhw4i16o4i:
283     case f::gOIhw8o8i:
284     case f::gOIhw16o16i:
285     case f::gIOhw16o16i:
286     case f::gOhIw8o4i:
287     case f::Goihw16g_s8s8:
288         ndims = 5; break;
289     case f::gOIdhw8i8o:
290     case f::gOIdhw16i16o:
291     case f::gOIdhw8o8i:
292     case f::gOIdhw16o16i:
293     case f::gOdhwi16o:
294     case f::goidhw:
295         ndims = 6; break;
296     case f::format_undef:
297         ndims = 0; break;
298     case f::any:
299         return mkldnn::memory::desc(dims, data_type, fmt);
300     default: EXPECT_TRUE(false) << "test does not support format: " << int(fmt);
301     }
302
303     EXPECT_EQ(dims.size(), ndims) << "dims and format are inconsistent";
304
305     return mkldnn::memory::desc(dims, data_type, fmt);
306 }
307
308 template <typename data_t>
309 static inline data_t set_value(size_t index, data_t mean, data_t deviation,
310         double sparsity)
311 {
312     if (data_traits<data_t>::data_type == mkldnn::memory::data_type::f32) {
313         const size_t group_size = (size_t)(1. / sparsity);
314         const size_t group = index / group_size;
315         const size_t in_group = index % group_size;
316         const bool fill = in_group == ((group % 1637) % group_size);
317         return fill ? static_cast<data_t>(mean + deviation * sinf(float(index % 37)))
318             : data_t{0};
319     } else if (data_traits<data_t>::data_type == mkldnn::memory::data_type::s32
320         || data_traits<data_t>::data_type == mkldnn::memory::data_type::s16
321         || data_traits<data_t>::data_type == mkldnn::memory::data_type::s8) {
322         return data_t(rand() % 21 - 10);
323     } else if (data_traits<data_t>::data_type == mkldnn::memory::data_type::u8) {
324         return data_t(rand() % 17);
325     } else {
326         return data_t(0);
327     }
328 }
329
330 template <typename data_t>
331 static void fill_data(const size_t size, data_t *data, data_t mean,
332         data_t deviation, double sparsity = 1.)
333 {
334     mkldnn::impl::parallel_nd((ptrdiff_t)size, [&](ptrdiff_t n) {
335             data[n] = set_value<data_t>(n, mean, deviation, sparsity);
336     });
337 }
338
339 template <typename data_t>
340 static void fill_data(const size_t size, data_t *data, double sparsity = 1.,
341         bool init_negs = false)
342 {
343     mkldnn::impl::parallel_nd((ptrdiff_t)size, [&](ptrdiff_t n) {
344         data[n] = set_value<data_t>(n, data_t(1), data_t(2e-1), sparsity);
345
346         if (init_negs && n%4 == 0U)
347             data[n] = static_cast<data_t>(-data[n]); // weird for unsigned types!
348     });
349 }
350
351 int div_up(const int a, const int b) {
352     return (a + b - 1) / b;
353 }
354
355 template <typename data_t>
356 static void compare_data(mkldnn::memory& ref, mkldnn::memory& dst,
357         data_t threshold = (data_t)1e-4, bool isBinary = false)
358 {
359     using data_type = mkldnn::memory::data_type;
360
361     ASSERT_TRUE(data_traits<data_t>::data_type == data_type::f32 ||
362                 data_traits<data_t>::data_type == data_type::s32 ||
363                 data_traits<data_t>::data_type == data_type::u8);
364
365     /* Note: size_t incompatible with MSVC++ */
366     auto ref_desc = ref.get_primitive_desc().desc();
367     auto dst_desc = dst.get_primitive_desc().desc();
368
369     ASSERT_TRUE(ref_desc.data.ndims == dst_desc.data.ndims);
370
371     auto ndims = ref_desc.data.ndims;
372
373     for (auto d = 0; d < ndims; ++d) {
374         ASSERT_TRUE(ref_desc.data.dims[d] == dst_desc.data.dims[d]);
375     }
376
377     auto dims = ref_desc.data.dims;
378
379     ptrdiff_t num = 1;
380     for (auto d = 0; d < ndims; ++d) {
381         if (isBinary && d == 1) {
382             num *= div_up(dims[d], 8);
383         } else {
384             num *= dims[d];
385         }
386     }
387
388     data_t *ref_data = (data_t *)ref.get_data_handle();
389     data_t *dst_data = (data_t *)dst.get_data_handle();
390
391     mkldnn::impl::parallel_nd(num, [&](ptrdiff_t i) {
392         int divider = isBinary ? 8 : 1;
393
394         data_t ref = ref_data[map_index(ref_desc, i) / divider];
395         data_t got = dst_data[map_index(dst_desc, i) / divider];
396
397         if (data_traits<data_t>::data_type == data_type::f32) {
398             data_t diff = got - ref;
399             data_t e = (std::abs(ref) > threshold) ? diff / ref : diff;
400             EXPECT_NEAR(e, (data_t) 0.0, threshold)
401                                 << "Index: " << i << " Total: " << num;
402         } else {
403             EXPECT_EQ(ref, got) << "Index: " << i << " Total: " << num;
404         }
405     });
406 }
407
408 inline const char *query_impl_info(const_mkldnn_primitive_desc_t pd) {
409     const char *str;
410     mkldnn_primitive_desc_query(pd, mkldnn_query_impl_info_str, 0, &str);
411     return str;
412 };
413
414 mkldnn_status_t get_conv_impl_status(const_mkldnn_primitive_desc_t pd, const char *match_str){
415     const char* conv_str = query_impl_info(pd);
416
417     if( strstr(conv_str, match_str) != NULL)
418         return mkldnn_status_t::mkldnn_success;
419     return mkldnn_status_t::mkldnn_unimplemented;
420 };
421
422 struct test_convolution_sizes_t {
423     test_convolution_sizes_t(
424         int mb,
425         int ng,
426         int ic, int ih, int iw,
427         int oc, int oh, int ow,
428         int kh, int kw,
429         int padh, int padw,
430         int strh, int strw,
431         int dilh=0, int dilw=0
432     ) :
433         mb(mb),
434         ng(ng),
435         ic(ic), ih(ih), iw(iw),
436         oc(oc), oh(oh), ow(ow),
437         kh(kh), kw(kw),
438         padh(padh), padw(padw),
439         strh(strh), strw(strw),
440         dilh(dilh), dilw(dilw) {}
441     int mb;
442     int ng;
443     int ic, ih, iw;
444     int oc, oh, ow;
445     int kh, kw;
446     int padh, padw;
447     int strh, strw;
448     int dilh, dilw;
449 };
450
451 struct test_convolution_sizes_t_3d {
452     test_convolution_sizes_t_3d(
453         int mb,
454         int ng,
455         int ic, int id, int ih, int iw,
456         int oc, int od, int oh, int ow,
457         int kd, int kh, int kw,
458         int padd, int padh, int padw,
459         int strd, int strh, int strw,
460         int dild=0, int dilh=0, int dilw=0
461     ) :
462         mb(mb),
463         ng(ng),
464         ic(ic), id(id), ih(ih), iw(iw),
465         oc(oc), od(od), oh(oh), ow(ow),
466         kd(kd), kh(kh), kw(kw),
467         padd(padd), padh(padh), padw(padw),
468         strd(strd), strh(strh), strw(strw),
469         dild(dild), dilh(dilh), dilw(dilw) {}
470     int mb;
471     int ng;
472     int ic, id, ih, iw;
473     int oc, od, oh, ow;
474     int kd, kh, kw;
475     int padd, padh, padw;
476     int strd, strh, strw;
477     int dild, dilh, dilw;
478 };
479
480 struct test_convolution_attr_t {
481     struct scale_t {
482         enum policy_t { NONE = 0, COMMON };
483
484         bool is_def() const { return policy != NONE; }
485
486         scale_t (float s, policy_t p = NONE) :
487             scale(s) { policy = p; }
488
489         policy_t policy;
490         float scale;
491     };
492
493     void mkldnn_attr_recreate() {
494         mkl_attr = mkldnn::primitive_attr();
495         mkl_attr.set_int_output_round_mode(rmode);
496         if (oscale.is_def()) {
497             const int count = 1;
498             const int mask = 0;
499             std::vector<float> s(count, oscale.scale);
500             mkl_attr.set_output_scales(mask, s);
501         }
502     }
503
504     test_convolution_attr_t(mkldnn::round_mode rm, float s,
505         scale_t::policy_t p = scale_t::policy_t::NONE) :
506             rmode(rm), oscale(s, p), mkl_attr() {}
507
508     test_convolution_attr_t() :
509         rmode(mkldnn::round_mode::round_nearest),
510         oscale(1.0), mkl_attr() {}
511
512     mkldnn::round_mode rmode;
513     scale_t oscale;
514     mkldnn::primitive_attr mkl_attr;
515 };
516
517 struct test_convolution_formats_t {
518     mkldnn::memory::format src_format;
519     mkldnn::memory::format weights_format;
520     mkldnn::memory::format bias_format;
521     mkldnn::memory::format dst_format;
522 };
523
524 struct test_convolution_params_t {
525     const mkldnn::engine::kind engine_kind;
526     mkldnn::algorithm aalgorithm;
527     test_convolution_formats_t formats;
528     test_convolution_attr_t attr;
529     test_convolution_sizes_t sizes;
530     bool expect_to_fail;
531     mkldnn_status_t expected_status;
532 };
533
534 struct test_convolution_params_t_3d {
535     const mkldnn::engine::kind engine_kind;
536     mkldnn::algorithm aalgorithm;
537     test_convolution_formats_t formats;
538     test_convolution_attr_t attr;
539     test_convolution_sizes_t_3d sizes;
540     bool expect_to_fail;
541     mkldnn_status_t expected_status;
542 };
543
544 struct test_convolution_eltwise_params_t {
545     const mkldnn::algorithm alg;
546     const mkldnn::engine::kind engine_kind;
547     mkldnn::algorithm aalgorithm;
548     const float eltwise_alpha;
549     const float eltwise_beta;
550     test_convolution_formats_t formats;
551     test_convolution_attr_t attr;
552     test_convolution_sizes_t sizes;
553     bool expect_to_fail;
554     mkldnn_status_t expected_status;
555 };
556
557 struct test_convolution_depthwise_params_t {
558     const mkldnn::algorithm alg;
559     const mkldnn::engine::kind engine_kind;
560     mkldnn::algorithm aalgorithm;
561     test_convolution_formats_t formats;
562     test_convolution_attr_t attr;
563     test_convolution_sizes_t sizes;
564     bool expect_to_fail;
565     mkldnn_status_t expected_status;
566 };
567
568 struct test_convolution_dw_conv_sizes_t {
569     test_convolution_dw_conv_sizes_t(
570             int mb, int ic, int ih, int iw,
571             int conv1_oc,
572             int conv1_kh, int conv1_kw,
573             int conv1_padh, int conv1_padw,
574             int conv1_strh, int conv1_strw,
575             int conv2_oc,
576             int conv2_kh, int conv2_kw,
577             int conv2_padh, int conv2_padw,
578             int conv2_strh, int conv2_strw
579     ) :
580             mb(mb), ic(ic), ih(ih), iw(iw),
581             conv1_oc(conv1_oc),
582             conv1_kh(conv1_kh), conv1_kw(conv1_kw),
583             conv1_padh(conv1_padh), conv1_padw(conv1_padw),
584             conv1_strh(conv1_strh), conv1_strw(conv1_strw),
585             conv2_oc(conv2_oc),
586             conv2_kh(conv2_kh), conv2_kw(conv2_kw),
587             conv2_padh(conv2_padh), conv2_padw(conv2_padw),
588             conv2_strh(conv2_strh), conv2_strw(conv2_strw) {}
589     int mb, ic, ih, iw;
590     int conv1_oc;
591     int conv1_kh,   conv1_kw;
592     int conv1_padh, conv1_padw;
593     int conv1_strh, conv1_strw;
594     int conv2_oc;
595     int conv2_kh,   conv2_kw;
596     int conv2_padh, conv2_padw;
597     int conv2_strh, conv2_strw;
598 };
599
600 struct test_convolution_dw_conv_formats_t {
601     mkldnn::memory::format src_format;
602     mkldnn::memory::format conv1_weights_format;
603     mkldnn::memory::format conv1_bias_format;
604     mkldnn::memory::format conv2_weights_format;
605     mkldnn::memory::format conv2_bias_format;
606     mkldnn::memory::format dst_format;
607 };
608
609 struct test_convolution_dw_conv_params_t {
610     const mkldnn::engine::kind engine_kind;
611     mkldnn::algorithm aalgorithm;
612     test_convolution_dw_conv_formats_t formats;
613     test_convolution_dw_conv_sizes_t sizes;
614 };
615
616 struct test_roi_pool_desc_t {
617     struct {
618         int mb, c;
619         int h, w;
620     } data;
621
622     struct {
623         int mb, c;
624         int h, w;
625     } roi;
626
627     int pooled_h, pooled_w;
628     double spatial_scale;
629 };
630
631 struct roi_pool_test_params {
632     mkldnn::prop_kind aprop_kind;
633     mkldnn::algorithm algorithm_kind;
634     const mkldnn::engine::kind engine_kind;
635     mkldnn::memory::format data_format;
636     mkldnn::memory::format roi_format;
637     mkldnn::memory::format dst_format;
638     test_roi_pool_desc_t test_pd;
639 };
640
641 struct test_binary_convolution_params_t {
642     const mkldnn::engine::kind engine_kind;
643     mkldnn::algorithm aalgorithm;
644     float pad_value;
645     mkldnn::algorithm eltwise_algorithm;
646     const float eltwise_alpha;
647     const float eltwise_beta;
648     mkldnn::algorithm depthwise_algorithm;
649     bool with_sum;
650     mkldnn::algorithm binarization_algorithm;
651     test_convolution_formats_t formats;
652     test_convolution_sizes_t sizes;
653 };
654
655 struct test_binary_convolution_dw_conv_params_t {
656     const mkldnn::engine::kind engine_kind;
657     mkldnn::algorithm aalgorithm;
658     mkldnn::algorithm eltwise_algorithm;
659     const float eltwise_alpha;
660     const float eltwise_beta;
661     mkldnn::algorithm depthwise_algorithm;
662     bool with_sum;
663     mkldnn::algorithm binarization_algorithm;
664     test_convolution_dw_conv_formats_t formats;
665     test_convolution_dw_conv_sizes_t sizes;
666 };
667
668 std::ostream &operator<<(std::ostream &stream,
669                          const roi_pool_test_params &tp)
670 {
671     return stream << "(" << "input_data:" << " mb = " << tp.test_pd.data.mb  << ", c = " << tp.test_pd.data.c
672                   << ", h = " << tp.test_pd.data.h   << ", w = " << tp.test_pd.data.w
673                   << ", rois_num: " << tp.test_pd.roi.mb
674                   << ", pooled_h: " << tp.test_pd.pooled_h
675                   << ", pooled_w: " << tp.test_pd.pooled_w
676                   << ", spatial_scale: " << tp.test_pd.spatial_scale
677                   << ")";
678 }
679
680 template<typename F> bool catch_expected_failures(const F &f,
681         bool expect_to_fail, mkldnn_status_t expected_status, bool ignore_unimplemented = true)
682 {
683     try {
684         f();
685     } catch (const mkldnn::error &e) {
686         // Rethrow the exception if it is not expected or the error status did
687         // not match.
688         if (!(expect_to_fail) || e.status != (expected_status)) {
689             // Ignore unimplemented
690             if ( ignore_unimplemented && (e.status == mkldnn_unimplemented))
691                 return true;
692             else
693                 throw e;
694         }
695         // Return normally if the failure is expected
696         if (expect_to_fail)
697             return true;
698     }
699
700     // Throw an exception if the failure is expected but did not happen
701     if (expect_to_fail)
702         throw std::exception();
703
704     return false;
705 }
706
707 #define TEST_MALLOC_OFFSET 8
708 char *test_malloc(size_t size) {
709     void *ptr;
710     const size_t align = 64;
711     const size_t padded_size = TEST_MALLOC_OFFSET + size;
712 #ifdef _WIN32
713     ptr = _aligned_malloc(padded_size, align);
714     int rc = ((ptr) ? 0 : errno);
715 #else
716     int rc = ::posix_memalign(&ptr, align, padded_size);
717 #endif /* _WIN32 */
718     return rc == 0 ? (char*)ptr + TEST_MALLOC_OFFSET: 0;
719 }
720
721 void test_free(char *ptr) {
722     char *base_ptr = ptr - TEST_MALLOC_OFFSET;
723 #ifdef _WIN32
724     _aligned_free(base_ptr);
725 #else
726     return ::free(base_ptr);
727 #endif /* _WIN32 */
728 }
729 #undef TEST_MALLOC_OFFSET
730
731 class test_memory {
732 public:
733     test_memory(const mkldnn::memory::desc &d, const mkldnn::engine &e) {
734         auto pd = mkldnn::memory::primitive_desc(d, e);
735         pd_size_ = pd.get_size();
736         data_.reset(test_malloc(pd_size_), test_free);
737         mem_.reset(new mkldnn::memory(pd, data_.get()));
738     }
739     size_t get_size() const { return pd_size_; }
740     mkldnn::memory &get() { return *mem_; }
741
742 private:
743     std::shared_ptr<mkldnn::memory> mem_;
744     std::shared_ptr<char> data_;
745     size_t pd_size_;
746 };
747
748 #endif