Publishing 2019 R1 content
[platform/upstream/dldt.git] / inference-engine / thirdparty / mkl-dnn / tests / gtests / test_binarization.cpp
1 /*******************************************************************************
2 * Copyright 2019 Intel Corporation
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *     http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
16
17 #include <mkldnn_types.h>
18 #include "gtest/gtest.h"
19 #include "mkldnn_test_common.hpp"
20 #include "mkldnn.hpp"
21
22 namespace mkldnn {
23
24 template <typename data_t>
25 struct binarization_test_params {
26     engine::kind engine_kind;
27     algorithm alg_kind;
28     memory::format data_format;
29     memory::dims dims;
30 };
31
32 template <typename src_data_t>
33 void check_binarization_fwd(const binarization_test_params<src_data_t> &p,
34         const memory::desc &src_md, const memory &src, const memory &weights, const memory &dst) {
35     auto src_data = (src_data_t*)src.get_data_handle();
36     auto weights_data = (src_data_t*)weights.get_data_handle();
37     auto dst_data = (uint8_t*)dst.get_data_handle();
38
39     const memory::desc src_d = src.get_primitive_desc().desc();
40     const memory::desc weights_d = weights.get_primitive_desc().desc();
41     const memory::desc dst_d = dst.get_primitive_desc().desc();
42
43     int N = src_md.data.ndims > 0 ? src_md.data.dims[0] : 1;
44     int C = src_md.data.ndims > 1 ? src_md.data.dims[1] : 1;
45     int H = src_md.data.ndims > 2 ? src_md.data.dims[2] : 1;
46     int W = src_md.data.ndims > 3 ? src_md.data.dims[3] : 1;
47
48     int nbits = 8;
49     int CB = div_up(C, nbits);
50
51     int padded_ic = src_d.data.layout_desc.blocking.padding_dims[1];
52     int padded_oc = dst_d.data.layout_desc.blocking.padding_dims[1];
53
54     for (int n = 0; n < N; ++n) {
55         for (int cb = 0; cb < CB; ++cb) {
56             for (int h = 0; h < H; ++h) {
57                 for (int w = 0; w < W; ++w) {
58
59                     uint8_t bin_val = 0x00;
60                     for (int c = cb * nbits, shift = 0; c < std::min(C, (cb + 1) * nbits); c++, shift++) {
61                         int src_idx = n*padded_ic*H*W + c*H*W + h*W + w;
62                         int wei_idx = c;
63
64                         src_data_t s_val = src_data[map_index(src_d, src_idx)];
65                         src_data_t w_val = weights_data[map_index(weights_d, wei_idx)];
66
67                         auto bit = uint8_t((s_val > w_val) ? 0x01 : 0x00);
68                         bin_val |= (bit << shift);
69                     }
70
71                     int dst_idx = n*padded_oc*H*W + cb*nbits*H*W + h*W + w;
72                     dst_idx = map_index(dst_d, dst_idx);
73
74                     EXPECT_EQ(dst_data[dst_idx / nbits], bin_val);
75                 }
76             }
77         }
78     }
79 }
80
81 template <typename src_data_t>
82 class binarization_test : public ::testing::TestWithParam<binarization_test_params<src_data_t>> {
83 private:
84
85 protected:
86     virtual void SetUp() {
87         auto p = ::testing::TestWithParam<binarization_test_params<src_data_t>>::GetParam();
88
89         auto eng = engine(p.engine_kind, 0);
90         auto src_data_type = data_traits<src_data_t>::data_type;
91
92         memory::dims src_dims = memory::dims({p.dims[0], p.dims[1], p.dims[2], p.dims[3]});
93         memory::dims wei_dims = memory::dims({src_dims[1]});
94         memory::dims dst_dims = memory::dims({p.dims[0], p.dims[1], p.dims[2], p.dims[3]});
95
96         auto src_desc = create_md(src_dims, src_data_type, p.data_format);
97         auto weights_desc = create_md(wei_dims, src_data_type, memory::format::x);
98         auto dst_desc = create_md(dst_dims, memory::data_type::bin, p.data_format);
99
100         auto src = test_memory(src_desc, eng);
101         auto weights = test_memory(weights_desc, eng);
102         auto dst = test_memory(dst_desc, eng);
103
104         fill_data<src_data_t>(src.get_size() / sizeof(src_data_t), (src_data_t *)src.get().get_data_handle(),
105                               src_data_t(0), src_data_t(1));
106         fill_data<src_data_t>(weights.get_size() / sizeof(src_data_t), (src_data_t *)weights.get().get_data_handle(),
107                               src_data_t(0), src_data_t(1));
108         fill_data<uint8_t>(dst.get_size() / sizeof(uint8_t), (uint8_t*)dst.get().get_data_handle());
109
110         std::vector<primitive> pipeline;
111         auto binarization_desc = binarization_forward::desc(prop_kind::forward_training, p.alg_kind, src_desc, weights_desc, dst_desc);
112         auto binarization_prim_desc = binarization_forward::primitive_desc(binarization_desc, eng);
113         auto binarization = binarization_forward(binarization_prim_desc, src.get(), weights.get(), dst.get());
114
115         pipeline.push_back(binarization);
116         auto s = stream(stream::kind::lazy);
117         s.submit(pipeline).wait();
118
119         check_binarization_fwd(p, src_desc, src.get(), weights.get(), dst.get());
120     }
121 };
122
123 using binarization_test_float = binarization_test<float>;
124 using binarization_test_params_float = binarization_test_params<float>;
125
126 TEST_P(binarization_test_float, TestsBinarization)
127 {
128 }
129
130 #define EXPAND(args) args
131
132 #define EXPAND_FORMATS(data) memory::format::data
133
134 #define ENGINE engine::kind::cpu
135
136 #define PARAMS(alg, data, mb, c, h, w) \
137     binarization_test_params_float { ENGINE, algorithm::alg, \
138     EXPAND_FORMATS(data), {mb, c, h, w} }
139
140 #define PARAMS_ALL_ALG(...) \
141     EXPAND(PARAMS(binarization_depthwise, __VA_ARGS__))
142
143 #define INST_TEST_CASE(str, ...) INSTANTIATE_TEST_CASE_P( \
144         str, binarization_test_float, ::testing::Values(__VA_ARGS__))
145
146 INST_TEST_CASE(Simple_NHWC,
147     PARAMS_ALL_ALG(nhwc, 2, 8, 4, 4),
148     PARAMS_ALL_ALG(nhwc, 2, 16, 4, 4),
149     PARAMS_ALL_ALG(nhwc, 2, 16, 8, 8),
150     PARAMS_ALL_ALG(nhwc, 2, 16, 16, 8),
151     PARAMS_ALL_ALG(nhwc, 2, 16, 10, 8),
152     PARAMS_ALL_ALG(nhwc, 10, 10, 10, 10),
153     PARAMS_ALL_ALG(nhwc, 256, 64, 8, 16),
154     PARAMS_ALL_ALG(nhwc, 1, 1, 1, 1),
155     PARAMS_ALL_ALG(nhwc, 3, 5, 7, 11),
156     PARAMS_ALL_ALG(nhwc, 2, 129, 7, 4),
157     PARAMS_ALL_ALG(nhwc, 2, 333, 8, 3)
158 );
159
160 }