Publishing 2019 R1 content
[platform/upstream/dldt.git] / inference-engine / thirdparty / mkl-dnn / tests / gtests / test_eltwise.cpp
1 /*******************************************************************************
2 * Copyright 2016-2018 Intel Corporation
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *     http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
16
17 #include "gtest/gtest.h"
18 #include "mkldnn_test_common.hpp"
19 #include "math_utils.hpp"
20 #include "mkldnn.hpp"
21
22 using namespace mkldnn::impl::math;
23
24 namespace mkldnn {
25
26 template <typename data_t>
27 struct eltwise_test_params {
28     engine::kind engine_kind;
29     algorithm alg_kind;
30     memory::format data_format;
31     memory::format diff_format;
32     data_t alpha, beta;
33     memory::dims dims;
34     bool expect_to_fail;
35     mkldnn_status_t expected_status;
36 };
37
38 size_t n_elems(const memory::desc &md) {
39     size_t p = 1;
40     const ptrdiff_t *pdims = md.data.layout_desc.blocking.padding_dims;
41     for (int i = 0; i < md.data.ndims; ++i)
42         p *= (size_t)(pdims[i]);
43     return p;
44 }
45
46 template <typename data_t>
47 void check_eltwise_fwd(const eltwise_test_params<data_t> &p,
48         const memory::desc &md, const memory &src, const memory &dst)
49 {
50     data_t *src_data = (data_t *)src.get_data_handle();
51     data_t *dst_data = (data_t *)dst.get_data_handle();
52
53     ASSERT_EQ(md.data.data_type, memory::data_type::f32); // TODO: type assert
54
55     size_t n = n_elems(md);
56     for (size_t i = 0; i < n; ++i) {
57         data_t s = src_data[i];
58         data_t ref_d = 0;
59         switch (p.alg_kind) {
60         case eltwise_relu:        ref_d = relu_fwd(s, p.alpha);           break;
61         case eltwise_tanh:        ref_d = tanh_fwd(s);                    break;
62         case eltwise_elu:         ref_d = elu_fwd(s, p.alpha);            break;
63         case eltwise_square:      ref_d = square_fwd(s);                  break;
64         case eltwise_abs:         ref_d = abs_fwd(s);                     break;
65         case eltwise_sqrt:        ref_d = sqrt_fwd(s);                    break;
66         case eltwise_linear:      ref_d = linear_fwd(s, p.alpha, p.beta); break;
67         case eltwise_bounded_relu: ref_d = bounded_relu_fwd(s, p.alpha);  break;
68         case eltwise_soft_relu:   ref_d = soft_relu_fwd(s);               break;
69         case eltwise_logistic:    ref_d = logistic_fwd(s);                break;
70         case eltwise_clamp:       ref_d = clamp_fwd(s, p.alpha, p.beta);  break;
71         case eltwise_exp:         ref_d = exp_fwd(s);                     break;
72         case eltwise_not:         ref_d = not_fwd(s);                     break;
73         default: assert(!"unknown alg_kind");
74         }
75         dst_data[i] = ref_d;
76     }
77 }
78
79 template <typename data_t>
80 void compare_eltwise_fwd(const eltwise_test_params<data_t> &p,
81         const memory::desc &md, const memory &dst, const memory &ref_dst)
82 {
83     data_t *ref_dst_data = (data_t *)ref_dst.get_data_handle();
84     data_t *dst_data = (data_t *)dst.get_data_handle();
85
86     ASSERT_EQ(md.data.data_type, memory::data_type::f32); // TODO: type assert
87
88     size_t n = n_elems(md);
89     for (size_t i = 0; i < n; ++i) {
90         if (p.alg_kind == eltwise_soft_relu){
91             EXPECT_NEAR(dst_data[i], ref_dst_data[i], 2.e-6);
92         }
93         else{
94             EXPECT_NEAR(dst_data[i], ref_dst_data[i], 1.e-6);
95         }
96     }
97 }
98
99
100 template <typename data_t>
101 void check_eltwise_bwd(const eltwise_test_params<data_t> &p,
102         const memory::desc &md, const memory &src, const memory &diff_dst,
103         const memory &diff_src)
104 {
105     data_t *src_data = (data_t *)src.get_data_handle();
106     data_t *diff_dst_data = (data_t *)diff_dst.get_data_handle();
107     data_t *diff_src_data = (data_t *)diff_src.get_data_handle();
108
109     const memory::desc data_d = src.get_primitive_desc().desc();
110     const memory::desc diff_data_d = diff_src.get_primitive_desc().desc();
111
112     ASSERT_EQ(md.data.data_type, memory::data_type::f32); // TODO: type assert
113
114     size_t n = n_elems(md);
115     for (size_t i = 0; i < n; ++i) {
116         data_t ref_s = src_data[map_index(data_d, i)];
117         data_t ref_dd = diff_dst_data[map_index(diff_data_d, i)];
118         data_t ref_ds = 0;
119         switch (p.alg_kind) {
120         case eltwise_relu:   ref_ds = relu_bwd(ref_dd, ref_s, p.alpha); break;
121         case eltwise_tanh:   ref_ds = tanh_bwd(ref_dd, ref_s); break;
122         case eltwise_elu:    ref_ds = elu_bwd(ref_dd, ref_s, p.alpha); break;
123         case eltwise_square: ref_ds = square_bwd(ref_dd, ref_s); break;
124         case eltwise_abs:    ref_ds = abs_bwd(ref_dd, ref_s); break;
125         case eltwise_sqrt:   ref_ds = sqrt_bwd(ref_dd, ref_s); break;
126         case eltwise_linear:
127             ref_ds = linear_bwd(ref_dd, ref_s, p.alpha, p.beta);
128             break;
129         case eltwise_bounded_relu:
130             ref_ds = bounded_relu_bwd(ref_dd, ref_s, p.alpha);
131             break;
132         case eltwise_soft_relu:
133             ref_ds = soft_relu_bwd(ref_dd, ref_s);
134             break;
135         case eltwise_logistic: ref_ds = logistic_bwd(ref_dd, ref_s); break;
136         case eltwise_clamp: ref_ds = clamp_bwd(ref_dd, ref_s, p.alpha, p.beta); break;
137         case eltwise_exp: ref_ds = exp_bwd(ref_dd, ref_s); break;
138         default: assert(!"unknown alg_kind");
139         }
140         EXPECT_NEAR(diff_src_data[map_index(diff_data_d, i)], ref_ds, 1.e-6);
141     }
142 }
143
144 template <typename data_t>
145 class eltwise_test : public ::testing::TestWithParam<eltwise_test_params<data_t>> {
146 private:
147     std::shared_ptr<memory> src;
148     std::shared_ptr<memory> diff_src;
149     std::shared_ptr<memory> dst;
150     std::shared_ptr<memory> ref_dst;
151     std::shared_ptr<memory> diff_dst;
152     std::shared_ptr<memory> workspace;
153     std::shared_ptr<memory::desc> data_desc;
154     std::shared_ptr<memory::desc> diff_data_desc;
155     std::shared_ptr<eltwise_forward::primitive_desc> eltwise_prim_desc;
156     eltwise_test_params<data_t> p;
157     std::shared_ptr<engine> eng;
158     memory::data_type data_type;
159
160 protected:
161     virtual void SetUp() {
162         p = ::testing::TestWithParam<decltype(p)>::GetParam();
163         catch_expected_failures([=](){Test();}, p.expect_to_fail,
164                     p.expected_status);
165     }
166
167     void Test() {
168         p = ::testing::TestWithParam<eltwise_test_params<data_t>>::GetParam();
169
170         ASSERT_TRUE(p.engine_kind == engine::kind::cpu);
171         eng.reset(new engine(p.engine_kind, 0));
172
173         data_type = data_traits<data_t>::data_type;
174         ASSERT_EQ(data_type, mkldnn::memory::data_type::f32);
175
176         Forward();
177         Backward();
178     }
179
180     void Forward() {
181         data_desc.reset(new memory::desc(p.dims, data_type,
182             p.data_format));
183         diff_data_desc.reset(new memory::desc(p.dims, data_type,
184             p.diff_format));
185         src.reset(new memory({*data_desc, *eng}));
186         dst.reset(new memory({*data_desc, *eng}));
187         ref_dst.reset(new memory({*data_desc, *eng}));
188
189         data_t data_median = data_t(0);
190         data_t data_deviation
191                 = p.alg_kind == eltwise_elu || p.alg_kind == eltwise_exp ? data_t(1) : data_t(200);
192         fill_data<data_t>(n_elems(*data_desc), (data_t *)src->get_data_handle(),
193                 data_median, data_deviation);
194         check_zero_tail<data_t>(1, *src);
195
196         auto eltwise_desc = eltwise_forward::desc(prop_kind::forward_training,
197                 p.alg_kind, *data_desc, p.alpha, p.beta);
198         eltwise_prim_desc.reset(
199                 new eltwise_forward::primitive_desc(eltwise_desc, *eng));
200         auto eltwise = eltwise_forward(*eltwise_prim_desc, *src, *dst);
201
202         std::vector<primitive> pipeline;
203         pipeline.push_back(eltwise);
204         auto s = stream(stream::kind::lazy);
205         s.submit(pipeline).wait();
206         check_zero_tail<data_t>(0, *dst);
207         check_eltwise_fwd(p, *data_desc, *src, *ref_dst);
208         check_zero_tail<data_t>(1, *ref_dst);
209         compare_eltwise_fwd(p, *data_desc, *dst, *ref_dst);
210
211     }
212
213     void Backward() {
214         diff_src.reset(new memory({*diff_data_desc, *eng}));
215         diff_dst.reset(new memory({*diff_data_desc, *eng}));
216
217         data_t data_median = data_t(0);
218         data_t data_deviation
219                 = p.alg_kind == eltwise_elu ? data_t(1) : data_t(200);
220         fill_data<data_t>(n_elems(*diff_data_desc),
221                 (data_t *)diff_dst->get_data_handle(), data_median,
222                 data_deviation);
223         check_zero_tail<data_t>(1, *diff_dst);
224
225         auto eltwise_bwd_desc = eltwise_backward::desc(p.alg_kind,
226                 *diff_data_desc, *data_desc, p.alpha, p.beta);
227         auto eltwise_bwd_prim_desc = eltwise_backward::primitive_desc(
228                 eltwise_bwd_desc, *eng, *eltwise_prim_desc);
229         auto eltwise_bwd = eltwise_backward(eltwise_bwd_prim_desc, *src,
230                 *diff_dst, *diff_src);
231
232         std::vector<primitive> pipeline;
233         pipeline.push_back(eltwise_bwd);
234         auto s = stream(stream::kind::lazy);
235         s.submit(pipeline).wait();
236
237         check_zero_tail<data_t>(0, *diff_src);
238         check_eltwise_bwd(p, *data_desc, *src, *diff_dst, *diff_src);
239     }
240 };
241
242 using eltwise_test_float = eltwise_test<float>;
243 using eltwise_test_params_float = eltwise_test_params<float>;
244
245 TEST_P(eltwise_test_float, TestsEltwise)
246 {
247 }
248
249 #define EXPAND(args) args
250
251 #define EXPAND_FORMATS(data) memory::format::data
252 #define EXPAND_DIMS(...) { __VA_ARGS__ }
253
254 #define ENGINE engine::kind::cpu
255
256 #define PARAMS(alg, data, diff_data, alpha, beta, ...) \
257     eltwise_test_params_float { ENGINE, algorithm::alg, \
258     EXPAND_FORMATS(data), EXPAND_FORMATS(diff_data), \
259     alpha, beta, EXPAND_DIMS(__VA_ARGS__) }
260
261 #define PARAMS_ALL_ALG(...) \
262     EXPAND(PARAMS(eltwise_relu, __VA_ARGS__)), \
263     EXPAND(PARAMS(eltwise_tanh, __VA_ARGS__)), \
264     EXPAND(PARAMS(eltwise_elu, __VA_ARGS__)), \
265     EXPAND(PARAMS(eltwise_square, __VA_ARGS__)), \
266     EXPAND(PARAMS(eltwise_abs, __VA_ARGS__))
267
268
269 #define PARAMS_ALL_ALG_SDPART(...) \
270     EXPAND(PARAMS(eltwise_sqrt, __VA_ARGS__)), \
271     EXPAND(PARAMS(eltwise_linear, __VA_ARGS__)), \
272     EXPAND(PARAMS(eltwise_soft_relu, __VA_ARGS__)), \
273     EXPAND(PARAMS(eltwise_bounded_relu, __VA_ARGS__)), \
274     EXPAND(PARAMS(eltwise_logistic, __VA_ARGS__)), \
275     EXPAND(PARAMS(eltwise_clamp, __VA_ARGS__)), \
276     EXPAND(PARAMS(eltwise_exp, __VA_ARGS__))
277
278
279 #define INST_TEST_CASE(str, ...) INSTANTIATE_TEST_CASE_P( \
280         str, eltwise_test_float, ::testing::Values(__VA_ARGS__))
281
282 INST_TEST_CASE(SimpleZeroDim,
283     PARAMS_ALL_ALG(ncdhw, nCdhw8c, 0.1f, 0.f, 0, 2, 4, 4, 4),
284     PARAMS_ALL_ALG(ncdhw, nCdhw8c, 0.1f, 0.f, 2, 0, 4, 4, 4),
285     PARAMS_ALL_ALG_SDPART(nCdhw16c, nCdhw16c, 0.1f, 0.2f, 0, 4, 2, 2, 2),
286     PARAMS_ALL_ALG_SDPART(nCdhw16c, nCdhw16c, 0.1f, 0.2f, 4, 0, 2, 2, 2)
287 );
288
289 #define CASE_EF(alg, d0, d1, d2, d3) \
290         eltwise_test_params_float { ENGINE, algorithm::eltwise_##alg, \
291         EXPAND_FORMATS(nchw), EXPAND_FORMATS(nchw), 0.f, 0.f, {d0, d1, d2, d3}, \
292         true, mkldnn_invalid_arguments }
293 INST_TEST_CASE(SimpleExpectedFails,
294     CASE_EF(relu, -1, 2, 4, 4),
295     CASE_EF(sqrt, -1, 2, 4, 4),
296     CASE_EF(logistic, -1, 2, 4, 4),
297     CASE_EF(relu, 1, -2, 4, 4),
298     CASE_EF(sqrt, 1, -2, 4, 4),
299     CASE_EF(logistic, 1, -2, 4, 4)
300 );
301
302 INST_TEST_CASE(Simple_3D,
303     PARAMS_ALL_ALG(ncdhw, nCdhw8c, 0.1f, 0.f, 2, 8, 4, 4, 4),
304     PARAMS_ALL_ALG(nCdhw8c, ncdhw, 0.1f, 0.f, 2, 16, 4, 4, 4),
305     PARAMS_ALL_ALG(ncdhw, ncdhw, 0.1f, 0.f, 2, 16, 8, 8, 8),
306     PARAMS_ALL_ALG(nCdhw8c, nCdhw8c, 0.1f, 0.f, 2, 16, 16, 8, 6),
307     PARAMS_ALL_ALG(ndhwc, ncdhw, 0.1f, 0.f, 2, 16, 10, 8, 6),
308     PARAMS_ALL_ALG(ncdhw, ndhwc, 0.1f, 0.f, 10, 10, 10, 10, 10)
309 );
310
311 INST_TEST_CASE(Simple_blocked_3d_padded,
312     PARAMS_ALL_ALG(nCdhw16c, nCdhw16c, 0.1f, 0.2f, 4, 15, 2, 2, 2),
313     PARAMS_ALL_ALG_SDPART(nCdhw16c, nCdhw16c, 0.1f, 0.2f, 4, 27, 2, 2, 2),
314     PARAMS_ALL_ALG(nCdhw16c, nCdhw16c, 0.1f, 0.2f, 4, 23, 2, 2, 2),
315     PARAMS_ALL_ALG_SDPART(nCdhw16c, nCdhw16c, 0.1f, 0.2f, 4, 23, 7, 7, 7)
316 );
317
318 INST_TEST_CASE(Simple_blocked_padded,
319     PARAMS_ALL_ALG(nChw16c, nChw16c, 0.1f, 0.2f, 4, 15, 2, 2),
320     PARAMS_ALL_ALG_SDPART(nChw16c, nChw16c, 0.1f, 0.2f, 4, 27, 2, 2),
321     PARAMS_ALL_ALG(nChw16c, nChw16c, 0.1f, 0.2f, 4, 23, 2, 2),
322     PARAMS_ALL_ALG_SDPART(nChw16c, nChw16c, 0.1f, 0.2f, 4, 17, 7, 7),
323     PARAMS_ALL_ALG(nChw8c, nChw8c, 0.1f, 0.2f, 4, 15, 2, 2),
324     PARAMS_ALL_ALG_SDPART(nChw8c, nChw8c, 0.1f, 0.2f, 4, 27, 2, 2),
325     PARAMS_ALL_ALG(nChw8c, nChw8c, 0.1f, 0.2f, 4, 23, 2, 2),
326     PARAMS_ALL_ALG_SDPART(nChw8c, nChw8c, 0.1f, 0.2f, 4, 17, 7, 7)
327 );
328
329 INST_TEST_CASE(Simple_NCDHW,
330     PARAMS_ALL_ALG(ncdhw, ncdhw, 0.f, 0.f, 2, 32, 28, 28, 28),
331     PARAMS_ALL_ALG(ncdhw, ncdhw, 1.f, 0.f, 2, 64, 13, 13, 13),
332     PARAMS_ALL_ALG(ncdhw, ncdhw, 1.f, 1.f, 1, 64, 27, 27, 27),
333     PARAMS_ALL_ALG(ncdhw, ncdhw, 0.f, 1.f, 1, 128, 11, 11, 11)
334 );
335
336 INST_TEST_CASE(SimpleZeroNegativeSlope_NCHW,
337     PARAMS_ALL_ALG(nchw, nchw, 0.f, 0.f, 2, 8, 4, 4),
338     PARAMS_ALL_ALG(nchw, nchw, 0.f, 0.f, 2, 16, 4, 4),
339     PARAMS_ALL_ALG(nchw, nchw, 0.f, 0.f, 2, 16, 8, 8),
340     PARAMS_ALL_ALG(nchw, nchw, 0.f, 0.f, 2, 16, 16, 8),
341     PARAMS_ALL_ALG(nchw, nchw, 0.f, 0.f, 2, 16, 10, 8),
342     PARAMS_ALL_ALG(nchw, nchw, 0.f, 0.f, 10, 10, 10, 10),
343     PARAMS_ALL_ALG(nchw, nchw, 0.f, 0.f, 256, 64, 8, 16),
344     PARAMS_ALL_ALG(nchw, nchw, 0.f, 0.f, 1, 1, 1, 1),
345     PARAMS_ALL_ALG(nchw, nchw, 0.f, 0.f, 3, 5, 7, 11)
346 );
347
348 INST_TEST_CASE(Simple_NCHW,
349     PARAMS_ALL_ALG(nchw, nchw, 0.1f, 0.f, 2, 8, 4, 4),
350     PARAMS_ALL_ALG(nchw, nchw, 0.1f, 0.f, 2, 16, 4, 4),
351     PARAMS_ALL_ALG(nchw, nchw, 0.1f, 0.f, 2, 16, 8, 8),
352     PARAMS_ALL_ALG(nchw, nchw, 0.1f, 0.f, 2, 16, 16, 8),
353     PARAMS_ALL_ALG(nchw, nchw, 0.1f, 0.f, 2, 16, 10, 8),
354     PARAMS_ALL_ALG(nchw, nchw, 0.1f, 0.f, 10, 10, 10, 10),
355     PARAMS_ALL_ALG(nchw, nchw, 0.1f, 0.f, 256, 64, 8, 16),
356     PARAMS_ALL_ALG(nchw, nchw, 0.1f, 0.f, 1, 1, 1, 1),
357     PARAMS_ALL_ALG(nchw, nchw, 0.1f, 0.f, 3, 5, 7, 11)
358 );
359
360 INST_TEST_CASE(Simple_NCHW_SDPART,
361     PARAMS_ALL_ALG_SDPART(nchw, nchw, 0.1f, 0.f, 256, 64, 8, 16)
362 );
363
364 INST_TEST_CASE(Simple,
365     PARAMS_ALL_ALG(nchw, nChw8c, 0.1f, 0.f, 2, 8, 4, 4),
366     PARAMS_ALL_ALG(nChw8c, nchw, 0.1f, 0.f, 2, 16, 4, 4),
367     PARAMS_ALL_ALG(nchw, nchw, 0.1f, 0.f, 2, 16, 8, 8),
368     PARAMS_ALL_ALG(nChw8c, nChw8c, 0.1f, 0.f, 2, 16, 16, 8),
369     PARAMS_ALL_ALG(nhwc, nchw, 0.1f, 0.f, 2, 16, 10, 8),
370     PARAMS_ALL_ALG(nchw, nhwc, 0.1f, 0.f, 10, 10, 10, 10)
371 );
372
373 INST_TEST_CASE(Simple_SDPART,
374     PARAMS_ALL_ALG_SDPART(nchw, nChw8c, 0.1f, 0.f, 2, 8, 4, 4),
375     PARAMS_ALL_ALG_SDPART(nChw8c, nchw, 0.1f, 0.f, 2, 16, 4, 4),
376     PARAMS_ALL_ALG_SDPART(nchw, nchw, 0.1f, 0.f, 2, 16, 8, 8),
377     PARAMS_ALL_ALG_SDPART(nChw8c, nChw8c, 0.1f, 0.f, 2, 16, 16, 8),
378     PARAMS_ALL_ALG_SDPART(nhwc, nchw, 0.1f, 0.f, 2, 16, 10, 8),
379     PARAMS_ALL_ALG_SDPART(nchw, nhwc, 0.1f, 0.f, 10, 10, 10, 10)
380 );
381
382 INST_TEST_CASE(AlexNet_NCHW,
383     PARAMS_ALL_ALG(nchw, nchw, 0.f, 0.f, 2, 96, 55, 55),
384     PARAMS_ALL_ALG(nchw, nchw, 0.f, 0.f, 2, 256, 27, 27),
385     PARAMS_ALL_ALG(nchw, nchw, 0.f, 0.f, 2, 384, 13, 13),
386     PARAMS_ALL_ALG_SDPART(nchw, nchw, 0.f, 0.f, 2, 96, 55, 55),
387     PARAMS_ALL_ALG_SDPART(nchw, nchw, 0.f, 0.f, 2, 256, 27, 27),
388     PARAMS_ALL_ALG_SDPART(nchw, nchw, 0.f, 0.f, 2, 384, 13, 13)
389 );
390
391 INST_TEST_CASE(Simple_X,
392     PARAMS_ALL_ALG(x, x, 0.f, 0.f, 55)
393 );
394
395 }