Publishing 2019 R1.1 content and Myriad plugin sources (#162)
[platform/upstream/dldt.git] / inference-engine / thirdparty / mkl-dnn / tests / gtests / test_binarization.cpp
1 /*******************************************************************************
2 * Copyright 2019 Intel Corporation
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *     http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
16
17 #include <mkldnn_types.h>
18 #include "gtest/gtest.h"
19 #include "mkldnn_test_common.hpp"
20 #include "mkldnn.hpp"
21
22 namespace mkldnn {
23
24 template <typename data_t>
25 struct binarization_test_params {
26     engine::kind engine_kind;
27     algorithm alg_kind;
28     memory::format data_format;
29     memory::dims dims;
30 };
31
32 template <typename src_data_t>
33 void check_binarization_fwd(const binarization_test_params<src_data_t> &p,
34         const memory::desc &src_md, const memory &src, const memory &weights,
35         const memory &output_low, const memory &output_high, const memory &dst) {
36     auto src_data = (src_data_t*)src.get_data_handle();
37     auto weights_data = (src_data_t*)weights.get_data_handle();
38     auto output_low_data = (float*)output_low.get_data_handle();
39     auto output_high_data = (float*)output_high.get_data_handle();
40     auto dst_data = (uint8_t*)dst.get_data_handle();
41
42     const memory::desc src_d = src.get_primitive_desc().desc();
43     const memory::desc weights_d = weights.get_primitive_desc().desc();
44     const memory::desc output_low_d = output_low.get_primitive_desc().desc();
45     const memory::desc output_high_d = output_high.get_primitive_desc().desc();
46     const memory::desc dst_d = dst.get_primitive_desc().desc();
47
48     int N = src_md.data.ndims > 0 ? src_md.data.dims[0] : 1;
49     int C = src_md.data.ndims > 1 ? src_md.data.dims[1] : 1;
50     int H = src_md.data.ndims > 2 ? src_md.data.dims[2] : 1;
51     int W = src_md.data.ndims > 3 ? src_md.data.dims[3] : 1;
52
53     int nbits = 8;
54     int CB = div_up(C, nbits);
55
56     int padded_ic = src_d.data.layout_desc.blocking.padding_dims[1];
57     int padded_oc = dst_d.data.layout_desc.blocking.padding_dims[1];
58
59     for (int n = 0; n < N; ++n) {
60         for (int cb = 0; cb < CB; ++cb) {
61             for (int h = 0; h < H; ++h) {
62                 for (int w = 0; w < W; ++w) {
63
64                     uint8_t bin_val = 0x00;
65                     for (int c = cb * nbits, shift = 0; c < std::min(C, (cb + 1) * nbits); c++, shift++) {
66                         int src_idx = n*padded_ic*H*W + c*H*W + h*W + w;
67                         int wei_idx = c;
68
69                         src_data_t s_val = src_data[map_index(src_d, src_idx)];
70                         src_data_t w_val = weights_data[map_index(weights_d, wei_idx)];
71                         src_data_t out_low = output_low_data[map_index(output_low_d, wei_idx)];
72                         src_data_t out_high = output_high_data[map_index(output_high_d, wei_idx)];
73
74                         auto bit = uint8_t((s_val > w_val) ? out_high : out_low);
75                         bin_val |= (bit << shift);
76                     }
77
78                     int dst_idx = n*padded_oc*H*W + cb*nbits*H*W + h*W + w;
79                     dst_idx = map_index(dst_d, dst_idx);
80
81                     EXPECT_EQ(dst_data[dst_idx / nbits], bin_val);
82                 }
83             }
84         }
85     }
86 }
87
88 template <typename src_data_t>
89 class binarization_test : public ::testing::TestWithParam<binarization_test_params<src_data_t>> {
90 private:
91
92 protected:
93     virtual void SetUp() {
94         auto p = ::testing::TestWithParam<binarization_test_params<src_data_t>>::GetParam();
95
96         auto eng = engine(p.engine_kind, 0);
97         auto src_data_type = data_traits<src_data_t>::data_type;
98
99         memory::dims src_dims = memory::dims({p.dims[0], p.dims[1], p.dims[2], p.dims[3]});
100         memory::dims wei_dims = memory::dims({src_dims[1]});
101         memory::dims dst_dims = memory::dims({p.dims[0], p.dims[1], p.dims[2], p.dims[3]});
102
103         auto src_desc = create_md(src_dims, src_data_type, p.data_format);
104         auto weights_desc = create_md(wei_dims, src_data_type, memory::format::x);
105         auto output_low_desc = create_md(wei_dims, src_data_type, memory::format::x);
106         auto output_high_desc = create_md(wei_dims, src_data_type, memory::format::x);
107         auto output_mask_desc = create_md(wei_dims, src_data_type, memory::format::x);
108         auto dst_desc = create_md(dst_dims, memory::data_type::bin, p.data_format);
109
110         auto src = test_memory(src_desc, eng);
111         auto weights = test_memory(weights_desc, eng);
112         auto output_low = test_memory(output_low_desc, eng);
113         auto output_high = test_memory(output_high_desc, eng);
114         auto output_mask = test_memory(output_mask_desc, eng);
115         auto dst = test_memory(dst_desc, eng);
116
117         fill_data<src_data_t>(src.get_size() / sizeof(src_data_t), (src_data_t *)src.get().get_data_handle(),
118                               src_data_t(0), src_data_t(1));
119         fill_data<src_data_t>(weights.get_size() / sizeof(src_data_t), (src_data_t *)weights.get().get_data_handle(),
120                               src_data_t(0), src_data_t(1));
121         fill_data<src_data_t>(output_low.get_size() / sizeof(src_data_t), (src_data_t *)output_low.get().get_data_handle(),
122                               src_data_t(0), src_data_t(1));
123         fill_data<uint8_t>(dst.get_size() / sizeof(uint8_t), (uint8_t*)dst.get().get_data_handle());
124
125         src_data_t* p_output_low = (src_data_t *)output_low.get().get_data_handle();
126         src_data_t* p_output_high = (src_data_t *)output_high.get().get_data_handle();
127         uint32_t* p_output_mask = (uint32_t *)output_mask.get().get_data_handle();
128         for (int i = 0; i < src_dims[1]; i++) {
129             p_output_low[i] = p_output_low[i] >= 0 ? 1 : 0;
130             p_output_high[i] = p_output_low[i] == 1 ? 0 : 1;
131             p_output_mask[i] = p_output_high[i] == 1 ? 0xffffffff : 0x00000000;
132         }
133
134         std::vector<primitive> pipeline;
135         auto binarization_desc = binarization_forward::desc(prop_kind::forward_training, p.alg_kind, src_desc, weights_desc, output_high_desc, dst_desc);
136         auto binarization_prim_desc = binarization_forward::primitive_desc(binarization_desc, eng);
137         auto binarization = binarization_forward(binarization_prim_desc, src.get(), weights.get(), output_mask.get(), dst.get());
138
139         pipeline.push_back(binarization);
140         auto s = stream(stream::kind::lazy);
141         s.submit(pipeline).wait();
142
143         check_binarization_fwd(p, src_desc, src.get(), weights.get(), output_low.get(), output_high.get(), dst.get());
144     }
145 };
146
147 using binarization_test_float = binarization_test<float>;
148 using binarization_test_params_float = binarization_test_params<float>;
149
150 TEST_P(binarization_test_float, TestsBinarization)
151 {
152 }
153
154 #define EXPAND(args) args
155
156 #define EXPAND_FORMATS(data) memory::format::data
157
158 #define ENGINE engine::kind::cpu
159
160 #define PARAMS(alg, data, mb, c, h, w) \
161     binarization_test_params_float { ENGINE, algorithm::alg, \
162     EXPAND_FORMATS(data), {mb, c, h, w} }
163
164 #define PARAMS_ALL_ALG(...) \
165     EXPAND(PARAMS(binarization_depthwise, __VA_ARGS__))
166
167 #define INST_TEST_CASE(str, ...) INSTANTIATE_TEST_CASE_P( \
168         str, binarization_test_float, ::testing::Values(__VA_ARGS__))
169
170 INST_TEST_CASE(Simple_NHWC,
171     PARAMS_ALL_ALG(nhwc, 2, 8, 4, 4),
172     PARAMS_ALL_ALG(nhwc, 2, 16, 4, 4),
173     PARAMS_ALL_ALG(nhwc, 2, 16, 8, 8),
174     PARAMS_ALL_ALG(nhwc, 2, 16, 16, 8),
175     PARAMS_ALL_ALG(nhwc, 2, 16, 10, 8),
176     PARAMS_ALL_ALG(nhwc, 10, 10, 10, 10),
177     PARAMS_ALL_ALG(nhwc, 256, 64, 8, 16),
178     PARAMS_ALL_ALG(nhwc, 1, 1, 1, 1),
179     PARAMS_ALL_ALG(nhwc, 3, 5, 7, 11),
180     PARAMS_ALL_ALG(nhwc, 2, 129, 7, 4),
181     PARAMS_ALL_ALG(nhwc, 2, 333, 8, 3)
182 );
183
184 }