Publishing 2019 R1 content
[platform/upstream/dldt.git] / inference-engine / thirdparty / mkl-dnn / tests / gtests / test_depthwise.cpp
1 /*******************************************************************************
2 * Copyright 2018 Intel Corporation
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *     http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
16
17 #include <mkldnn_types.h>
18 #include "gtest/gtest.h"
19 #include "mkldnn_test_common.hpp"
20 #include "mkldnn.hpp"
21
22 namespace mkldnn {
23
24 template <typename T> inline T scale_shift_fwd(T s_val, T w_val, T b_val) {
25     return s_val*w_val + b_val;
26 }
27
28 template <typename T> inline T prelu_fwd(T s_val, T w_val) {
29     return s_val >= 0 ? s_val : w_val*s_val;
30 }
31
32 template <typename data_t>
33 struct depthwise_test_params {
34     engine::kind engine_kind;
35     algorithm alg_kind;
36     memory::format data_format;
37     memory::dims dims;
38 };
39
40 template <typename data_t>
41 void check_depthwise_fwd(const depthwise_test_params<data_t> &p,
42         const memory::desc &md, const memory &src, const memory &weights,
43         const memory &bias, bool with_bias, const memory &dst)
44 {
45     data_t *src_data     = (data_t *)src.get_data_handle();
46     data_t *weights_data = (data_t *)weights.get_data_handle();
47     data_t *bias_data    = with_bias ? (data_t *)bias.get_data_handle() : nullptr;
48     data_t *dst_data     = (data_t *)dst.get_data_handle();
49
50     const memory::desc src_d = src.get_primitive_desc().desc();
51     const memory::desc weights_d = weights.get_primitive_desc().desc();
52     const memory::desc dst_d = dst.get_primitive_desc().desc();
53
54     ASSERT_EQ(md.data.data_type, memory::data_type::f32); // TODO: type assert
55
56     int N = md.data.ndims > 0 ? md.data.dims[0] : 1;
57     int C = md.data.ndims > 1 ? md.data.dims[1] : 1;
58     int H = md.data.ndims > 2 ? md.data.dims[2] : 1;
59     int W = md.data.ndims > 3 ? md.data.dims[3] : 1;
60
61     for (int n = 0; n < N; ++n) {
62         for (int c = 0; c < C; ++c) {
63             for (int h = 0; h < H; ++h) {
64                 for (int w = 0; w < W; ++w) {
65                     int idx = n*C*H*W + c*H*W + h*W + w;
66
67                     data_t s_val = src_data[map_index(src_d, idx)];
68                     data_t w_val = weights_data[map_index(weights_d, c)];
69                     data_t b_val = with_bias ? bias_data[map_index(bias.get_primitive_desc().desc(), c)] : 0;
70
71                     data_t ref_d = 0;
72                     switch (p.alg_kind) {
73                         case depthwise_scale_shift:
74                             ref_d = scale_shift_fwd(s_val, w_val, b_val);
75                             break;
76                         case depthwise_prelu:
77                             ref_d = prelu_fwd(s_val, w_val);
78                             break;
79                         default:
80                             assert(!"unknown alg_kind");
81                     }
82
83                     EXPECT_NEAR(dst_data[map_index(dst_d, idx)], ref_d, 1.e-6);
84                 }
85             }
86         }
87     }
88 }
89
90 template <typename data_t>
91 class depthwise_test : public ::testing::TestWithParam<depthwise_test_params<data_t>> {
92 private:
93     std::shared_ptr<memory> src;
94     std::shared_ptr<memory> weights;
95     std::shared_ptr<memory> bias;
96     std::shared_ptr<memory> dst;
97     std::shared_ptr<memory> workspace;
98     std::shared_ptr<memory::desc> src_desc;
99     std::shared_ptr<memory::desc> dst_desc;
100     std::shared_ptr<memory::desc> weights_desc;
101     std::shared_ptr<memory::desc> bias_desc;
102     std::shared_ptr<depthwise_forward::primitive_desc> depthwise_prim_desc;
103     depthwise_test_params<data_t> p;
104     std::shared_ptr<engine> eng;
105     memory::data_type data_type;
106     int data_size;
107     int weights_size;
108
109 protected:
110     virtual void SetUp() {
111         p = ::testing::TestWithParam<depthwise_test_params<data_t>>::GetParam();
112
113         ASSERT_TRUE(p.engine_kind == engine::kind::cpu);
114         eng.reset(new engine(p.engine_kind, 0));
115
116         data_type = data_traits<data_t>::data_type;
117         ASSERT_EQ(data_type, mkldnn::memory::data_type::f32);
118
119         data_size = p.dims[0] * p.dims[1] * p.dims[2] * p.dims[3];
120         weights_size = p.dims[1];
121
122         Forward();
123     }
124
125     void Forward() {
126         bool with_bias = p.alg_kind == depthwise_scale_shift;
127
128         memory::dims dims = p.data_format == mkldnn_nc ? memory::dims({p.dims[0], p.dims[1]}) : p.dims;
129
130         src_desc.reset(new memory::desc(dims, data_type, p.data_format));
131         dst_desc.reset(new memory::desc(dims, data_type, p.data_format));
132         src.reset(new memory({*src_desc, *eng}));
133         dst.reset(new memory({*dst_desc, *eng}));
134         fill_data<data_t>(data_size, (data_t *)src->get_data_handle(),
135                           data_t(0), data_t(1));
136
137         weights_desc.reset(new memory::desc({dims[1]}, data_type, memory::format::x));
138         weights.reset(new memory({*weights_desc, *eng}));
139         fill_data<data_t>(weights_size, (data_t *)weights->get_data_handle(),
140                           data_t(0), data_t(1));
141
142         if (with_bias) {
143             bias_desc.reset(new memory::desc({dims[1]}, data_type, memory::format::x));
144             bias.reset(new memory({*bias_desc, *eng}));
145             fill_data<data_t>(weights_size, (data_t *) bias->get_data_handle(),
146                               data_t(0), data_t(1));
147         }
148
149         std::vector<primitive> pipeline;
150         auto depthwise_desc = with_bias
151                               ? depthwise_forward::desc(prop_kind::forward_training, p.alg_kind, *src_desc, *dst_desc, *weights_desc, *bias_desc)
152                               : depthwise_forward::desc(prop_kind::forward_training, p.alg_kind, *src_desc, *dst_desc, *weights_desc);
153         depthwise_prim_desc.reset(new depthwise_forward::primitive_desc(depthwise_desc, *eng));
154
155         auto depthwise = with_bias
156                          ? depthwise_forward(*depthwise_prim_desc, *src, *weights, *bias, *dst)
157                          : depthwise_forward(*depthwise_prim_desc, *src, *weights, *dst);
158
159         pipeline.push_back(depthwise);
160         auto s = stream(stream::kind::lazy);
161         s.submit(pipeline).wait();
162
163         check_depthwise_fwd(p, *src_desc, *src, *weights, *bias, with_bias, *dst);
164     }
165 };
166
167 using depthwise_test_float = depthwise_test<float>;
168 using depthwise_test_params_float = depthwise_test_params<float>;
169
170 TEST_P(depthwise_test_float, TestsDepthwise)
171 {
172 }
173
174 #define EXPAND(args) args
175
176 #define EXPAND_FORMATS(data) memory::format::data
177
178 #define ENGINE engine::kind::cpu
179
180 #define PARAMS(alg, data, mb, c, h, w) \
181     depthwise_test_params_float { ENGINE, algorithm::alg, \
182     EXPAND_FORMATS(data), {mb, c, h, w} }
183
184 #define PARAMS_ALL_ALG(...) \
185     EXPAND(PARAMS(depthwise_scale_shift, __VA_ARGS__)), \
186     EXPAND(PARAMS(depthwise_prelu, __VA_ARGS__))
187
188 #define INST_TEST_CASE(str, ...) INSTANTIATE_TEST_CASE_P( \
189         str, depthwise_test_float, ::testing::Values(__VA_ARGS__))
190
191 INST_TEST_CASE(Simple_NC,
192     PARAMS_ALL_ALG(nc, 2, 8, 1, 1),
193     PARAMS_ALL_ALG(nc, 2, 16, 1, 1),
194     PARAMS_ALL_ALG(nc, 2, 16, 1, 1),
195     PARAMS_ALL_ALG(nc, 2, 16, 1, 1),
196     PARAMS_ALL_ALG(nc, 2, 16, 1, 1),
197     PARAMS_ALL_ALG(nc, 10, 10, 1, 1),
198     PARAMS_ALL_ALG(nc, 256, 64, 1, 1),
199     PARAMS_ALL_ALG(nc, 1, 1, 1, 1),
200     PARAMS_ALL_ALG(nc, 3, 5, 1, 1)
201 );
202
203 INST_TEST_CASE(Simple_NCHW,
204     PARAMS_ALL_ALG(nchw, 2, 8, 4, 4),
205     PARAMS_ALL_ALG(nchw, 2, 16, 4, 4),
206     PARAMS_ALL_ALG(nchw, 2, 16, 8, 8),
207     PARAMS_ALL_ALG(nchw, 2, 16, 16, 8),
208     PARAMS_ALL_ALG(nchw, 2, 16, 10, 8),
209     PARAMS_ALL_ALG(nchw, 10, 10, 10, 10),
210     PARAMS_ALL_ALG(nchw, 256, 64, 8, 16),
211     PARAMS_ALL_ALG(nchw, 1, 1, 1, 1),
212     PARAMS_ALL_ALG(nchw, 3, 5, 7, 11)
213 );
214
215 INST_TEST_CASE(Simple_Blocked,
216     PARAMS_ALL_ALG(nChw8c, 2, 8, 4, 4),
217     PARAMS_ALL_ALG(nChw8c, 2, 16, 4, 4),
218     PARAMS_ALL_ALG(nChw8c, 2, 16, 8, 8),
219     PARAMS_ALL_ALG(nChw8c, 2, 16, 16, 8),
220     PARAMS_ALL_ALG(nChw8c, 2, 32, 10, 8),
221     PARAMS_ALL_ALG(nChw8c, 256, 64, 8, 16)
222 );
223
224 INST_TEST_CASE(Simple_Blocked16,
225     PARAMS_ALL_ALG(nChw16c, 2, 16, 4, 4),
226     PARAMS_ALL_ALG(nChw16c, 2, 16, 8, 8),
227     PARAMS_ALL_ALG(nChw16c, 2, 16, 16, 8),
228     PARAMS_ALL_ALG(nChw16c, 2, 32, 10, 8),
229     PARAMS_ALL_ALG(nChw16c, 256, 64, 8, 16)
230 );
231
232 }