1 /*******************************************************************************
2 * Copyright 2016-2018 Intel Corporation
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
8 * http://www.apache.org/licenses/LICENSE-2.0
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
17 #include "mkldnn_test_common.hpp"
18 #include "gtest/gtest.h"
24 struct sum_test_params {
25 const engine::kind engine_kind;
26 std::vector<memory::format> srcs_format;
27 memory::format dst_format;
29 std::vector<float> scale;
30 bool is_output_omitted;
32 mkldnn_status_t expected_status;
36 template <typename data_t, typename acc_t>
37 class sum_test: public ::testing::TestWithParam<sum_test_params> {
38 void check_data(const std::vector<memory> &srcs,
39 const std::vector<float> scale,
42 const data_t *dst_data = (const data_t *)dst.get_data_handle();
43 const auto &dst_d = dst.get_primitive_desc().desc();
44 const auto dst_dims = dst_d.data.dims;
46 mkldnn::impl::parallel_nd(dst_dims[0], dst_dims[1], dst_dims[2], dst_dims[3],
47 [&](int n, int c, int h, int w) {
49 for (size_t num = 0; num < srcs.size(); num++) {
50 const data_t *src_data =
51 (const data_t *)srcs[num].get_data_handle();
52 const auto &src_d = srcs[num].get_primitive_desc().desc();
53 const auto src_dims = src_d.data.dims;
57 + src_dims[2]*src_dims[3]*c
58 + src_dims[1]*src_dims[2]*src_dims[3]*n;
60 src_sum = data_t(scale[num]) * src_data[map_index(src_d, src_idx)];
62 src_sum += data_t(scale[num])* src_data[map_index(src_d, src_idx)];
65 src_sum = std::max(std::min(src_sum,
66 std::numeric_limits<acc_t>::max()),
67 std::numeric_limits<acc_t>::lowest());
73 + dst_dims[2]*dst_dims[3]*c
74 + dst_dims[1]*dst_dims[2]*dst_dims[3]*n;
75 auto diff = src_sum - dst_data[map_index(dst_d, dst_idx)];
76 auto e = (std::abs(src_sum) > 1e-4) ? diff / src_sum : diff;
77 EXPECT_NEAR(e, 0.0, 1.2e-7);
83 virtual void SetUp() {
85 = ::testing::TestWithParam<sum_test_params>::GetParam();
86 catch_expected_failures([=](){Test();}, p.expect_to_fail,
92 = ::testing::TestWithParam<sum_test_params>::GetParam();
94 const auto num_srcs = p.srcs_format.size();
96 ASSERT_TRUE(p.engine_kind == engine::kind::cpu);
97 auto eng = engine(p.engine_kind, 0);
98 memory::data_type data_type = data_traits<data_t>::data_type;
100 std::vector<memory::primitive_desc> srcs_pd;
101 std::vector<memory> srcs;
103 for (size_t i = 0; i < num_srcs; i++) {
104 bool is_fmt_blocked = p.srcs_format[i] == memory::format::blocked;
105 auto desc = memory::desc(p.dims, data_type, is_fmt_blocked
106 ? memory::format::nchw
108 if (is_fmt_blocked) desc.data.format = mkldnn_blocked;
109 auto mpd = memory::primitive_desc(desc, eng);
110 auto src_memory = memory(mpd);
112 src_memory.get_primitive_desc().get_size() / sizeof(data_t);
113 fill_data<data_t>(sz, (data_t *)src_memory.get_data_handle());
114 srcs_pd.push_back(mpd);
115 srcs.push_back(src_memory);
118 std::shared_ptr<memory> dst;
119 std::shared_ptr<sum::primitive_desc> sum_pd;
121 if (p.is_output_omitted) {
122 ASSERT_NO_THROW(sum_pd.reset(
123 new sum::primitive_desc(p.scale, srcs_pd)));
125 bool is_fmt_blocked = p.dst_format == memory::format::blocked;
126 auto dst_desc = memory::desc(p.dims, data_type, is_fmt_blocked
127 ? memory::format::nchw
129 if (is_fmt_blocked) dst_desc.data.format = mkldnn_blocked;
131 new sum::primitive_desc(dst_desc, p.scale, srcs_pd));
133 ASSERT_EQ(sum_pd->dst_primitive_desc().desc().data.format,
134 dst_desc.data.format);
135 ASSERT_EQ(sum_pd->dst_primitive_desc().desc().data.ndims,
136 dst_desc.data.ndims);
138 ASSERT_NO_THROW(dst.reset(new memory(sum_pd->dst_primitive_desc())));
140 data_t *dst_data = (data_t *)dst->get_data_handle();
142 dst->get_primitive_desc().get_size() / sizeof(data_t);
143 // overwriting dst to prevent false positives for test cases.
144 mkldnn::impl::parallel_nd((ptrdiff_t)sz,
145 [&](ptrdiff_t i) { dst_data[i] = -32; }
148 std::vector<primitive::at> inputs;
149 for (size_t i = 0; i < num_srcs; i++) {
150 inputs.push_back(srcs[i]);
152 auto c = sum(*sum_pd, inputs, *dst);
153 std::vector<primitive> pipeline;
154 pipeline.push_back(c);
155 auto s = stream(stream::kind::eager);
156 s.submit(pipeline).wait();
158 check_data(srcs, p.scale, *dst);
163 #define CASE_CC(ifmt0, ifmt1, ofmt, dims_, ef, st) \
164 sum_test_params{engine::kind::cpu, \
165 {memory::format::ifmt0, memory::format::ifmt1}, memory::format::ofmt, \
166 memory::dims dims_, {1.0f, 1.0f}, 0, ef, st}
168 #define INST_TEST_CASE(test, omit_output) \
169 TEST_P(test, TestsSum) {} \
170 INSTANTIATE_TEST_CASE_P(TestSum, test, ::testing::Values( \
171 sum_test_params{engine::kind::cpu, \
172 {memory::format::blocked, memory::format::blocked}, memory::format::blocked, \
173 {2, 8, 4, 4}, {1.0f, 1.0f}, omit_output}, \
174 sum_test_params{engine::kind::cpu, \
175 {memory::format::nchw, memory::format::blocked}, memory::format::blocked, \
176 {2, 8, 4, 4}, {1.0f, 1.0f}, omit_output}, \
177 sum_test_params{engine::kind::cpu, \
178 {memory::format::blocked, memory::format::nchw}, memory::format::blocked, \
179 {2, 8, 4, 4}, {1.0f, 1.0f}, omit_output}, \
180 sum_test_params{engine::kind::cpu, \
181 {memory::format::nchw, memory::format::nchw}, memory::format::blocked, \
182 {2, 8, 4, 4}, {1.0f, 1.0f}, omit_output}, \
183 sum_test_params{engine::kind::cpu, \
184 {memory::format::nchw, memory::format::nChw8c}, memory::format::nchw, \
185 {0, 7, 4, 4}, {1.0f, 1.0f}, omit_output}, \
186 sum_test_params{engine::kind::cpu, \
187 {memory::format::nchw, memory::format::nChw8c}, memory::format::nchw, \
188 {1, 0, 4, 4}, {1.0f, 1.0f}, omit_output}, \
189 sum_test_params{engine::kind::cpu, \
190 {memory::format::nchw, memory::format::nChw8c}, memory::format::nchw, \
191 {1, 8, 0, 4}, {1.0f, 1.0f}, omit_output}, \
192 sum_test_params{engine::kind::cpu, \
193 {memory::format::nchw, memory::format::nChw8c}, memory::format::nchw, \
194 {-1, 8, 4, 4}, {1.0f, 1.0f}, omit_output, true, mkldnn_invalid_arguments}, \
196 sum_test_params{engine::kind::cpu, \
197 {memory::format::nchw, memory::format::nChw8c}, memory::format::nchw, \
198 {1, 1024, 38, 50}, {1.0f, 1.0f}, omit_output}, \
199 sum_test_params{engine::kind::cpu, \
200 {memory::format::nchw, memory::format::nchw}, memory::format::nchw, \
201 {2, 8, 2, 2}, {1.0f, 1.0f}, omit_output}, \
202 sum_test_params{engine::kind::cpu, \
203 {memory::format::nChw8c, memory::format::nChw8c}, memory::format::nChw8c, \
204 {2, 16, 3, 4}, {1.0f, 1.0f}, omit_output}, \
205 sum_test_params{engine::kind::cpu, \
206 {memory::format::nchw, memory::format::nchw}, memory::format::nChw8c, \
207 {2, 16, 2, 2}, {1.0f, 1.0f}, omit_output}, \
208 sum_test_params{engine::kind::cpu, \
209 {memory::format::nChw8c, memory::format::nChw8c}, memory::format::nchw, \
210 {2, 16, 3, 4}, {1.0f, 1.0f}, omit_output}, \
211 sum_test_params{engine::kind::cpu, \
212 {memory::format::nchw, memory::format::nchw}, memory::format::nchw, \
213 {2, 8, 2, 2}, {2.0f, 3.0f}, omit_output}, \
214 sum_test_params{engine::kind::cpu, \
215 {memory::format::nChw8c, memory::format::nChw8c}, memory::format::nChw8c,\
216 {2, 16, 3, 4}, {2.0f, 3.0f}, omit_output}, \
217 sum_test_params{engine::kind::cpu, \
218 {memory::format::nchw, memory::format::nchw}, memory::format::nChw8c, \
219 {2, 16, 2, 2}, {2.0f, 3.0f}, omit_output}, \
220 sum_test_params{engine::kind::cpu, \
221 {memory::format::nChw8c, memory::format::nChw8c}, memory::format::nchw, \
222 {2, 16, 3, 4}, {2.0f, 3.0f}, omit_output}, \
223 sum_test_params{engine::kind::cpu, \
224 {memory::format::nchw, memory::format::nChw8c}, memory::format::nchw, \
225 {5, 8, 3, 3}, {2.0f, 3.0f}, omit_output}, \
226 sum_test_params{engine::kind::cpu, \
227 {memory::format::nchw, memory::format::nChw8c}, memory::format::nchw, \
228 {32, 32, 13, 14}, {2.0f, 3.0f}, omit_output}, \
229 sum_test_params{engine::kind::cpu, \
230 {memory::format::nChw16c, memory::format::nChw8c}, \
231 memory::format::nChw16c, \
232 {2, 16, 3, 3}, {2.0f, 3.0f}, omit_output} \
235 INSTANTIATE_TEST_CASE_P(TestSumEF, test, ::testing::Values( \
236 sum_test_params{engine::kind::cpu, \
237 {memory::format::nchw, memory::format::nChw8c}, memory::format::nchw, \
238 {1, 8, 4 ,4}, {1.0f}, 0, true, mkldnn_invalid_arguments}, \
239 sum_test_params{engine::kind::cpu, \
240 {memory::format::nchw, memory::format::nChw8c}, memory::format::nchw, \
241 {2, 8, 4 ,4}, {0.1f}, 0, true, mkldnn_invalid_arguments} \
244 using sum_test_float_omit_output = sum_test<float,float>;
245 using sum_test_u8_omit_output = sum_test<uint8_t,float>;
246 using sum_test_s8_omit_output = sum_test<int8_t,float>;
247 using sum_test_s32_omit_output = sum_test<int32_t,float>;
249 using sum_test_float = sum_test<float,float>;
250 using sum_test_u8 = sum_test<uint8_t,float>;
251 using sum_test_s8 = sum_test<int8_t,float>;
252 using sum_test_s32 = sum_test<int32_t,float>;
254 using sum_cc_f32 = sum_test<float,float>;
255 TEST_P(sum_cc_f32, TestSumCornerCases) {}
256 INSTANTIATE_TEST_CASE_P(TestSumCornerCases, sum_cc_f32, ::testing::Values(
257 CASE_CC(nchw, nChw8c, nchw, ({0, 7, 4, 4}), false, mkldnn_success),
258 CASE_CC(nchw, nChw8c, nchw, ({1, 0, 4, 4}), false, mkldnn_success),
259 CASE_CC(nchw, nChw8c, nchw, ({1, 8, 0, 4}), false, mkldnn_success),
260 CASE_CC(nchw, nChw8c, nchw, ({-1, 8, 4, 4}), true, mkldnn_invalid_arguments)
264 INST_TEST_CASE(sum_test_float_omit_output, 1)
265 INST_TEST_CASE(sum_test_u8_omit_output, 1)
266 INST_TEST_CASE(sum_test_s8_omit_output, 1)
267 INST_TEST_CASE(sum_test_s32_omit_output, 1)
269 INST_TEST_CASE(sum_test_float, 0)
270 INST_TEST_CASE(sum_test_u8, 0)
271 INST_TEST_CASE(sum_test_s8, 0)
272 INST_TEST_CASE(sum_test_s32, 0)
274 #undef INST_TEST_CASE