1 /*******************************************************************************
2 * Copyright 2016-2018 Intel Corporation
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
8 * http://www.apache.org/licenses/LICENSE-2.0
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
17 #include "mkldnn_test_common.hpp"
18 #include "gtest/gtest.h"
21 #include "cpu_isa_traits.hpp"
25 struct sum_test_params {
26 const engine::kind engine_kind;
27 std::vector<memory::format> srcs_format;
28 memory::format dst_format;
30 std::vector<float> scale;
31 bool is_output_omitted;
33 mkldnn_status_t expected_status;
36 template <typename data_t, typename acc_t>
37 void check_data(const std::vector<memory> &srcs,
38 const std::vector<float> scale,
42 const data_t *dst_data = (const data_t *)dst.get_data_handle();
43 const auto &dst_d = dst.get_primitive_desc().desc();
44 const auto dst_dims = dst_d.data.dims;
46 mkldnn::impl::parallel_nd(dst_dims[0], dst_dims[1], dst_dims[2], dst_dims[3],
47 [&](int n, int c, int h, int w) {
48 acc_t src_sum = (acc_t)0;
49 for (size_t num = 0; num < srcs.size(); num++) {
50 const data_t *src_data =
51 (const data_t *)srcs[num].get_data_handle();
52 const auto &src_d = srcs[num].get_primitive_desc().desc();
53 const auto src_dims = src_d.data.dims;
57 + src_dims[2]*src_dims[3]*c
58 + src_dims[1]*src_dims[2]*src_dims[3]*n;
59 src_sum += acc_t(scale[num])* src_data[map_index(src_d, src_idx)];
61 src_sum = (std::max)((std::min)(src_sum,
62 (acc_t)(std::numeric_limits<data_t>::max)()),
63 (acc_t)std::numeric_limits<data_t>::lowest());
67 + dst_dims[2]*dst_dims[3]*c
68 + dst_dims[1]*dst_dims[2]*dst_dims[3]*n;
69 auto diff = src_sum - dst_data[map_index(dst_d, dst_idx)];
70 auto e = (std::abs(src_sum) > 1e-4) ? diff / src_sum : diff;
71 EXPECT_NEAR(e, 0.0, eps);
76 template <typename data_t, typename acc_t>
77 class sum_test: public ::testing::TestWithParam<sum_test_params> {
79 virtual void SetUp() {
81 = ::testing::TestWithParam<sum_test_params>::GetParam();
82 catch_expected_failures([=](){Test();}, p.expect_to_fail,
88 = ::testing::TestWithParam<sum_test_params>::GetParam();
90 const auto num_srcs = p.srcs_format.size();
92 ASSERT_TRUE(p.engine_kind == engine::kind::cpu);
93 auto eng = engine(p.engine_kind, 0);
94 memory::data_type data_type = data_traits<data_t>::data_type;
96 std::vector<memory::primitive_desc> srcs_pd;
97 std::vector<memory> srcs;
99 for (size_t i = 0; i < num_srcs; i++) {
100 bool is_fmt_blocked = p.srcs_format[i] == memory::format::blocked;
101 auto desc = memory::desc(p.dims, data_type, is_fmt_blocked
102 ? memory::format::nchw
104 if (is_fmt_blocked) desc.data.format = mkldnn_blocked;
105 auto mpd = memory::primitive_desc(desc, eng);
106 auto src_memory = memory(mpd);
108 src_memory.get_primitive_desc().get_size() / sizeof(data_t);
109 fill_data<data_t>(sz, (data_t *)src_memory.get_data_handle());
110 srcs_pd.push_back(mpd);
111 srcs.push_back(src_memory);
114 std::shared_ptr<memory> dst;
115 std::shared_ptr<sum::primitive_desc> sum_pd;
117 if (p.is_output_omitted) {
118 ASSERT_NO_THROW(sum_pd.reset(
119 new sum::primitive_desc(p.scale, srcs_pd)));
121 bool is_fmt_blocked = p.dst_format == memory::format::blocked;
122 auto dst_desc = memory::desc(p.dims, data_type, is_fmt_blocked
123 ? memory::format::nchw
125 if (is_fmt_blocked) dst_desc.data.format = mkldnn_blocked;
127 new sum::primitive_desc(dst_desc, p.scale, srcs_pd));
129 ASSERT_EQ(sum_pd->dst_primitive_desc().desc().data.format,
130 dst_desc.data.format);
131 ASSERT_EQ(sum_pd->dst_primitive_desc().desc().data.ndims,
132 dst_desc.data.ndims);
134 ASSERT_NO_THROW(dst.reset(new memory(sum_pd->dst_primitive_desc())));
136 data_t *dst_data = (data_t *)dst->get_data_handle();
138 dst->get_primitive_desc().get_size() / sizeof(data_t);
139 // overwriting dst to prevent false positives for test cases.
140 mkldnn::impl::parallel_nd((ptrdiff_t)sz,
141 [&](ptrdiff_t i) { dst_data[i] = -32; }
144 std::vector<primitive::at> inputs;
145 for (size_t i = 0; i < num_srcs; i++) {
146 inputs.push_back(srcs[i]);
148 auto c = sum(*sum_pd, inputs, *dst);
149 std::vector<primitive> pipeline;
150 pipeline.push_back(c);
151 auto s = stream(stream::kind::eager);
152 s.submit(pipeline).wait();
154 check_data<data_t, acc_t>(srcs, p.scale, *dst, 1.2e-7);
158 template <typename dst_data_t>
159 class sum_test_bf16: public ::testing::TestWithParam<sum_test_params> {
162 /* Skip test for systems with don't support avx512_core*/
163 SKIP_IF(!impl::cpu::mayiuse(impl::cpu::avx512_core),
164 "current ISA doesn't support bfloat16 data type");
166 = ::testing::TestWithParam<sum_test_params>::GetParam();
167 catch_expected_failures([=](){Test();}, p.expect_to_fail,
172 /*TODO: refactor to improve readability.
173 * Maybe share common code with sum_test by
174 * inheriting from sum_test<mkldnn_bfloat16_t>. */
176 = ::testing::TestWithParam<sum_test_params>::GetParam();
178 const auto num_srcs = p.srcs_format.size();
180 ASSERT_TRUE(p.engine_kind == engine::kind::cpu);
181 auto eng = engine(p.engine_kind, 0);
182 memory::data_type data_type = data_traits<dst_data_t>::data_type;
183 memory::data_type bf16_data_type = mkldnn::memory::data_type::bf16;
184 memory::data_type f32_data_type = mkldnn::memory::data_type::f32;
186 std::vector<memory::primitive_desc> srcs_pd_f32;
187 std::vector<memory::primitive_desc> srcs_pd_bf16;
188 std::vector<memory> srcs_f32;
189 std::vector<memory> srcs_bf16;
191 for (size_t i = 0; i < num_srcs; i++) {
192 bool is_fmt_blocked = memory::format::blocked == p.srcs_format[i];
193 auto fmt = is_fmt_blocked ? memory::format::nchw : p.srcs_format[i];
194 auto desc_f32 = memory::desc(p.dims, f32_data_type, fmt);
195 auto desc_bf16 = memory::desc(p.dims, bf16_data_type, fmt);
196 if (is_fmt_blocked) {
197 desc_f32.data.format = mkldnn_blocked;
198 desc_bf16.data.format = mkldnn_blocked;
200 auto mpd_f32 = memory::primitive_desc(desc_f32, eng);
201 auto mpd_bf16 = memory::primitive_desc(desc_bf16, eng);
203 auto src_memory_f32 = memory(mpd_f32);
204 auto src_memory_bf16 = memory(mpd_bf16);
206 const size_t sz = src_memory_f32.get_primitive_desc().get_size()
208 fill_data_bf16(sz, src_memory_bf16, src_memory_f32,
209 float(i), 2e-1f * (i + 1));
211 srcs_pd_f32.push_back(mpd_f32);
212 srcs_pd_bf16.push_back(mpd_bf16);
213 srcs_f32.push_back(src_memory_f32);
214 srcs_bf16.push_back(src_memory_bf16);
217 std::shared_ptr<memory> dst;
218 std::shared_ptr<memory> dst_f32;
219 std::shared_ptr<sum::primitive_desc> sum_pd;
221 if (p.is_output_omitted) {
222 ASSERT_NO_THROW(sum_pd.reset(
223 new sum::primitive_desc(p.scale, srcs_pd_bf16)));
225 bool is_fmt_blocked = memory::format::blocked == p.dst_format;
226 auto fmt = is_fmt_blocked ? memory::format::nchw : p.dst_format;
227 auto dst_desc = memory::desc(p.dims, data_type, fmt);
228 if (is_fmt_blocked) dst_desc.data.format = mkldnn_blocked;
230 new sum::primitive_desc(dst_desc, p.scale, srcs_pd_bf16));
232 ASSERT_EQ(sum_pd->dst_primitive_desc().desc().data.format,
233 dst_desc.data.format);
234 ASSERT_EQ(sum_pd->dst_primitive_desc().desc().data.ndims,
235 dst_desc.data.ndims);
237 ASSERT_NO_THROW(dst.reset(new memory(sum_pd->dst_primitive_desc())));
238 // Check automatically created dst descriptor data type
240 dst->get_primitive_desc().desc().data.data_type == data_type);
242 dst_data_t *dst_data = (dst_data_t *)dst->get_data_handle();
244 dst->get_primitive_desc().get_size() / sizeof(dst_data_t);
245 // overwriting dst to prevent false positives for test cases.
246 mkldnn::impl::parallel_nd((ptrdiff_t)sz,
250 if (data_type == bf16_data_type)
251 dst_data[i] = t.i[1];
253 dst_data[i] = (dst_data_t)-32;
256 std::vector<primitive::at> inputs;
257 for (size_t i = 0; i < num_srcs; i++) {
258 inputs.push_back(srcs_bf16[i]);
260 auto c = sum(*sum_pd, inputs, *dst);
261 std::vector<primitive> pipeline;
262 pipeline.push_back(c);
263 auto s = stream(stream::kind::eager);
264 s.submit(pipeline).wait();
266 bool is_bf16_dst = data_type == bf16_data_type;
268 bool is_fmt_blocked = memory::format::blocked == p.dst_format;
269 auto fmt = is_fmt_blocked ? memory::format::nchw : p.dst_format;
270 auto dst_desc_f32 = memory::desc(p.dims, f32_data_type, fmt);
271 if (is_fmt_blocked) dst_desc_f32.data.format = mkldnn_blocked;
273 auto dst_mpd_f32 = memory::primitive_desc(dst_desc_f32, eng);
274 ASSERT_NO_THROW(dst_f32.reset(new memory(dst_mpd_f32)));
275 cvt_bf16_to_ps((float *)dst_f32->get_data_handle(),
276 (mkldnn_bfloat16_t *)dst->get_data_handle(),
282 const double eps = is_bf16_dst ? 1e-2 : 1e-7;
283 check_data<float, float>(srcs_f32, p.scale, *dst_f32, eps);
288 #define CASE_CC(ifmt0, ifmt1, ofmt, dims_, ef, st) \
289 sum_test_params{engine::kind::cpu, \
290 {memory::format::ifmt0, memory::format::ifmt1}, memory::format::ofmt, \
291 memory::dims dims_, {1.0f, 1.0f}, 0, ef, st}
293 #define INST_TEST_CASE(test, omit_output) \
294 TEST_P(test, TestsSum) {} \
295 INSTANTIATE_TEST_CASE_P(TestSum, test, ::testing::Values( \
296 sum_test_params{engine::kind::cpu, \
297 {memory::format::blocked, memory::format::blocked}, memory::format::blocked, \
298 {2, 8, 4, 4}, {1.0f, 1.0f}, omit_output}, \
299 sum_test_params{engine::kind::cpu, \
300 {memory::format::nchw, memory::format::blocked}, memory::format::blocked, \
301 {2, 8, 4, 4}, {1.0f, 1.0f}, omit_output}, \
302 sum_test_params{engine::kind::cpu, \
303 {memory::format::blocked, memory::format::nchw}, memory::format::blocked, \
304 {2, 8, 4, 4}, {1.0f, 1.0f}, omit_output}, \
305 sum_test_params{engine::kind::cpu, \
306 {memory::format::nchw, memory::format::nchw}, memory::format::blocked, \
307 {2, 8, 4, 4}, {1.0f, 1.0f}, omit_output}, \
308 sum_test_params{engine::kind::cpu, \
309 {memory::format::nchw, memory::format::nChw8c}, memory::format::nchw, \
310 {0, 7, 4, 4}, {1.0f, 1.0f}, omit_output}, \
311 sum_test_params{engine::kind::cpu, \
312 {memory::format::nchw, memory::format::nChw8c}, memory::format::nchw, \
313 {1, 0, 4, 4}, {1.0f, 1.0f}, omit_output}, \
314 sum_test_params{engine::kind::cpu, \
315 {memory::format::nchw, memory::format::nChw8c}, memory::format::nchw, \
316 {1, 8, 0, 4}, {1.0f, 1.0f}, omit_output}, \
317 sum_test_params{engine::kind::cpu, \
318 {memory::format::nchw, memory::format::nChw8c}, memory::format::nchw, \
319 {-1, 8, 4, 4}, {1.0f, 1.0f}, omit_output, true, mkldnn_invalid_arguments}, \
321 sum_test_params{engine::kind::cpu, \
322 {memory::format::nchw, memory::format::nChw8c}, memory::format::nchw, \
323 {1, 1024, 38, 50}, {1.0f, 1.0f}, omit_output}, \
324 sum_test_params{engine::kind::cpu, \
325 {memory::format::nchw, memory::format::nchw}, memory::format::nchw, \
326 {2, 8, 2, 2}, {1.0f, 1.0f}, omit_output}, \
327 sum_test_params{engine::kind::cpu, \
328 {memory::format::nChw8c, memory::format::nChw8c}, memory::format::nChw8c, \
329 {2, 16, 3, 4}, {1.0f, 1.0f}, omit_output}, \
330 sum_test_params{engine::kind::cpu, \
331 {memory::format::nchw, memory::format::nchw}, memory::format::nChw8c, \
332 {2, 16, 2, 2}, {1.0f, 1.0f}, omit_output}, \
333 sum_test_params{engine::kind::cpu, \
334 {memory::format::nChw8c, memory::format::nChw8c}, memory::format::nchw, \
335 {2, 16, 3, 4}, {1.0f, 1.0f}, omit_output}, \
336 sum_test_params{engine::kind::cpu, \
337 {memory::format::nchw, memory::format::nchw}, memory::format::nchw, \
338 {2, 8, 2, 2}, {2.0f, 3.0f}, omit_output}, \
339 sum_test_params{engine::kind::cpu, \
340 {memory::format::nChw8c, memory::format::nChw8c}, memory::format::nChw8c,\
341 {2, 16, 3, 4}, {2.0f, 3.0f}, omit_output}, \
342 sum_test_params{engine::kind::cpu, \
343 {memory::format::nchw, memory::format::nchw}, memory::format::nChw8c, \
344 {2, 16, 2, 2}, {2.0f, 3.0f}, omit_output}, \
345 sum_test_params{engine::kind::cpu, \
346 {memory::format::nChw8c, memory::format::nChw8c}, memory::format::nchw, \
347 {2, 16, 3, 4}, {2.0f, 3.0f}, omit_output}, \
348 sum_test_params{engine::kind::cpu, \
349 {memory::format::nchw, memory::format::nChw8c}, memory::format::nchw, \
350 {5, 8, 3, 3}, {2.0f, 3.0f}, omit_output}, \
351 sum_test_params{engine::kind::cpu, \
352 {memory::format::nchw, memory::format::nChw8c}, memory::format::nchw, \
353 {32, 32, 13, 14}, {2.0f, 3.0f}, omit_output}, \
354 sum_test_params{engine::kind::cpu, \
355 {memory::format::nChw16c, memory::format::nChw8c}, \
356 memory::format::nChw16c, \
357 {2, 16, 3, 3}, {2.0f, 3.0f}, omit_output} \
360 INSTANTIATE_TEST_CASE_P(TestSumEF, test, ::testing::Values( \
361 sum_test_params{engine::kind::cpu, \
362 {memory::format::nchw, memory::format::nChw8c}, memory::format::nchw, \
363 {1, 8, 4 ,4}, {1.0f}, 0, true, mkldnn_invalid_arguments}, \
364 sum_test_params{engine::kind::cpu, \
365 {memory::format::nchw, memory::format::nChw8c}, memory::format::nchw, \
366 {2, 8, 4 ,4}, {0.1f}, 0, true, mkldnn_invalid_arguments} \
369 using sum_test_float_omit_output = sum_test<float,float>;
370 using sum_test_u8_omit_output = sum_test<uint8_t,float>;
371 using sum_test_s8_omit_output = sum_test<int8_t,float>;
372 using sum_test_s32_omit_output = sum_test<int32_t,float>;
374 using sum_test_float = sum_test<float,float>;
375 using sum_test_u8 = sum_test<uint8_t,float>;
376 using sum_test_s8 = sum_test<int8_t,float>;
377 using sum_test_s32 = sum_test<int32_t,float>;
379 using sum_cc_f32 = sum_test<float,float>;
380 TEST_P(sum_cc_f32, TestSumCornerCases) {}
381 INSTANTIATE_TEST_CASE_P(TestSumCornerCases, sum_cc_f32, ::testing::Values(
382 CASE_CC(nchw, nChw8c, nchw, ({0, 7, 4, 4}), false, mkldnn_success),
383 CASE_CC(nchw, nChw8c, nchw, ({1, 0, 4, 4}), false, mkldnn_success),
384 CASE_CC(nchw, nChw8c, nchw, ({1, 8, 0, 4}), false, mkldnn_success),
385 CASE_CC(nchw, nChw8c, nchw, ({-1, 8, 4, 4}), true, mkldnn_invalid_arguments)
389 INST_TEST_CASE(sum_test_float_omit_output, 1)
390 INST_TEST_CASE(sum_test_u8_omit_output, 1)
391 INST_TEST_CASE(sum_test_s8_omit_output, 1)
392 INST_TEST_CASE(sum_test_s32_omit_output, 1)
394 INST_TEST_CASE(sum_test_float, 0)
395 INST_TEST_CASE(sum_test_u8, 0)
396 INST_TEST_CASE(sum_test_s8, 0)
397 INST_TEST_CASE(sum_test_s32, 0)
399 using sum_test_bf16f32 = sum_test_bf16<float>;
400 using sum_test_bf16bf16 =
401 sum_test_bf16<prec_traits<mkldnn::memory::data_type::bf16>::type>;
402 using sum_test_bf16f32_omit_output = sum_test_bf16<float>;
403 using sum_test_bf16bf16_omit_output =
404 sum_test_bf16<prec_traits<mkldnn::memory::data_type::bf16>::type>;
406 #define INST_TEST_CASE_BF16(test, omit_output) \
407 TEST_P(test, TestsSum) {} \
408 INSTANTIATE_TEST_CASE_P(TestSum, test, ::testing::Values( \
409 sum_test_params{engine::kind::cpu, \
410 {memory::format::nChw16c, memory::format::nChw16c}, \
411 memory::format::nChw16c, \
412 {1, 16, 1, 1}, {2.0f, 3.0f}, omit_output}, \
413 sum_test_params{engine::kind::cpu, \
414 {memory::format::nchw, memory::format::nchw}, \
415 memory::format::nchw, \
416 {1, 16, 1, 1}, {2.0f, 3.0f}, omit_output}, \
417 sum_test_params{engine::kind::cpu, \
418 {memory::format::nchw, memory::format::nchw}, \
419 memory::format::nchw, \
420 {2, 16, 13, 7}, {2.0f, 3.0f}, omit_output}, \
421 sum_test_params{engine::kind::cpu, \
422 {memory::format::nchw, memory::format::nchw, \
423 memory::format::nchw, memory::format::nchw}, \
424 memory::format::nchw, \
425 {2, 16, 13, 7}, {2.0f, 3.0f, 4.0f, 5.0f}, omit_output}, \
426 sum_test_params{engine::kind::cpu, \
427 {memory::format::nchw, memory::format::nchw, memory::format::nchw}, \
428 memory::format::nchw, \
429 {2, 16, 13, 7}, {2.0f, 3.0f, 4.0f}, omit_output}, \
430 sum_test_params{engine::kind::cpu, \
431 {memory::format::nchw, memory::format::nchw, memory::format::nchw, \
432 memory::format::nchw, memory::format::nchw}, \
433 memory::format::nchw, \
434 {2, 16, 13, 7}, {2.0f, 3.0f, 4.0f, 5.0f, 6.0f}, omit_output}, \
435 sum_test_params{engine::kind::cpu, \
436 {memory::format::nchw, memory::format::nchw, memory::format::nchw}, \
437 memory::format::nchw, \
438 {2, 37, 13, 7}, {2.0f, 3.0f, 4.0f}, omit_output}, \
439 sum_test_params{engine::kind::cpu, \
440 {memory::format::nchw, memory::format::nchw, memory::format::nchw}, \
441 memory::format::nchw, \
442 {2, 16, 13, 7}, {2.0f, 3.0f, 4.0f}, omit_output}, \
443 sum_test_params{engine::kind::cpu, \
444 {memory::format::nChw16c, memory::format::nChw16c}, \
445 memory::format::nChw16c, \
446 {2, 16, 13, 7}, {2.0f, 3.0f}, omit_output}, \
447 sum_test_params{engine::kind::cpu, \
448 {memory::format::nChw16c, memory::format::nChw16c, \
449 memory::format::nChw16c}, \
450 memory::format::nChw16c, \
451 {2, 16, 13, 7}, {2.0f, 3.0f, 4.0f}, omit_output}, \
452 sum_test_params{engine::kind::cpu, \
453 {memory::format::nChw16c, memory::format::nChw16c, \
454 memory::format::nChw16c, memory::format::nChw16c, \
455 memory::format::nChw16c}, \
456 memory::format::nChw16c, \
457 {2, 16, 13, 7}, {2.0f, 3.0f, 4.0f, 5.0f, 6.0f}, omit_output}, \
458 sum_test_params{engine::kind::cpu, \
459 {memory::format::nChw16c, memory::format::nChw16c}, \
460 memory::format::nChw16c, \
461 {2, 128, 23, 15}, {2.5f, 0.05f}, omit_output} \
464 // TODO: merge with INST_TEST_CASE
465 INST_TEST_CASE_BF16(sum_test_bf16f32, 0)
466 INST_TEST_CASE_BF16(sum_test_bf16bf16, 0)
468 // Automatically created dst descriptor has bf16 data type
469 // so this test is not valid
470 //INST_TEST_CASE_BF16(sum_test_bf16f32_omit_output, 1)
471 INST_TEST_CASE_BF16(sum_test_bf16bf16_omit_output, 1)
473 #undef INST_TEST_CASE_BF16
474 #undef INST_TEST_CASE