1 /*******************************************************************************
2 * Copyright 2016-2018 Intel Corporation
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
8 * http://www.apache.org/licenses/LICENSE-2.0
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
17 #include "gtest/gtest.h"
18 #include "mkldnn_test_common.hpp"
19 #include "math_utils.hpp"
22 using namespace mkldnn::impl::math;
26 template <typename data_t>
27 struct eltwise_test_params {
28 engine::kind engine_kind;
30 memory::format data_format;
31 memory::format diff_format;
35 mkldnn_status_t expected_status;
38 size_t n_elems(const memory::desc &md) {
40 const ptrdiff_t *pdims = md.data.layout_desc.blocking.padding_dims;
41 for (int i = 0; i < md.data.ndims; ++i)
42 p *= (size_t)(pdims[i]);
46 template <typename data_t>
47 void check_eltwise_fwd(const eltwise_test_params<data_t> &p,
48 const memory::desc &md, const memory &src, const memory &dst)
50 data_t *src_data = (data_t *)src.get_data_handle();
51 data_t *dst_data = (data_t *)dst.get_data_handle();
53 ASSERT_EQ(md.data.data_type, memory::data_type::f32); // TODO: type assert
55 size_t n = n_elems(md);
56 for (size_t i = 0; i < n; ++i) {
57 data_t s = src_data[i];
60 case eltwise_relu: ref_d = relu_fwd(s, p.alpha); break;
61 case eltwise_tanh: ref_d = tanh_fwd(s); break;
62 case eltwise_elu: ref_d = elu_fwd(s, p.alpha); break;
63 case eltwise_square: ref_d = square_fwd(s); break;
64 case eltwise_abs: ref_d = abs_fwd(s); break;
65 case eltwise_sqrt: ref_d = sqrt_fwd(s); break;
66 case eltwise_linear: ref_d = linear_fwd(s, p.alpha, p.beta); break;
67 case eltwise_bounded_relu: ref_d = bounded_relu_fwd(s, p.alpha); break;
68 case eltwise_soft_relu: ref_d = soft_relu_fwd(s); break;
69 case eltwise_logistic: ref_d = logistic_fwd(s); break;
70 case eltwise_clamp: ref_d = clamp_fwd(s, p.alpha, p.beta); break;
71 case eltwise_exp: ref_d = exp_fwd(s); break;
72 case eltwise_not: ref_d = not_fwd(s); break;
73 default: assert(!"unknown alg_kind");
79 template <typename data_t>
80 void compare_eltwise_fwd(const eltwise_test_params<data_t> &p,
81 const memory::desc &md, const memory &dst, const memory &ref_dst)
83 data_t *ref_dst_data = (data_t *)ref_dst.get_data_handle();
84 data_t *dst_data = (data_t *)dst.get_data_handle();
86 ASSERT_EQ(md.data.data_type, memory::data_type::f32); // TODO: type assert
88 size_t n = n_elems(md);
89 for (size_t i = 0; i < n; ++i) {
90 if (p.alg_kind == eltwise_soft_relu){
91 EXPECT_NEAR(dst_data[i], ref_dst_data[i], 2.e-6);
94 EXPECT_NEAR(dst_data[i], ref_dst_data[i], 1.e-6);
100 template <typename data_t>
101 void check_eltwise_bwd(const eltwise_test_params<data_t> &p,
102 const memory::desc &md, const memory &src, const memory &diff_dst,
103 const memory &diff_src)
105 data_t *src_data = (data_t *)src.get_data_handle();
106 data_t *diff_dst_data = (data_t *)diff_dst.get_data_handle();
107 data_t *diff_src_data = (data_t *)diff_src.get_data_handle();
109 const memory::desc data_d = src.get_primitive_desc().desc();
110 const memory::desc diff_data_d = diff_src.get_primitive_desc().desc();
112 ASSERT_EQ(md.data.data_type, memory::data_type::f32); // TODO: type assert
114 size_t n = n_elems(md);
115 for (size_t i = 0; i < n; ++i) {
116 data_t ref_s = src_data[map_index(data_d, i)];
117 data_t ref_dd = diff_dst_data[map_index(diff_data_d, i)];
119 switch (p.alg_kind) {
120 case eltwise_relu: ref_ds = relu_bwd(ref_dd, ref_s, p.alpha); break;
121 case eltwise_tanh: ref_ds = tanh_bwd(ref_dd, ref_s); break;
122 case eltwise_elu: ref_ds = elu_bwd(ref_dd, ref_s, p.alpha); break;
123 case eltwise_square: ref_ds = square_bwd(ref_dd, ref_s); break;
124 case eltwise_abs: ref_ds = abs_bwd(ref_dd, ref_s); break;
125 case eltwise_sqrt: ref_ds = sqrt_bwd(ref_dd, ref_s); break;
127 ref_ds = linear_bwd(ref_dd, ref_s, p.alpha, p.beta);
129 case eltwise_bounded_relu:
130 ref_ds = bounded_relu_bwd(ref_dd, ref_s, p.alpha);
132 case eltwise_soft_relu:
133 ref_ds = soft_relu_bwd(ref_dd, ref_s);
135 case eltwise_logistic: ref_ds = logistic_bwd(ref_dd, ref_s); break;
136 case eltwise_clamp: ref_ds = clamp_bwd(ref_dd, ref_s, p.alpha, p.beta); break;
137 case eltwise_exp: ref_ds = exp_bwd(ref_dd, ref_s); break;
138 default: assert(!"unknown alg_kind");
140 EXPECT_NEAR(diff_src_data[map_index(diff_data_d, i)], ref_ds, 1.e-6);
144 template <typename data_t>
145 class eltwise_test : public ::testing::TestWithParam<eltwise_test_params<data_t>> {
147 std::shared_ptr<memory> src;
148 std::shared_ptr<memory> diff_src;
149 std::shared_ptr<memory> dst;
150 std::shared_ptr<memory> ref_dst;
151 std::shared_ptr<memory> diff_dst;
152 std::shared_ptr<memory> workspace;
153 std::shared_ptr<memory::desc> data_desc;
154 std::shared_ptr<memory::desc> diff_data_desc;
155 std::shared_ptr<eltwise_forward::primitive_desc> eltwise_prim_desc;
156 eltwise_test_params<data_t> p;
157 std::shared_ptr<engine> eng;
158 memory::data_type data_type;
161 virtual void SetUp() {
162 p = ::testing::TestWithParam<decltype(p)>::GetParam();
163 catch_expected_failures([=](){Test();}, p.expect_to_fail,
168 p = ::testing::TestWithParam<eltwise_test_params<data_t>>::GetParam();
170 ASSERT_TRUE(p.engine_kind == engine::kind::cpu);
171 eng.reset(new engine(p.engine_kind, 0));
173 data_type = data_traits<data_t>::data_type;
174 ASSERT_EQ(data_type, mkldnn::memory::data_type::f32);
181 data_desc.reset(new memory::desc(p.dims, data_type,
183 diff_data_desc.reset(new memory::desc(p.dims, data_type,
185 src.reset(new memory({*data_desc, *eng}));
186 dst.reset(new memory({*data_desc, *eng}));
187 ref_dst.reset(new memory({*data_desc, *eng}));
189 data_t data_median = data_t(0);
190 data_t data_deviation
191 = p.alg_kind == eltwise_elu || p.alg_kind == eltwise_exp ? data_t(1) : data_t(200);
192 fill_data<data_t>(n_elems(*data_desc), (data_t *)src->get_data_handle(),
193 data_median, data_deviation);
194 check_zero_tail<data_t>(1, *src);
196 auto eltwise_desc = eltwise_forward::desc(prop_kind::forward_training,
197 p.alg_kind, *data_desc, p.alpha, p.beta);
198 eltwise_prim_desc.reset(
199 new eltwise_forward::primitive_desc(eltwise_desc, *eng));
200 auto eltwise = eltwise_forward(*eltwise_prim_desc, *src, *dst);
202 std::vector<primitive> pipeline;
203 pipeline.push_back(eltwise);
204 auto s = stream(stream::kind::lazy);
205 s.submit(pipeline).wait();
206 check_zero_tail<data_t>(0, *dst);
207 check_eltwise_fwd(p, *data_desc, *src, *ref_dst);
208 check_zero_tail<data_t>(1, *ref_dst);
209 compare_eltwise_fwd(p, *data_desc, *dst, *ref_dst);
214 diff_src.reset(new memory({*diff_data_desc, *eng}));
215 diff_dst.reset(new memory({*diff_data_desc, *eng}));
217 data_t data_median = data_t(0);
218 data_t data_deviation
219 = p.alg_kind == eltwise_elu ? data_t(1) : data_t(200);
220 fill_data<data_t>(n_elems(*diff_data_desc),
221 (data_t *)diff_dst->get_data_handle(), data_median,
223 check_zero_tail<data_t>(1, *diff_dst);
225 auto eltwise_bwd_desc = eltwise_backward::desc(p.alg_kind,
226 *diff_data_desc, *data_desc, p.alpha, p.beta);
227 auto eltwise_bwd_prim_desc = eltwise_backward::primitive_desc(
228 eltwise_bwd_desc, *eng, *eltwise_prim_desc);
229 auto eltwise_bwd = eltwise_backward(eltwise_bwd_prim_desc, *src,
230 *diff_dst, *diff_src);
232 std::vector<primitive> pipeline;
233 pipeline.push_back(eltwise_bwd);
234 auto s = stream(stream::kind::lazy);
235 s.submit(pipeline).wait();
237 check_zero_tail<data_t>(0, *diff_src);
238 check_eltwise_bwd(p, *data_desc, *src, *diff_dst, *diff_src);
242 using eltwise_test_float = eltwise_test<float>;
243 using eltwise_test_params_float = eltwise_test_params<float>;
245 TEST_P(eltwise_test_float, TestsEltwise)
249 #define EXPAND(args) args
251 #define EXPAND_FORMATS(data) memory::format::data
252 #define EXPAND_DIMS(...) { __VA_ARGS__ }
254 #define ENGINE engine::kind::cpu
256 #define PARAMS(alg, data, diff_data, alpha, beta, ...) \
257 eltwise_test_params_float { ENGINE, algorithm::alg, \
258 EXPAND_FORMATS(data), EXPAND_FORMATS(diff_data), \
259 alpha, beta, EXPAND_DIMS(__VA_ARGS__) }
261 #define PARAMS_ALL_ALG(...) \
262 EXPAND(PARAMS(eltwise_relu, __VA_ARGS__)), \
263 EXPAND(PARAMS(eltwise_tanh, __VA_ARGS__)), \
264 EXPAND(PARAMS(eltwise_elu, __VA_ARGS__)), \
265 EXPAND(PARAMS(eltwise_square, __VA_ARGS__)), \
266 EXPAND(PARAMS(eltwise_abs, __VA_ARGS__))
269 #define PARAMS_ALL_ALG_SDPART(...) \
270 EXPAND(PARAMS(eltwise_sqrt, __VA_ARGS__)), \
271 EXPAND(PARAMS(eltwise_linear, __VA_ARGS__)), \
272 EXPAND(PARAMS(eltwise_soft_relu, __VA_ARGS__)), \
273 EXPAND(PARAMS(eltwise_bounded_relu, __VA_ARGS__)), \
274 EXPAND(PARAMS(eltwise_logistic, __VA_ARGS__)), \
275 EXPAND(PARAMS(eltwise_clamp, __VA_ARGS__)), \
276 EXPAND(PARAMS(eltwise_exp, __VA_ARGS__))
279 #define INST_TEST_CASE(str, ...) INSTANTIATE_TEST_CASE_P( \
280 str, eltwise_test_float, ::testing::Values(__VA_ARGS__))
282 INST_TEST_CASE(SimpleZeroDim,
283 PARAMS_ALL_ALG(ncdhw, nCdhw8c, 0.1f, 0.f, 0, 2, 4, 4, 4),
284 PARAMS_ALL_ALG(ncdhw, nCdhw8c, 0.1f, 0.f, 2, 0, 4, 4, 4),
285 PARAMS_ALL_ALG_SDPART(nCdhw16c, nCdhw16c, 0.1f, 0.2f, 0, 4, 2, 2, 2),
286 PARAMS_ALL_ALG_SDPART(nCdhw16c, nCdhw16c, 0.1f, 0.2f, 4, 0, 2, 2, 2)
289 #define CASE_EF(alg, d0, d1, d2, d3) \
290 eltwise_test_params_float { ENGINE, algorithm::eltwise_##alg, \
291 EXPAND_FORMATS(nchw), EXPAND_FORMATS(nchw), 0.f, 0.f, {d0, d1, d2, d3}, \
292 true, mkldnn_invalid_arguments }
293 INST_TEST_CASE(SimpleExpectedFails,
294 CASE_EF(relu, -1, 2, 4, 4),
295 CASE_EF(sqrt, -1, 2, 4, 4),
296 CASE_EF(logistic, -1, 2, 4, 4),
297 CASE_EF(relu, 1, -2, 4, 4),
298 CASE_EF(sqrt, 1, -2, 4, 4),
299 CASE_EF(logistic, 1, -2, 4, 4)
302 INST_TEST_CASE(Simple_3D,
303 PARAMS_ALL_ALG(ncdhw, nCdhw8c, 0.1f, 0.f, 2, 8, 4, 4, 4),
304 PARAMS_ALL_ALG(nCdhw8c, ncdhw, 0.1f, 0.f, 2, 16, 4, 4, 4),
305 PARAMS_ALL_ALG(ncdhw, ncdhw, 0.1f, 0.f, 2, 16, 8, 8, 8),
306 PARAMS_ALL_ALG(nCdhw8c, nCdhw8c, 0.1f, 0.f, 2, 16, 16, 8, 6),
307 PARAMS_ALL_ALG(ndhwc, ncdhw, 0.1f, 0.f, 2, 16, 10, 8, 6),
308 PARAMS_ALL_ALG(ncdhw, ndhwc, 0.1f, 0.f, 10, 10, 10, 10, 10)
311 INST_TEST_CASE(Simple_blocked_3d_padded,
312 PARAMS_ALL_ALG(nCdhw16c, nCdhw16c, 0.1f, 0.2f, 4, 15, 2, 2, 2),
313 PARAMS_ALL_ALG_SDPART(nCdhw16c, nCdhw16c, 0.1f, 0.2f, 4, 27, 2, 2, 2),
314 PARAMS_ALL_ALG(nCdhw16c, nCdhw16c, 0.1f, 0.2f, 4, 23, 2, 2, 2),
315 PARAMS_ALL_ALG_SDPART(nCdhw16c, nCdhw16c, 0.1f, 0.2f, 4, 23, 7, 7, 7)
318 INST_TEST_CASE(Simple_blocked_padded,
319 PARAMS_ALL_ALG(nChw16c, nChw16c, 0.1f, 0.2f, 4, 15, 2, 2),
320 PARAMS_ALL_ALG_SDPART(nChw16c, nChw16c, 0.1f, 0.2f, 4, 27, 2, 2),
321 PARAMS_ALL_ALG(nChw16c, nChw16c, 0.1f, 0.2f, 4, 23, 2, 2),
322 PARAMS_ALL_ALG_SDPART(nChw16c, nChw16c, 0.1f, 0.2f, 4, 17, 7, 7),
323 PARAMS_ALL_ALG(nChw8c, nChw8c, 0.1f, 0.2f, 4, 15, 2, 2),
324 PARAMS_ALL_ALG_SDPART(nChw8c, nChw8c, 0.1f, 0.2f, 4, 27, 2, 2),
325 PARAMS_ALL_ALG(nChw8c, nChw8c, 0.1f, 0.2f, 4, 23, 2, 2),
326 PARAMS_ALL_ALG_SDPART(nChw8c, nChw8c, 0.1f, 0.2f, 4, 17, 7, 7)
329 INST_TEST_CASE(Simple_NCDHW,
330 PARAMS_ALL_ALG(ncdhw, ncdhw, 0.f, 0.f, 2, 32, 28, 28, 28),
331 PARAMS_ALL_ALG(ncdhw, ncdhw, 1.f, 0.f, 2, 64, 13, 13, 13),
332 PARAMS_ALL_ALG(ncdhw, ncdhw, 1.f, 1.f, 1, 64, 27, 27, 27),
333 PARAMS_ALL_ALG(ncdhw, ncdhw, 0.f, 1.f, 1, 128, 11, 11, 11)
336 INST_TEST_CASE(SimpleZeroNegativeSlope_NCHW,
337 PARAMS_ALL_ALG(nchw, nchw, 0.f, 0.f, 2, 8, 4, 4),
338 PARAMS_ALL_ALG(nchw, nchw, 0.f, 0.f, 2, 16, 4, 4),
339 PARAMS_ALL_ALG(nchw, nchw, 0.f, 0.f, 2, 16, 8, 8),
340 PARAMS_ALL_ALG(nchw, nchw, 0.f, 0.f, 2, 16, 16, 8),
341 PARAMS_ALL_ALG(nchw, nchw, 0.f, 0.f, 2, 16, 10, 8),
342 PARAMS_ALL_ALG(nchw, nchw, 0.f, 0.f, 10, 10, 10, 10),
343 PARAMS_ALL_ALG(nchw, nchw, 0.f, 0.f, 256, 64, 8, 16),
344 PARAMS_ALL_ALG(nchw, nchw, 0.f, 0.f, 1, 1, 1, 1),
345 PARAMS_ALL_ALG(nchw, nchw, 0.f, 0.f, 3, 5, 7, 11)
348 INST_TEST_CASE(Simple_NCHW,
349 PARAMS_ALL_ALG(nchw, nchw, 0.1f, 0.f, 2, 8, 4, 4),
350 PARAMS_ALL_ALG(nchw, nchw, 0.1f, 0.f, 2, 16, 4, 4),
351 PARAMS_ALL_ALG(nchw, nchw, 0.1f, 0.f, 2, 16, 8, 8),
352 PARAMS_ALL_ALG(nchw, nchw, 0.1f, 0.f, 2, 16, 16, 8),
353 PARAMS_ALL_ALG(nchw, nchw, 0.1f, 0.f, 2, 16, 10, 8),
354 PARAMS_ALL_ALG(nchw, nchw, 0.1f, 0.f, 10, 10, 10, 10),
355 PARAMS_ALL_ALG(nchw, nchw, 0.1f, 0.f, 256, 64, 8, 16),
356 PARAMS_ALL_ALG(nchw, nchw, 0.1f, 0.f, 1, 1, 1, 1),
357 PARAMS_ALL_ALG(nchw, nchw, 0.1f, 0.f, 3, 5, 7, 11)
360 INST_TEST_CASE(Simple_NCHW_SDPART,
361 PARAMS_ALL_ALG_SDPART(nchw, nchw, 0.1f, 0.f, 256, 64, 8, 16)
364 INST_TEST_CASE(Simple,
365 PARAMS_ALL_ALG(nchw, nChw8c, 0.1f, 0.f, 2, 8, 4, 4),
366 PARAMS_ALL_ALG(nChw8c, nchw, 0.1f, 0.f, 2, 16, 4, 4),
367 PARAMS_ALL_ALG(nchw, nchw, 0.1f, 0.f, 2, 16, 8, 8),
368 PARAMS_ALL_ALG(nChw8c, nChw8c, 0.1f, 0.f, 2, 16, 16, 8),
369 PARAMS_ALL_ALG(nhwc, nchw, 0.1f, 0.f, 2, 16, 10, 8),
370 PARAMS_ALL_ALG(nchw, nhwc, 0.1f, 0.f, 10, 10, 10, 10)
373 INST_TEST_CASE(Simple_SDPART,
374 PARAMS_ALL_ALG_SDPART(nchw, nChw8c, 0.1f, 0.f, 2, 8, 4, 4),
375 PARAMS_ALL_ALG_SDPART(nChw8c, nchw, 0.1f, 0.f, 2, 16, 4, 4),
376 PARAMS_ALL_ALG_SDPART(nchw, nchw, 0.1f, 0.f, 2, 16, 8, 8),
377 PARAMS_ALL_ALG_SDPART(nChw8c, nChw8c, 0.1f, 0.f, 2, 16, 16, 8),
378 PARAMS_ALL_ALG_SDPART(nhwc, nchw, 0.1f, 0.f, 2, 16, 10, 8),
379 PARAMS_ALL_ALG_SDPART(nchw, nhwc, 0.1f, 0.f, 10, 10, 10, 10)
382 INST_TEST_CASE(AlexNet_NCHW,
383 PARAMS_ALL_ALG(nchw, nchw, 0.f, 0.f, 2, 96, 55, 55),
384 PARAMS_ALL_ALG(nchw, nchw, 0.f, 0.f, 2, 256, 27, 27),
385 PARAMS_ALL_ALG(nchw, nchw, 0.f, 0.f, 2, 384, 13, 13),
386 PARAMS_ALL_ALG_SDPART(nchw, nchw, 0.f, 0.f, 2, 96, 55, 55),
387 PARAMS_ALL_ALG_SDPART(nchw, nchw, 0.f, 0.f, 2, 256, 27, 27),
388 PARAMS_ALL_ALG_SDPART(nchw, nchw, 0.f, 0.f, 2, 384, 13, 13)
391 INST_TEST_CASE(Simple_X,
392 PARAMS_ALL_ALG(x, x, 0.f, 0.f, 55)