1 /*******************************************************************************
2 * Copyright 2016-2018 Intel Corporation
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
8 * http://www.apache.org/licenses/LICENSE-2.0
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
17 #include "gtest/gtest.h"
18 #include "mkldnn_test_common.hpp"
24 template <typename T, typename A> inline T relu_fwd(T s, A alpha) {
25 return s > 0 ? s : static_cast<T>(s * alpha);
27 template <typename T, typename A> inline T relu_bwd(T dd, T s, A alpha) {
28 return s > 0 ? dd : static_cast<T>(dd * alpha);
31 template <typename T> T tanh_fwd(T s) {
32 const float e = ::expf(2*s); /* maybe replace with -2*s? */
33 return static_cast<T>((e - 1.0) / (e + 1.0));
35 template <typename T> T tanh_bwd(T dd, T s) {
36 const float e = ::expf(2*s); /* maybe replace with -2*s? */
37 const float th = ((e - 1) / (e + 1));
38 return static_cast<T>(dd * (1 - th * th));
41 template <typename T, typename A> T elu_fwd(T s, A alpha) {
42 return s > 0 ? s : static_cast<T>(alpha * (::expf(s) - 1));
44 template <typename T, typename A> T elu_bwd(T dd, T s, A alpha) {
45 return static_cast<T>(dd * (s > 0 ? 1 : alpha * ::expf(s)));
54 T square_bwd(T dd, T s) {
60 return s > 0 ? s : -s;;
64 T abs_bwd(T dd, T s) {
65 return dd * (s > 0 ? 1 : s < 0 ? -1 : 0);
70 return s > 0 ? ::sqrtf(s) : 0;
74 T sqrt_bwd(T dd, T s) {
75 return s > 0 ? dd / (2 * ::sqrtf(s)) : 0;
78 template <typename T, typename A>
79 T linear_fwd(T s, A alpha, A beta) {
80 return alpha * s + beta;
83 template <typename T, typename A>
84 T linear_bwd(T dd, T s, A alpha, A beta) {
90 template <typename T, typename A>
91 T bounded_relu_fwd(T s, A alpha) {
93 return s > alpha ? alpha : s;
96 template <typename T, typename A>
97 T bounded_relu_bwd(T dd, T s, A alpha) {
98 return dd * ((0 < s && s < alpha) ? 1 : 0);
101 template <typename T>
102 T soft_relu_fwd(T s) {
103 return logf(1 + ::expf(s));
106 template <typename T>
107 T soft_relu_bwd(T dd, T s) {
108 return dd / (1 + ::expf(-s));
111 template <typename T>
112 T logistic_fwd(T s) {
117 template <typename T>
118 T logistic_bwd(T dd, T s) {
120 return dd * v / ((v + 1) * (v + 1));
123 template <typename T, typename A>
124 T clamp_fwd(T s, A alpha, A beta) {
125 return s > alpha ? (T)(alpha) : s < beta ? (T)(beta) : s;
128 template <typename T, typename A>
129 T clamp_bwd(T dd, T s, A alpha, A beta) {
130 return dd * ((beta < s && s < alpha) ? 1 : 0);
133 template <typename data_t>
134 struct eltwise_test_params {
135 engine::kind engine_kind;
137 memory::format data_format;
138 memory::format diff_format;
143 size_t n_elems(const memory::desc &md) {
145 const int *pdims = md.data.layout_desc.blocking.padding_dims;
146 for (int i = 0; i < md.data.ndims; ++i)
147 p *= (size_t)(pdims[i]);
151 template <typename data_t>
152 void check_eltwise_fwd(const eltwise_test_params<data_t> &p,
153 const memory::desc &md, const memory &src, const memory &dst)
155 data_t *src_data = (data_t *)src.get_data_handle();
156 data_t *dst_data = (data_t *)dst.get_data_handle();
158 ASSERT_EQ(md.data.data_type, memory::data_type::f32); // TODO: type assert
160 size_t n = n_elems(md);
161 for (size_t i = 0; i < n; ++i) {
162 data_t s = src_data[i];
164 switch (p.alg_kind) {
165 case eltwise_relu: ref_d = relu_fwd(s, p.alpha); break;
166 case eltwise_tanh: ref_d = tanh_fwd(s); break;
167 case eltwise_elu: ref_d = elu_fwd(s, p.alpha); break;
168 case eltwise_square: ref_d = square_fwd(s); break;
169 case eltwise_abs: ref_d = abs_fwd(s); break;
170 case eltwise_sqrt: ref_d = sqrt_fwd(s); break;
171 case eltwise_linear: ref_d = linear_fwd(s, p.alpha, p.beta); break;
172 case eltwise_bounded_relu: ref_d = bounded_relu_fwd(s, p.alpha); break;
173 case eltwise_soft_relu: ref_d = soft_relu_fwd(s); break;
174 case eltwise_logistic: ref_d = logistic_fwd(s); break;
175 case eltwise_clamp: ref_d = clamp_fwd(s, p.alpha, p.beta); break;
176 default: assert(!"unknown alg_kind");
182 template <typename data_t>
183 void compare_eltwise_fwd(const eltwise_test_params<data_t> &p,
184 const memory::desc &md, const memory &dst, const memory &ref_dst)
186 data_t *ref_dst_data = (data_t *)ref_dst.get_data_handle();
187 data_t *dst_data = (data_t *)dst.get_data_handle();
189 ASSERT_EQ(md.data.data_type, memory::data_type::f32); // TODO: type assert
191 size_t n = n_elems(md);
192 for (size_t i = 0; i < n; ++i) {
193 if (p.alg_kind == eltwise_soft_relu){
194 EXPECT_NEAR(dst_data[i], ref_dst_data[i], 2.e-6);
197 EXPECT_NEAR(dst_data[i], ref_dst_data[i], 1.e-6);
203 template <typename data_t>
204 void check_eltwise_bwd(const eltwise_test_params<data_t> &p,
205 const memory::desc &md, const memory &src, const memory &diff_dst,
206 const memory &diff_src)
208 data_t *src_data = (data_t *)src.get_data_handle();
209 data_t *diff_dst_data = (data_t *)diff_dst.get_data_handle();
210 data_t *diff_src_data = (data_t *)diff_src.get_data_handle();
212 const memory::desc data_d = src.get_primitive_desc().desc();
213 const memory::desc diff_data_d = diff_src.get_primitive_desc().desc();
215 ASSERT_EQ(md.data.data_type, memory::data_type::f32); // TODO: type assert
217 size_t n = n_elems(md);
218 for (size_t i = 0; i < n; ++i) {
219 data_t ref_s = src_data[map_index(data_d, i)];
220 data_t ref_dd = diff_dst_data[map_index(diff_data_d, i)];
222 switch (p.alg_kind) {
223 case eltwise_relu: ref_ds = relu_bwd(ref_dd, ref_s, p.alpha); break;
224 case eltwise_tanh: ref_ds = tanh_bwd(ref_dd, ref_s); break;
225 case eltwise_elu: ref_ds = elu_bwd(ref_dd, ref_s, p.alpha); break;
226 case eltwise_square: ref_ds = square_bwd(ref_dd, ref_s); break;
227 case eltwise_abs: ref_ds = abs_bwd(ref_dd, ref_s); break;
228 case eltwise_sqrt: ref_ds = sqrt_bwd(ref_dd, ref_s); break;
230 ref_ds = linear_bwd(ref_dd, ref_s, p.alpha, p.beta);
232 case eltwise_bounded_relu:
233 ref_ds = bounded_relu_bwd(ref_dd, ref_s, p.alpha);
235 case eltwise_soft_relu:
236 ref_ds = soft_relu_bwd(ref_dd, ref_s);
238 case eltwise_logistic: ref_ds = logistic_bwd(ref_dd, ref_s); break;
239 case eltwise_clamp: ref_ds = clamp_bwd(ref_dd, ref_s, p.alpha, p.beta); break;
240 default: assert(!"unknown alg_kind");
242 EXPECT_NEAR(diff_src_data[map_index(diff_data_d, i)], ref_ds, 1.e-6);
246 template <typename data_t>
247 class eltwise_test : public ::testing::TestWithParam<eltwise_test_params<data_t>> {
249 std::shared_ptr<memory> src;
250 std::shared_ptr<memory> diff_src;
251 std::shared_ptr<memory> dst;
252 std::shared_ptr<memory> ref_dst;
253 std::shared_ptr<memory> diff_dst;
254 std::shared_ptr<memory> workspace;
255 std::shared_ptr<memory::desc> data_desc;
256 std::shared_ptr<memory::desc> diff_data_desc;
257 std::shared_ptr<eltwise_forward::primitive_desc> eltwise_prim_desc;
258 eltwise_test_params<data_t> p;
259 std::shared_ptr<engine> eng;
260 memory::data_type data_type;
263 virtual void SetUp() {
264 p = ::testing::TestWithParam<eltwise_test_params<data_t>>::GetParam();
266 ASSERT_TRUE(p.engine_kind == engine::kind::cpu);
267 eng.reset(new engine(p.engine_kind, 0));
269 data_type = data_traits<data_t>::data_type;
270 ASSERT_EQ(data_type, mkldnn::memory::data_type::f32);
277 data_desc.reset(new memory::desc(p.dims, data_type,
279 diff_data_desc.reset(new memory::desc(p.dims, data_type,
281 src.reset(new memory({*data_desc, *eng}));
282 dst.reset(new memory({*data_desc, *eng}));
283 ref_dst.reset(new memory({*data_desc, *eng}));
285 fill_data<data_t>(n_elems(*data_desc), (data_t *)src->get_data_handle(),
286 data_t(0), data_t(1));
287 check_zero_tail<data_t>(1, *src);
289 auto eltwise_desc = eltwise_forward::desc(prop_kind::forward_training,
290 p.alg_kind, *data_desc, p.alpha, p.beta);
291 eltwise_prim_desc.reset(
292 new eltwise_forward::primitive_desc(eltwise_desc, *eng));
293 auto eltwise = eltwise_forward(*eltwise_prim_desc, *src, *dst);
295 std::vector<primitive> pipeline;
296 pipeline.push_back(eltwise);
297 auto s = stream(stream::kind::lazy);
298 s.submit(pipeline).wait();
299 check_zero_tail<data_t>(0, *dst);
300 check_eltwise_fwd(p, *data_desc, *src, *ref_dst);
301 check_zero_tail<data_t>(1, *ref_dst);
302 compare_eltwise_fwd(p, *data_desc, *dst, *ref_dst);
307 diff_src.reset(new memory({*diff_data_desc, *eng}));
308 diff_dst.reset(new memory({*diff_data_desc, *eng}));
310 fill_data<data_t>(n_elems(*diff_data_desc),
311 (data_t *)diff_dst->get_data_handle(), data_t(0), data_t(1));
312 check_zero_tail<data_t>(1, *diff_dst);
314 auto eltwise_bwd_desc = eltwise_backward::desc(p.alg_kind,
315 *diff_data_desc, *data_desc, p.alpha, p.beta);
316 auto eltwise_bwd_prim_desc = eltwise_backward::primitive_desc(
317 eltwise_bwd_desc, *eng, *eltwise_prim_desc);
318 auto eltwise_bwd = eltwise_backward(eltwise_bwd_prim_desc, *src,
319 *diff_dst, *diff_src);
321 std::vector<primitive> pipeline;
322 pipeline.push_back(eltwise_bwd);
323 auto s = stream(stream::kind::lazy);
324 s.submit(pipeline).wait();
326 check_zero_tail<data_t>(0, *diff_src);
327 check_eltwise_bwd(p, *data_desc, *src, *diff_dst, *diff_src);
331 using eltwise_test_float = eltwise_test<float>;
332 using eltwise_test_params_float = eltwise_test_params<float>;
334 TEST_P(eltwise_test_float, TestsEltwise)
338 #define EXPAND(args) args
340 #define EXPAND_FORMATS(data) memory::format::data
341 #define EXPAND_DIMS(...) { __VA_ARGS__ }
343 #define ENGINE engine::kind::cpu
345 #define PARAMS(alg, data, diff_data, alpha, beta, ...) \
346 eltwise_test_params_float { ENGINE, algorithm::alg, \
347 EXPAND_FORMATS(data), EXPAND_FORMATS(diff_data), \
348 alpha, beta, EXPAND_DIMS(__VA_ARGS__) }
350 #define PARAMS_ALL_ALG(...) \
351 EXPAND(PARAMS(eltwise_relu, __VA_ARGS__)), \
352 EXPAND(PARAMS(eltwise_tanh, __VA_ARGS__)), \
353 EXPAND(PARAMS(eltwise_elu, __VA_ARGS__)), \
354 EXPAND(PARAMS(eltwise_square, __VA_ARGS__)), \
355 EXPAND(PARAMS(eltwise_abs, __VA_ARGS__))
357 #define PARAMS_ALL_ALG_SDPART(...) \
358 EXPAND(PARAMS(eltwise_sqrt, __VA_ARGS__)), \
359 EXPAND(PARAMS(eltwise_linear, __VA_ARGS__)), \
360 EXPAND(PARAMS(eltwise_soft_relu, __VA_ARGS__)), \
361 EXPAND(PARAMS(eltwise_bounded_relu, __VA_ARGS__)), \
362 EXPAND(PARAMS(eltwise_logistic, __VA_ARGS__)), \
363 EXPAND(PARAMS(eltwise_clamp, __VA_ARGS__))
365 #define INST_TEST_CASE(str, ...) INSTANTIATE_TEST_CASE_P( \
366 str, eltwise_test_float, ::testing::Values(__VA_ARGS__))
368 INST_TEST_CASE(Simple_blocked_3d_padded,
369 PARAMS_ALL_ALG(nCdhw16c, nCdhw16c, 0.1f, 0.2f, 4, 15, 2, 2, 2),
370 PARAMS_ALL_ALG_SDPART(nCdhw16c, nCdhw16c, 0.1f, 0.2f, 4, 27, 2, 2, 2),
371 PARAMS_ALL_ALG(nCdhw16c, nCdhw16c, 0.1f, 0.2f, 4, 23, 2, 2, 2),
372 PARAMS_ALL_ALG_SDPART(nCdhw16c, nCdhw16c, 0.1f, 0.2f, 4, 23, 7, 7, 7)
375 INST_TEST_CASE(Simple_blocked_padded,
376 PARAMS_ALL_ALG(nChw16c, nChw16c, 0.1f, 0.2f, 4, 15, 2, 2),
377 PARAMS_ALL_ALG_SDPART(nChw16c, nChw16c, 0.1f, 0.2f, 4, 27, 2, 2),
378 PARAMS_ALL_ALG(nChw16c, nChw16c, 0.1f, 0.2f, 4, 23, 2, 2),
379 PARAMS_ALL_ALG_SDPART(nChw16c, nChw16c, 0.1f, 0.2f, 4, 17, 7, 7)
382 INST_TEST_CASE(Simple_NCDHW,
383 PARAMS_ALL_ALG(ncdhw, ncdhw, 0.f, 0.f, 2, 32, 28, 28, 28),
384 PARAMS_ALL_ALG(ncdhw, ncdhw, 1.f, 0.f, 2, 64, 13, 13, 13),
385 PARAMS_ALL_ALG(ncdhw, ncdhw, 1.f, 1.f, 1, 64, 27, 27, 27),
386 PARAMS_ALL_ALG(ncdhw, ncdhw, 0.f, 1.f, 1, 128, 11, 11, 11)
389 INST_TEST_CASE(SimpleZeroNegativeSlope_NCHW,
390 PARAMS_ALL_ALG(nchw, nchw, 0.f, 0.f, 2, 8, 4, 4),
391 PARAMS_ALL_ALG(nchw, nchw, 0.f, 0.f, 2, 16, 4, 4),
392 PARAMS_ALL_ALG(nchw, nchw, 0.f, 0.f, 2, 16, 8, 8),
393 PARAMS_ALL_ALG(nchw, nchw, 0.f, 0.f, 2, 16, 16, 8),
394 PARAMS_ALL_ALG(nchw, nchw, 0.f, 0.f, 2, 16, 10, 8),
395 PARAMS_ALL_ALG(nchw, nchw, 0.f, 0.f, 10, 10, 10, 10),
396 PARAMS_ALL_ALG(nchw, nchw, 0.f, 0.f, 256, 64, 8, 16),
397 PARAMS_ALL_ALG(nchw, nchw, 0.f, 0.f, 1, 1, 1, 1),
398 PARAMS_ALL_ALG(nchw, nchw, 0.f, 0.f, 3, 5, 7, 11)
401 INST_TEST_CASE(Simple_NCHW,
402 PARAMS_ALL_ALG(nchw, nchw, 0.1f, 0.f, 2, 8, 4, 4),
403 PARAMS_ALL_ALG(nchw, nchw, 0.1f, 0.f, 2, 16, 4, 4),
404 PARAMS_ALL_ALG(nchw, nchw, 0.1f, 0.f, 2, 16, 8, 8),
405 PARAMS_ALL_ALG(nchw, nchw, 0.1f, 0.f, 2, 16, 16, 8),
406 PARAMS_ALL_ALG(nchw, nchw, 0.1f, 0.f, 2, 16, 10, 8),
407 PARAMS_ALL_ALG(nchw, nchw, 0.1f, 0.f, 10, 10, 10, 10),
408 PARAMS_ALL_ALG(nchw, nchw, 0.1f, 0.f, 256, 64, 8, 16),
409 PARAMS_ALL_ALG(nchw, nchw, 0.1f, 0.f, 1, 1, 1, 1),
410 PARAMS_ALL_ALG(nchw, nchw, 0.1f, 0.f, 3, 5, 7, 11)
413 INST_TEST_CASE(Simple_NCHW_SDPART,
414 PARAMS_ALL_ALG_SDPART(nchw, nchw, 0.1f, 0.f, 256, 64, 8, 16)
417 INST_TEST_CASE(Simple,
418 PARAMS_ALL_ALG(nchw, nChw8c, 0.1f, 0.f, 2, 8, 4, 4),
419 PARAMS_ALL_ALG(nChw8c, nchw, 0.1f, 0.f, 2, 16, 4, 4),
420 PARAMS_ALL_ALG(nchw, nchw, 0.1f, 0.f, 2, 16, 8, 8),
421 PARAMS_ALL_ALG(nChw8c, nChw8c, 0.1f, 0.f, 2, 16, 16, 8),
422 PARAMS_ALL_ALG(nhwc, nchw, 0.1f, 0.f, 2, 16, 10, 8),
423 PARAMS_ALL_ALG(nchw, nhwc, 0.1f, 0.f, 10, 10, 10, 10)
426 INST_TEST_CASE(Simple_SDPART,
427 PARAMS_ALL_ALG_SDPART(nchw, nChw8c, 0.1f, 0.f, 2, 8, 4, 4),
428 PARAMS_ALL_ALG_SDPART(nChw8c, nchw, 0.1f, 0.f, 2, 16, 4, 4),
429 PARAMS_ALL_ALG_SDPART(nchw, nchw, 0.1f, 0.f, 2, 16, 8, 8),
430 PARAMS_ALL_ALG_SDPART(nChw8c, nChw8c, 0.1f, 0.f, 2, 16, 16, 8),
431 PARAMS_ALL_ALG_SDPART(nhwc, nchw, 0.1f, 0.f, 2, 16, 10, 8),
432 PARAMS_ALL_ALG_SDPART(nchw, nhwc, 0.1f, 0.f, 10, 10, 10, 10)
435 INST_TEST_CASE(AlexNet_NCHW,
436 PARAMS_ALL_ALG(nchw, nchw, 0.f, 0.f, 2, 96, 55, 55),
437 PARAMS_ALL_ALG(nchw, nchw, 0.f, 0.f, 2, 256, 27, 27),
438 PARAMS_ALL_ALG(nchw, nchw, 0.f, 0.f, 2, 384, 13, 13),
439 PARAMS_ALL_ALG_SDPART(nchw, nchw, 0.f, 0.f, 2, 96, 55, 55),
440 PARAMS_ALL_ALG_SDPART(nchw, nchw, 0.f, 0.f, 2, 256, 27, 27),
441 PARAMS_ALL_ALG_SDPART(nchw, nchw, 0.f, 0.f, 2, 384, 13, 13)
444 INST_TEST_CASE(Simple_X,
445 PARAMS_ALL_ALG(x, x, 0.f, 0.f, 55)