Publishing R3
[platform/upstream/dldt.git] / inference-engine / thirdparty / mkl-dnn / tests / gtests / test_eltwise.cpp
1 /*******************************************************************************
2 * Copyright 2016-2018 Intel Corporation
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *     http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
16
17 #include "gtest/gtest.h"
18 #include "mkldnn_test_common.hpp"
19
20 #include "mkldnn.hpp"
21
22 namespace mkldnn {
23
24 template <typename T, typename A> inline T relu_fwd(T s, A alpha) {
25     return s > 0 ? s : static_cast<T>(s * alpha);
26 }
27 template <typename T, typename A> inline T relu_bwd(T dd, T s, A alpha) {
28     return s > 0 ? dd : static_cast<T>(dd * alpha);
29 }
30
31 template <typename T> T tanh_fwd(T s) {
32     const float e = ::expf(2*s); /* maybe replace with -2*s? */
33     return static_cast<T>((e - 1.0) / (e + 1.0));
34 }
35 template <typename T> T tanh_bwd(T dd, T s) {
36     const float e = ::expf(2*s); /* maybe replace with -2*s? */
37     const float th = ((e - 1) / (e + 1));
38     return static_cast<T>(dd * (1 - th * th));
39 }
40
41 template <typename T, typename A> T elu_fwd(T s, A alpha) {
42     return s > 0 ? s : static_cast<T>(alpha * (::expf(s) - 1));
43 }
44 template <typename T, typename A> T elu_bwd(T dd, T s, A alpha) {
45     return static_cast<T>(dd * (s > 0 ? 1 : alpha * ::expf(s)));
46 }
47
48 template <typename T>
49 T square_fwd(T s) {
50     return s * s;
51 }
52
53 template <typename T>
54 T square_bwd(T dd, T s) {
55     return dd * 2*s;
56 }
57
58 template <typename T>
59 T abs_fwd(T s) {
60     return s > 0 ? s : -s;;
61 }
62
63 template <typename T>
64 T abs_bwd(T dd, T s) {
65     return dd * (s > 0 ? 1 : s < 0 ? -1 : 0);
66 }
67
68 template <typename T>
69 T sqrt_fwd(T s) {
70     return s > 0 ? ::sqrtf(s) : 0;
71 }
72
73 template <typename T>
74 T sqrt_bwd(T dd, T s) {
75     return s > 0 ? dd / (2 * ::sqrtf(s)) : 0;
76 }
77
78 template <typename T, typename A>
79 T linear_fwd(T s, A alpha, A beta) {
80     return alpha * s + beta;
81 }
82
83 template <typename T, typename A>
84 T linear_bwd(T dd, T s, A alpha, A beta) {
85     (void) s;
86     (void) beta;
87     return dd * alpha;
88 }
89
90 template <typename T, typename A>
91 T bounded_relu_fwd(T s, A alpha) {
92     s = s > 0 ? s : 0;
93     return s > alpha ? alpha : s;
94 }
95
96 template <typename T, typename A>
97 T bounded_relu_bwd(T dd, T s, A alpha) {
98     return dd * ((0 < s && s < alpha) ? 1 : 0);
99 }
100
101 template <typename T>
102 T soft_relu_fwd(T s) {
103     return logf(1 + ::expf(s));
104 }
105
106 template <typename T>
107 T soft_relu_bwd(T dd, T s) {
108     return dd / (1 + ::expf(-s));
109 }
110
111 template <typename T>
112 T logistic_fwd(T s) {
113     T v = ::expf(s);
114     return v / (v + 1);
115 }
116
117 template <typename T>
118 T logistic_bwd(T dd, T s) {
119     T v = ::expf(-s);
120     return dd * v / ((v + 1) * (v + 1));
121 }
122
123 template <typename T, typename A>
124 T clamp_fwd(T s, A alpha, A beta) {
125     return s > alpha ? (T)(alpha) : s < beta ? (T)(beta) : s;
126 }
127
128 template <typename T, typename A>
129 T clamp_bwd(T dd, T s, A alpha, A beta) {
130     return dd * ((beta < s && s < alpha) ? 1 : 0);
131 }
132
133 template <typename data_t>
134 struct eltwise_test_params {
135     engine::kind engine_kind;
136     algorithm alg_kind;
137     memory::format data_format;
138     memory::format diff_format;
139     data_t alpha, beta;
140     memory::dims dims;
141 };
142
143 size_t n_elems(const memory::desc &md) {
144     size_t p = 1;
145     const int *pdims = md.data.layout_desc.blocking.padding_dims;
146     for (int i = 0; i < md.data.ndims; ++i)
147         p *= (size_t)(pdims[i]);
148     return p;
149 }
150
151 template <typename data_t>
152 void check_eltwise_fwd(const eltwise_test_params<data_t> &p,
153         const memory::desc &md, const memory &src, const memory &dst)
154 {
155     data_t *src_data = (data_t *)src.get_data_handle();
156     data_t *dst_data = (data_t *)dst.get_data_handle();
157
158     ASSERT_EQ(md.data.data_type, memory::data_type::f32); // TODO: type assert
159
160     size_t n = n_elems(md);
161     for (size_t i = 0; i < n; ++i) {
162         data_t s = src_data[i];
163         data_t ref_d = 0;
164         switch (p.alg_kind) {
165         case eltwise_relu:        ref_d = relu_fwd(s, p.alpha);           break;
166         case eltwise_tanh:        ref_d = tanh_fwd(s);                    break;
167         case eltwise_elu:         ref_d = elu_fwd(s, p.alpha);            break;
168         case eltwise_square:      ref_d = square_fwd(s);                  break;
169         case eltwise_abs:         ref_d = abs_fwd(s);                     break;
170         case eltwise_sqrt:        ref_d = sqrt_fwd(s);                    break;
171         case eltwise_linear:      ref_d = linear_fwd(s, p.alpha, p.beta); break;
172         case eltwise_bounded_relu: ref_d = bounded_relu_fwd(s, p.alpha);  break;
173         case eltwise_soft_relu:   ref_d = soft_relu_fwd(s);               break;
174         case eltwise_logistic:    ref_d = logistic_fwd(s);                break;
175         case eltwise_clamp:       ref_d = clamp_fwd(s, p.alpha, p.beta);  break;
176         default: assert(!"unknown alg_kind");
177         }
178         dst_data[i] = ref_d;
179     }
180 }
181
182 template <typename data_t>
183 void compare_eltwise_fwd(const eltwise_test_params<data_t> &p,
184         const memory::desc &md, const memory &dst, const memory &ref_dst)
185 {
186     data_t *ref_dst_data = (data_t *)ref_dst.get_data_handle();
187     data_t *dst_data = (data_t *)dst.get_data_handle();
188
189     ASSERT_EQ(md.data.data_type, memory::data_type::f32); // TODO: type assert
190
191     size_t n = n_elems(md);
192     for (size_t i = 0; i < n; ++i) {
193         if (p.alg_kind == eltwise_soft_relu){
194             EXPECT_NEAR(dst_data[i], ref_dst_data[i], 2.e-6);
195         }
196         else{
197             EXPECT_NEAR(dst_data[i], ref_dst_data[i], 1.e-6);
198         }
199     }
200 }
201
202
203 template <typename data_t>
204 void check_eltwise_bwd(const eltwise_test_params<data_t> &p,
205         const memory::desc &md, const memory &src, const memory &diff_dst,
206         const memory &diff_src)
207 {
208     data_t *src_data = (data_t *)src.get_data_handle();
209     data_t *diff_dst_data = (data_t *)diff_dst.get_data_handle();
210     data_t *diff_src_data = (data_t *)diff_src.get_data_handle();
211
212     const memory::desc data_d = src.get_primitive_desc().desc();
213     const memory::desc diff_data_d = diff_src.get_primitive_desc().desc();
214
215     ASSERT_EQ(md.data.data_type, memory::data_type::f32); // TODO: type assert
216
217     size_t n = n_elems(md);
218     for (size_t i = 0; i < n; ++i) {
219         data_t ref_s = src_data[map_index(data_d, i)];
220         data_t ref_dd = diff_dst_data[map_index(diff_data_d, i)];
221         data_t ref_ds = 0;
222         switch (p.alg_kind) {
223         case eltwise_relu:   ref_ds = relu_bwd(ref_dd, ref_s, p.alpha); break;
224         case eltwise_tanh:   ref_ds = tanh_bwd(ref_dd, ref_s); break;
225         case eltwise_elu:    ref_ds = elu_bwd(ref_dd, ref_s, p.alpha); break;
226         case eltwise_square: ref_ds = square_bwd(ref_dd, ref_s); break;
227         case eltwise_abs:    ref_ds = abs_bwd(ref_dd, ref_s); break;
228         case eltwise_sqrt:   ref_ds = sqrt_bwd(ref_dd, ref_s); break;
229         case eltwise_linear:
230             ref_ds = linear_bwd(ref_dd, ref_s, p.alpha, p.beta);
231             break;
232         case eltwise_bounded_relu:
233             ref_ds = bounded_relu_bwd(ref_dd, ref_s, p.alpha);
234             break;
235         case eltwise_soft_relu:
236             ref_ds = soft_relu_bwd(ref_dd, ref_s);
237             break;
238         case eltwise_logistic: ref_ds = logistic_bwd(ref_dd, ref_s); break;
239         case eltwise_clamp: ref_ds = clamp_bwd(ref_dd, ref_s, p.alpha, p.beta); break;
240         default: assert(!"unknown alg_kind");
241         }
242         EXPECT_NEAR(diff_src_data[map_index(diff_data_d, i)], ref_ds, 1.e-6);
243     }
244 }
245
246 template <typename data_t>
247 class eltwise_test : public ::testing::TestWithParam<eltwise_test_params<data_t>> {
248 private:
249     std::shared_ptr<memory> src;
250     std::shared_ptr<memory> diff_src;
251     std::shared_ptr<memory> dst;
252     std::shared_ptr<memory> ref_dst;
253     std::shared_ptr<memory> diff_dst;
254     std::shared_ptr<memory> workspace;
255     std::shared_ptr<memory::desc> data_desc;
256     std::shared_ptr<memory::desc> diff_data_desc;
257     std::shared_ptr<eltwise_forward::primitive_desc> eltwise_prim_desc;
258     eltwise_test_params<data_t> p;
259     std::shared_ptr<engine> eng;
260     memory::data_type data_type;
261
262 protected:
263     virtual void SetUp() {
264         p = ::testing::TestWithParam<eltwise_test_params<data_t>>::GetParam();
265
266         ASSERT_TRUE(p.engine_kind == engine::kind::cpu);
267         eng.reset(new engine(p.engine_kind, 0));
268
269         data_type = data_traits<data_t>::data_type;
270         ASSERT_EQ(data_type, mkldnn::memory::data_type::f32);
271
272         Forward();
273         Backward();
274     }
275
276     void Forward() {
277         data_desc.reset(new memory::desc(p.dims, data_type,
278             p.data_format));
279         diff_data_desc.reset(new memory::desc(p.dims, data_type,
280             p.diff_format));
281         src.reset(new memory({*data_desc, *eng}));
282         dst.reset(new memory({*data_desc, *eng}));
283         ref_dst.reset(new memory({*data_desc, *eng}));
284
285         fill_data<data_t>(n_elems(*data_desc), (data_t *)src->get_data_handle(),
286                 data_t(0), data_t(1));
287         check_zero_tail<data_t>(1, *src);
288
289         auto eltwise_desc = eltwise_forward::desc(prop_kind::forward_training,
290                 p.alg_kind, *data_desc, p.alpha, p.beta);
291         eltwise_prim_desc.reset(
292                 new eltwise_forward::primitive_desc(eltwise_desc, *eng));
293         auto eltwise = eltwise_forward(*eltwise_prim_desc, *src, *dst);
294
295         std::vector<primitive> pipeline;
296         pipeline.push_back(eltwise);
297         auto s = stream(stream::kind::lazy);
298         s.submit(pipeline).wait();
299         check_zero_tail<data_t>(0, *dst);
300         check_eltwise_fwd(p, *data_desc, *src, *ref_dst);
301         check_zero_tail<data_t>(1, *ref_dst);
302         compare_eltwise_fwd(p, *data_desc, *dst, *ref_dst);
303
304     }
305
306     void Backward() {
307         diff_src.reset(new memory({*diff_data_desc, *eng}));
308         diff_dst.reset(new memory({*diff_data_desc, *eng}));
309
310         fill_data<data_t>(n_elems(*diff_data_desc),
311                 (data_t *)diff_dst->get_data_handle(), data_t(0), data_t(1));
312         check_zero_tail<data_t>(1, *diff_dst);
313
314         auto eltwise_bwd_desc = eltwise_backward::desc(p.alg_kind,
315                 *diff_data_desc, *data_desc, p.alpha, p.beta);
316         auto eltwise_bwd_prim_desc = eltwise_backward::primitive_desc(
317                 eltwise_bwd_desc, *eng, *eltwise_prim_desc);
318         auto eltwise_bwd = eltwise_backward(eltwise_bwd_prim_desc, *src,
319                 *diff_dst, *diff_src);
320
321         std::vector<primitive> pipeline;
322         pipeline.push_back(eltwise_bwd);
323         auto s = stream(stream::kind::lazy);
324         s.submit(pipeline).wait();
325
326         check_zero_tail<data_t>(0, *diff_src);
327         check_eltwise_bwd(p, *data_desc, *src, *diff_dst, *diff_src);
328     }
329 };
330
331 using eltwise_test_float = eltwise_test<float>;
332 using eltwise_test_params_float = eltwise_test_params<float>;
333
334 TEST_P(eltwise_test_float, TestsEltwise)
335 {
336 }
337
338 #define EXPAND(args) args
339
340 #define EXPAND_FORMATS(data) memory::format::data
341 #define EXPAND_DIMS(...) { __VA_ARGS__ }
342
343 #define ENGINE engine::kind::cpu
344
345 #define PARAMS(alg, data, diff_data, alpha, beta, ...) \
346     eltwise_test_params_float { ENGINE, algorithm::alg, \
347     EXPAND_FORMATS(data), EXPAND_FORMATS(diff_data), \
348     alpha, beta, EXPAND_DIMS(__VA_ARGS__) }
349
350 #define PARAMS_ALL_ALG(...) \
351     EXPAND(PARAMS(eltwise_relu, __VA_ARGS__)), \
352     EXPAND(PARAMS(eltwise_tanh, __VA_ARGS__)), \
353     EXPAND(PARAMS(eltwise_elu, __VA_ARGS__)), \
354     EXPAND(PARAMS(eltwise_square, __VA_ARGS__)), \
355     EXPAND(PARAMS(eltwise_abs, __VA_ARGS__))
356
357 #define PARAMS_ALL_ALG_SDPART(...) \
358     EXPAND(PARAMS(eltwise_sqrt, __VA_ARGS__)), \
359     EXPAND(PARAMS(eltwise_linear, __VA_ARGS__)), \
360     EXPAND(PARAMS(eltwise_soft_relu, __VA_ARGS__)), \
361     EXPAND(PARAMS(eltwise_bounded_relu, __VA_ARGS__)), \
362     EXPAND(PARAMS(eltwise_logistic, __VA_ARGS__)), \
363     EXPAND(PARAMS(eltwise_clamp, __VA_ARGS__))
364
365 #define INST_TEST_CASE(str, ...) INSTANTIATE_TEST_CASE_P( \
366         str, eltwise_test_float, ::testing::Values(__VA_ARGS__))
367
368 INST_TEST_CASE(Simple_blocked_3d_padded,
369     PARAMS_ALL_ALG(nCdhw16c, nCdhw16c, 0.1f, 0.2f, 4, 15, 2, 2, 2),
370     PARAMS_ALL_ALG_SDPART(nCdhw16c, nCdhw16c, 0.1f, 0.2f, 4, 27, 2, 2, 2),
371     PARAMS_ALL_ALG(nCdhw16c, nCdhw16c, 0.1f, 0.2f, 4, 23, 2, 2, 2),
372     PARAMS_ALL_ALG_SDPART(nCdhw16c, nCdhw16c, 0.1f, 0.2f, 4, 23, 7, 7, 7)
373 );
374
375 INST_TEST_CASE(Simple_blocked_padded,
376     PARAMS_ALL_ALG(nChw16c, nChw16c, 0.1f, 0.2f, 4, 15, 2, 2),
377     PARAMS_ALL_ALG_SDPART(nChw16c, nChw16c, 0.1f, 0.2f, 4, 27, 2, 2),
378     PARAMS_ALL_ALG(nChw16c, nChw16c, 0.1f, 0.2f, 4, 23, 2, 2),
379     PARAMS_ALL_ALG_SDPART(nChw16c, nChw16c, 0.1f, 0.2f, 4, 17, 7, 7)
380 );
381
382 INST_TEST_CASE(Simple_NCDHW,
383     PARAMS_ALL_ALG(ncdhw, ncdhw, 0.f, 0.f, 2, 32, 28, 28, 28),
384     PARAMS_ALL_ALG(ncdhw, ncdhw, 1.f, 0.f, 2, 64, 13, 13, 13),
385     PARAMS_ALL_ALG(ncdhw, ncdhw, 1.f, 1.f, 1, 64, 27, 27, 27),
386     PARAMS_ALL_ALG(ncdhw, ncdhw, 0.f, 1.f, 1, 128, 11, 11, 11)
387 );
388
389 INST_TEST_CASE(SimpleZeroNegativeSlope_NCHW,
390     PARAMS_ALL_ALG(nchw, nchw, 0.f, 0.f, 2, 8, 4, 4),
391     PARAMS_ALL_ALG(nchw, nchw, 0.f, 0.f, 2, 16, 4, 4),
392     PARAMS_ALL_ALG(nchw, nchw, 0.f, 0.f, 2, 16, 8, 8),
393     PARAMS_ALL_ALG(nchw, nchw, 0.f, 0.f, 2, 16, 16, 8),
394     PARAMS_ALL_ALG(nchw, nchw, 0.f, 0.f, 2, 16, 10, 8),
395     PARAMS_ALL_ALG(nchw, nchw, 0.f, 0.f, 10, 10, 10, 10),
396     PARAMS_ALL_ALG(nchw, nchw, 0.f, 0.f, 256, 64, 8, 16),
397     PARAMS_ALL_ALG(nchw, nchw, 0.f, 0.f, 1, 1, 1, 1),
398     PARAMS_ALL_ALG(nchw, nchw, 0.f, 0.f, 3, 5, 7, 11)
399 );
400
401 INST_TEST_CASE(Simple_NCHW,
402     PARAMS_ALL_ALG(nchw, nchw, 0.1f, 0.f, 2, 8, 4, 4),
403     PARAMS_ALL_ALG(nchw, nchw, 0.1f, 0.f, 2, 16, 4, 4),
404     PARAMS_ALL_ALG(nchw, nchw, 0.1f, 0.f, 2, 16, 8, 8),
405     PARAMS_ALL_ALG(nchw, nchw, 0.1f, 0.f, 2, 16, 16, 8),
406     PARAMS_ALL_ALG(nchw, nchw, 0.1f, 0.f, 2, 16, 10, 8),
407     PARAMS_ALL_ALG(nchw, nchw, 0.1f, 0.f, 10, 10, 10, 10),
408     PARAMS_ALL_ALG(nchw, nchw, 0.1f, 0.f, 256, 64, 8, 16),
409     PARAMS_ALL_ALG(nchw, nchw, 0.1f, 0.f, 1, 1, 1, 1),
410     PARAMS_ALL_ALG(nchw, nchw, 0.1f, 0.f, 3, 5, 7, 11)
411 );
412
413 INST_TEST_CASE(Simple_NCHW_SDPART,
414     PARAMS_ALL_ALG_SDPART(nchw, nchw, 0.1f, 0.f, 256, 64, 8, 16)
415 );
416
417 INST_TEST_CASE(Simple,
418     PARAMS_ALL_ALG(nchw, nChw8c, 0.1f, 0.f, 2, 8, 4, 4),
419     PARAMS_ALL_ALG(nChw8c, nchw, 0.1f, 0.f, 2, 16, 4, 4),
420     PARAMS_ALL_ALG(nchw, nchw, 0.1f, 0.f, 2, 16, 8, 8),
421     PARAMS_ALL_ALG(nChw8c, nChw8c, 0.1f, 0.f, 2, 16, 16, 8),
422     PARAMS_ALL_ALG(nhwc, nchw, 0.1f, 0.f, 2, 16, 10, 8),
423     PARAMS_ALL_ALG(nchw, nhwc, 0.1f, 0.f, 10, 10, 10, 10)
424 );
425
426 INST_TEST_CASE(Simple_SDPART,
427     PARAMS_ALL_ALG_SDPART(nchw, nChw8c, 0.1f, 0.f, 2, 8, 4, 4),
428     PARAMS_ALL_ALG_SDPART(nChw8c, nchw, 0.1f, 0.f, 2, 16, 4, 4),
429     PARAMS_ALL_ALG_SDPART(nchw, nchw, 0.1f, 0.f, 2, 16, 8, 8),
430     PARAMS_ALL_ALG_SDPART(nChw8c, nChw8c, 0.1f, 0.f, 2, 16, 16, 8),
431     PARAMS_ALL_ALG_SDPART(nhwc, nchw, 0.1f, 0.f, 2, 16, 10, 8),
432     PARAMS_ALL_ALG_SDPART(nchw, nhwc, 0.1f, 0.f, 10, 10, 10, 10)
433 );
434
435 INST_TEST_CASE(AlexNet_NCHW,
436     PARAMS_ALL_ALG(nchw, nchw, 0.f, 0.f, 2, 96, 55, 55),
437     PARAMS_ALL_ALG(nchw, nchw, 0.f, 0.f, 2, 256, 27, 27),
438     PARAMS_ALL_ALG(nchw, nchw, 0.f, 0.f, 2, 384, 13, 13),
439     PARAMS_ALL_ALG_SDPART(nchw, nchw, 0.f, 0.f, 2, 96, 55, 55),
440     PARAMS_ALL_ALG_SDPART(nchw, nchw, 0.f, 0.f, 2, 256, 27, 27),
441     PARAMS_ALL_ALG_SDPART(nchw, nchw, 0.f, 0.f, 2, 384, 13, 13)
442 );
443
444 INST_TEST_CASE(Simple_X,
445     PARAMS_ALL_ALG(x, x, 0.f, 0.f, 55)
446 );
447
448 }