Publishing R5 content (#72)
[platform/upstream/dldt.git] / inference-engine / thirdparty / mkl-dnn / tests / gtests / test_sum.cpp
1 /*******************************************************************************
2 * Copyright 2016-2018 Intel Corporation
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *     http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
16
17 #include "mkldnn_test_common.hpp"
18 #include "gtest/gtest.h"
19
20 #include "mkldnn.hpp"
21
22 namespace mkldnn {
23
24 struct sum_test_params {
25     const engine::kind engine_kind;
26     std::vector<memory::format> srcs_format;
27     memory::format dst_format;
28     memory::dims dims;
29     std::vector<float> scale;
30     bool is_output_omitted;
31     bool expect_to_fail;
32     mkldnn_status_t expected_status;
33 };
34
35
36 template <typename data_t, typename acc_t>
37 class sum_test: public ::testing::TestWithParam<sum_test_params> {
38     void check_data(const std::vector<memory> &srcs,
39                     const std::vector<float> scale,
40                     const memory &dst)
41     {
42         const data_t *dst_data = (const data_t *)dst.get_data_handle();
43         const auto &dst_d = dst.get_primitive_desc().desc();
44         const auto dst_dims = dst_d.data.dims;
45
46         mkldnn::impl::parallel_nd(dst_dims[0], dst_dims[1], dst_dims[2], dst_dims[3],
47             [&](int n, int c, int h, int w) {
48             acc_t src_sum = 0.0;
49             for (size_t num = 0; num < srcs.size(); num++) {
50                 const data_t *src_data =
51                     (const data_t *)srcs[num].get_data_handle();
52                 const auto &src_d = srcs[num].get_primitive_desc().desc();
53                 const auto src_dims = src_d.data.dims;
54
55                 auto src_idx = w
56                     + src_dims[3]*h
57                     + src_dims[2]*src_dims[3]*c
58                     + src_dims[1]*src_dims[2]*src_dims[3]*n;
59                 if (num == 0) {
60                     src_sum = data_t(scale[num]) * src_data[map_index(src_d, src_idx)];
61                 } else {
62                     src_sum += data_t(scale[num])* src_data[map_index(src_d, src_idx)];
63                 }
64
65                 src_sum = std::max(std::min(src_sum,
66                             std::numeric_limits<acc_t>::max()),
67                         std::numeric_limits<acc_t>::lowest());
68
69             }
70
71             auto dst_idx = w
72                 + dst_dims[3]*h
73                 + dst_dims[2]*dst_dims[3]*c
74                 + dst_dims[1]*dst_dims[2]*dst_dims[3]*n;
75             auto diff = src_sum - dst_data[map_index(dst_d, dst_idx)];
76             auto e = (std::abs(src_sum) > 1e-4) ? diff / src_sum : diff;
77             EXPECT_NEAR(e, 0.0, 1.2e-7);
78             }
79         );
80     }
81
82 protected:
83     virtual void SetUp() {
84         sum_test_params p
85             = ::testing::TestWithParam<sum_test_params>::GetParam();
86         catch_expected_failures([=](){Test();}, p.expect_to_fail,
87                     p.expected_status);
88     }
89
90     void Test() {
91         sum_test_params p
92             = ::testing::TestWithParam<sum_test_params>::GetParam();
93
94         const auto num_srcs = p.srcs_format.size();
95
96         ASSERT_TRUE(p.engine_kind == engine::kind::cpu);
97         auto eng = engine(p.engine_kind, 0);
98         memory::data_type data_type = data_traits<data_t>::data_type;
99
100         std::vector<memory::primitive_desc> srcs_pd;
101         std::vector<memory> srcs;
102
103         for (size_t i = 0; i < num_srcs; i++) {
104             bool is_fmt_blocked = p.srcs_format[i] == memory::format::blocked;
105             auto desc = memory::desc(p.dims, data_type, is_fmt_blocked
106                 ? memory::format::nchw
107                 : p.srcs_format[i]);
108             if (is_fmt_blocked) desc.data.format = mkldnn_blocked;
109             auto mpd = memory::primitive_desc(desc, eng);
110             auto src_memory = memory(mpd);
111             const size_t sz =
112                 src_memory.get_primitive_desc().get_size() / sizeof(data_t);
113             fill_data<data_t>(sz, (data_t *)src_memory.get_data_handle());
114             srcs_pd.push_back(mpd);
115             srcs.push_back(src_memory);
116         }
117
118         std::shared_ptr<memory> dst;
119         std::shared_ptr<sum::primitive_desc> sum_pd;
120
121         if (p.is_output_omitted) {
122             ASSERT_NO_THROW(sum_pd.reset(
123                 new sum::primitive_desc(p.scale, srcs_pd)));
124         } else {
125             bool is_fmt_blocked = p.dst_format == memory::format::blocked;
126             auto dst_desc = memory::desc(p.dims, data_type, is_fmt_blocked
127                 ? memory::format::nchw
128                 : p.dst_format);
129             if (is_fmt_blocked) dst_desc.data.format = mkldnn_blocked;
130             sum_pd.reset(
131                 new sum::primitive_desc(dst_desc, p.scale, srcs_pd));
132
133             ASSERT_EQ(sum_pd->dst_primitive_desc().desc().data.format,
134                     dst_desc.data.format);
135             ASSERT_EQ(sum_pd->dst_primitive_desc().desc().data.ndims,
136                     dst_desc.data.ndims);
137         }
138         ASSERT_NO_THROW(dst.reset(new memory(sum_pd->dst_primitive_desc())));
139
140         data_t *dst_data = (data_t *)dst->get_data_handle();
141         const size_t sz =
142             dst->get_primitive_desc().get_size() / sizeof(data_t);
143         // overwriting dst to prevent false positives for test cases.
144         mkldnn::impl::parallel_nd((ptrdiff_t)sz,
145             [&](ptrdiff_t i) { dst_data[i] = -32; }
146         );
147
148         std::vector<primitive::at> inputs;
149         for (size_t i = 0; i < num_srcs; i++) {
150             inputs.push_back(srcs[i]);
151         }
152         auto c = sum(*sum_pd, inputs, *dst);
153         std::vector<primitive> pipeline;
154         pipeline.push_back(c);
155         auto s = stream(stream::kind::eager);
156         s.submit(pipeline).wait();
157
158         check_data(srcs, p.scale, *dst);
159     }
160 };
161
162 /* corner cases */
163 #define CASE_CC(ifmt0, ifmt1, ofmt, dims_, ef, st) \
164     sum_test_params{engine::kind::cpu, \
165         {memory::format::ifmt0, memory::format::ifmt1}, memory::format::ofmt, \
166         memory::dims dims_, {1.0f, 1.0f}, 0, ef, st}
167
168 #define INST_TEST_CASE(test, omit_output) \
169 TEST_P(test, TestsSum) {} \
170 INSTANTIATE_TEST_CASE_P(TestSum, test, ::testing::Values( \
171     sum_test_params{engine::kind::cpu, \
172     {memory::format::blocked, memory::format::blocked}, memory::format::blocked, \
173     {2, 8, 4, 4}, {1.0f, 1.0f}, omit_output}, \
174     sum_test_params{engine::kind::cpu, \
175     {memory::format::nchw, memory::format::blocked}, memory::format::blocked, \
176     {2, 8, 4, 4}, {1.0f, 1.0f}, omit_output}, \
177     sum_test_params{engine::kind::cpu, \
178     {memory::format::blocked, memory::format::nchw}, memory::format::blocked, \
179     {2, 8, 4, 4}, {1.0f, 1.0f}, omit_output}, \
180     sum_test_params{engine::kind::cpu, \
181     {memory::format::nchw, memory::format::nchw}, memory::format::blocked, \
182     {2, 8, 4, 4}, {1.0f, 1.0f}, omit_output}, \
183     sum_test_params{engine::kind::cpu, \
184     {memory::format::nchw, memory::format::nChw8c}, memory::format::nchw, \
185     {0, 7, 4, 4}, {1.0f, 1.0f}, omit_output}, \
186     sum_test_params{engine::kind::cpu, \
187     {memory::format::nchw, memory::format::nChw8c}, memory::format::nchw, \
188     {1, 0, 4, 4}, {1.0f, 1.0f}, omit_output}, \
189     sum_test_params{engine::kind::cpu, \
190     {memory::format::nchw, memory::format::nChw8c}, memory::format::nchw, \
191     {1, 8, 0, 4}, {1.0f, 1.0f}, omit_output}, \
192     sum_test_params{engine::kind::cpu, \
193     {memory::format::nchw, memory::format::nChw8c}, memory::format::nchw, \
194     {-1, 8, 4, 4}, {1.0f, 1.0f}, omit_output, true, mkldnn_invalid_arguments}, \
195     \
196     sum_test_params{engine::kind::cpu, \
197     {memory::format::nchw, memory::format::nChw8c}, memory::format::nchw, \
198     {1, 1024, 38, 50}, {1.0f, 1.0f}, omit_output}, \
199     sum_test_params{engine::kind::cpu, \
200     {memory::format::nchw, memory::format::nchw}, memory::format::nchw, \
201     {2, 8, 2, 2}, {1.0f, 1.0f}, omit_output}, \
202     sum_test_params{engine::kind::cpu, \
203     {memory::format::nChw8c, memory::format::nChw8c}, memory::format::nChw8c, \
204     {2, 16, 3, 4}, {1.0f, 1.0f}, omit_output}, \
205     sum_test_params{engine::kind::cpu, \
206     {memory::format::nchw, memory::format::nchw}, memory::format::nChw8c, \
207     {2, 16, 2, 2}, {1.0f, 1.0f}, omit_output}, \
208     sum_test_params{engine::kind::cpu, \
209     {memory::format::nChw8c, memory::format::nChw8c}, memory::format::nchw, \
210     {2, 16, 3, 4}, {1.0f, 1.0f}, omit_output}, \
211     sum_test_params{engine::kind::cpu, \
212     {memory::format::nchw, memory::format::nchw}, memory::format::nchw, \
213     {2, 8, 2, 2}, {2.0f, 3.0f}, omit_output}, \
214     sum_test_params{engine::kind::cpu, \
215     {memory::format::nChw8c, memory::format::nChw8c}, memory::format::nChw8c,\
216     {2, 16, 3, 4}, {2.0f, 3.0f}, omit_output}, \
217     sum_test_params{engine::kind::cpu, \
218     {memory::format::nchw, memory::format::nchw}, memory::format::nChw8c, \
219     {2, 16, 2, 2}, {2.0f, 3.0f}, omit_output}, \
220     sum_test_params{engine::kind::cpu, \
221     {memory::format::nChw8c, memory::format::nChw8c}, memory::format::nchw, \
222     {2, 16, 3, 4}, {2.0f, 3.0f}, omit_output}, \
223     sum_test_params{engine::kind::cpu, \
224     {memory::format::nchw, memory::format::nChw8c}, memory::format::nchw, \
225     {5, 8, 3, 3}, {2.0f, 3.0f}, omit_output}, \
226     sum_test_params{engine::kind::cpu, \
227     {memory::format::nchw, memory::format::nChw8c}, memory::format::nchw, \
228     {32, 32, 13, 14}, {2.0f, 3.0f}, omit_output}, \
229     sum_test_params{engine::kind::cpu, \
230     {memory::format::nChw16c, memory::format::nChw8c}, \
231     memory::format::nChw16c, \
232     {2, 16, 3, 3}, {2.0f, 3.0f}, omit_output} \
233 )); \
234 \
235 INSTANTIATE_TEST_CASE_P(TestSumEF, test, ::testing::Values( \
236     sum_test_params{engine::kind::cpu, \
237     {memory::format::nchw, memory::format::nChw8c}, memory::format::nchw, \
238     {1, 8, 4 ,4}, {1.0f}, 0, true, mkldnn_invalid_arguments}, \
239     sum_test_params{engine::kind::cpu, \
240     {memory::format::nchw, memory::format::nChw8c}, memory::format::nchw, \
241     {2, 8, 4 ,4}, {0.1f}, 0, true, mkldnn_invalid_arguments} \
242 ));
243
244 using sum_test_float_omit_output = sum_test<float,float>;
245 using sum_test_u8_omit_output = sum_test<uint8_t,float>;
246 using sum_test_s8_omit_output = sum_test<int8_t,float>;
247 using sum_test_s32_omit_output = sum_test<int32_t,float>;
248
249 using sum_test_float = sum_test<float,float>;
250 using sum_test_u8 = sum_test<uint8_t,float>;
251 using sum_test_s8 = sum_test<int8_t,float>;
252 using sum_test_s32 = sum_test<int32_t,float>;
253
254 using sum_cc_f32 = sum_test<float,float>;
255 TEST_P(sum_cc_f32, TestSumCornerCases) {}
256 INSTANTIATE_TEST_CASE_P(TestSumCornerCases, sum_cc_f32, ::testing::Values(
257     CASE_CC(nchw, nChw8c, nchw, ({0, 7, 4, 4}), false, mkldnn_success),
258     CASE_CC(nchw, nChw8c, nchw, ({1, 0, 4, 4}), false, mkldnn_success),
259     CASE_CC(nchw, nChw8c, nchw, ({1, 8, 0, 4}), false, mkldnn_success),
260     CASE_CC(nchw, nChw8c, nchw, ({-1, 8, 4, 4}), true, mkldnn_invalid_arguments)
261     ));
262 #undef CASE_CC
263
264 INST_TEST_CASE(sum_test_float_omit_output, 1)
265 INST_TEST_CASE(sum_test_u8_omit_output, 1)
266 INST_TEST_CASE(sum_test_s8_omit_output, 1)
267 INST_TEST_CASE(sum_test_s32_omit_output, 1)
268
269 INST_TEST_CASE(sum_test_float, 0)
270 INST_TEST_CASE(sum_test_u8, 0)
271 INST_TEST_CASE(sum_test_s8, 0)
272 INST_TEST_CASE(sum_test_s32, 0)
273
274 #undef INST_TEST_CASE
275 }