Publishing 2019 R1 content
[platform/upstream/dldt.git] / inference-engine / thirdparty / mkl-dnn / tests / gtests / test_concat.cpp
1 /*******************************************************************************
2 * Copyright 2016-2018 Intel Corporation
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *     http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
16
17 #include "mkldnn_test_common.hpp"
18 #include "gtest/gtest.h"
19
20 #include "mkldnn.hpp"
21
22 namespace mkldnn {
23
24 struct concat_test_params {
25     const engine::kind engine_kind;
26     size_t concat_dimension;
27     std::vector<memory::format> srcs_format;
28     memory::format dst_format;
29     std::vector<memory::dims> srcs_cds;
30     memory::dims dst_cds;
31     bool expect_to_fail;
32     mkldnn_status_t expected_status;
33 };
34
35 template <typename data_t>
36 class concat_test: public ::testing::TestWithParam<concat_test_params> {
37     void check_data(const std::vector<memory> &srcs, const memory &dst,
38             int concat_dim) {
39         const data_t *dst_data = (const data_t *)dst.get_data_handle();
40         const auto &dst_d = dst.get_primitive_desc().desc();
41         const auto dst_dims = dst_d.data.dims;
42         const ptrdiff_t* dst_pdims = dst_d.data.layout_desc.blocking.padding_dims;
43
44         int acc_concat_dim = 0;
45         const auto ndims = dst_d.data.ndims;
46
47         for (size_t num = 0; num < srcs.size(); num++) {
48             const data_t *src_data = (const data_t *)srcs[num].get_data_handle();
49             const auto &src_d = srcs[num].get_primitive_desc().desc();
50             const ptrdiff_t* src_dims = src_d.data.dims;
51             const ptrdiff_t* src_pdims = src_d.data.layout_desc.blocking.padding_dims;
52
53             auto N = src_dims[0];
54             auto C = src_dims[1];
55             auto C_PADDED = src_pdims[1];
56             auto D = (ndims == 5) ? src_dims[2] : 1;
57             auto H = src_dims[ndims-2];
58             auto W = src_dims[ndims-1];
59
60             auto DST_C_PADDED = dst_pdims[1];
61             auto DST_D = (ndims == 5) ? dst_dims[2] : 1;
62             auto DST_H = dst_dims[ndims-2];
63             auto DST_W = dst_dims[ndims-1];
64
65             for (auto n = 0; n < N; n++)
66             for (auto c = 0; c < C; c++)
67             for (auto d = 0; d < D; d++)
68             for (auto h = 0; h < H; h++)
69             for (auto w = 0; w < W; w++) {
70                 auto src_idx = w + W*h + H*W*d + D*H*W*c + C_PADDED*D*H*W*n;
71
72                 auto adj_dst_dim = [&](int dim, int dim_sz) {
73                     if (concat_dim == dim) return dim_sz + acc_concat_dim;
74                     return dim_sz;
75                 };
76                 auto dst_idx = adj_dst_dim(ndims-1, w)
77                     + DST_W*adj_dst_dim(ndims-2, h)
78                     + DST_D*DST_H*DST_W*adj_dst_dim(1, c)
79                     + DST_C_PADDED*DST_D*DST_H*DST_W*adj_dst_dim(0, n);
80                 if (ndims == 5) dst_idx += DST_H*DST_W*adj_dst_dim(2, d);
81                 EXPECT_NEAR(src_data[map_index(src_d, src_idx)],
82                             dst_data[map_index(dst_d, dst_idx)],
83                             1e-7);
84             }
85
86             acc_concat_dim += src_dims[concat_dim];
87         }
88     }
89
90 protected:
91     virtual void SetUp() {
92         concat_test_params p
93             = ::testing::TestWithParam<decltype(p)>::GetParam();
94         catch_expected_failures([=](){Test();}, p.expect_to_fail,
95                     p.expected_status);
96     }
97
98     virtual void Test() {
99         concat_test_params p
100             = ::testing::TestWithParam<concat_test_params>::GetParam();
101
102         int src_dim_sum = 0;
103         for (size_t i = 0; i < p.srcs_cds.size(); i++) {
104             for (size_t dim = 0; dim < p.dst_cds.size(); dim++) {
105                 if (dim == p.concat_dimension)
106                     src_dim_sum += p.srcs_cds[i][dim];
107                 else if (p.expect_to_fail == false) {
108                     ASSERT_TRUE(p.srcs_cds[i][dim] == p.dst_cds[dim]);
109                 }
110             }
111         }
112
113         if (p.expect_to_fail == false) {
114             ASSERT_TRUE(src_dim_sum == p.dst_cds[p.concat_dimension]);
115         }
116
117         ASSERT_TRUE(p.engine_kind == engine::kind::cpu);
118         auto eng = engine(p.engine_kind, 0);
119         memory::data_type data_type = data_traits<data_t>::data_type;
120
121         std::vector<memory::primitive_desc> srcs_pd;
122         std::vector<memory> srcs;
123         for (size_t i = 0; i < p.srcs_cds.size(); i++) {
124             auto desc = memory::desc(p.srcs_cds[i], data_type, p.srcs_format[i]);
125             auto mpd = memory::primitive_desc(desc, eng);
126             auto src_memory = memory(mpd);
127             const size_t sz = src_memory.get_primitive_desc().get_size() / sizeof(data_t);
128             fill_data<data_t>(sz, (data_t *)src_memory.get_data_handle());
129             check_zero_tail<data_t>(1, src_memory);
130             srcs_pd.push_back(mpd);
131             srcs.push_back(src_memory);
132         }
133
134         auto dst_desc = memory::desc(p.dst_cds, data_type, p.dst_format);
135         auto concat_pd = concat::primitive_desc(dst_desc, static_cast<int>(p.concat_dimension), srcs_pd);
136         auto dst = memory(concat_pd.dst_primitive_desc());
137         fill_data<data_t>(dst.get_primitive_desc().get_size() / sizeof(data_t),
138             (data_t *)dst.get_data_handle());
139         check_zero_tail<data_t>(1, dst);
140
141         std::vector<primitive::at> inputs;
142         for (size_t i = 0; i < p.srcs_cds.size(); i++) {
143             inputs.push_back(srcs[i]);
144         }
145         auto c = concat(concat_pd, inputs, dst);
146
147         ASSERT_EQ(concat_pd.dst_primitive_desc().desc().data.format,
148                 dst_desc.data.format);
149         ASSERT_EQ(concat_pd.dst_primitive_desc().desc().data.ndims,
150                 dst_desc.data.ndims);
151
152         std::vector<primitive> pipeline;
153         pipeline.push_back(c);
154         auto s = stream(stream::kind::eager);
155         s.submit(pipeline).wait();
156
157         check_data(srcs, dst, static_cast<int>(p.concat_dimension));
158         check_zero_tail<data_t>(0, dst);
159     }
160 };
161
162 using concat_test_float = concat_test<float>;
163 using concat_test_s8 = concat_test<int8_t>;
164
165 TEST_P(concat_test_float, TestsConcat) {}
166 TEST_P(concat_test_s8, TestsConcat) {}
167
168 using fmt = memory::format;
169
170 INSTANTIATE_TEST_CASE_P(TestConcat_ZeroDim, concat_test_float, ::testing::Values(
171     concat_test_params{engine::kind::cpu, 1, {fmt::nChw8c, fmt::nChw16c}, fmt::nchw,  {{4, 0, 5, 5}, {4, 5, 5, 5}}, {4, 5, 5, 5}},
172     concat_test_params{engine::kind::cpu, 1, {fmt::nChw8c, fmt::nChw16c}, fmt::nchw,  {{4, 4, 5, 5}, {4, 0, 5, 5}}, {4, 4, 5, 5}},
173     concat_test_params{engine::kind::cpu, 1, {fmt::nChw8c, fmt::nChw8c}, fmt::nChw8c, {{4, 0, 5, 5}, {4, 5, 5, 5}}, {4, 5, 5, 5}},
174     concat_test_params{engine::kind::cpu, 1, {fmt::nChw8c, fmt::nChw8c}, fmt::nChw8c, {{4, 4, 5, 5}, {4, 0, 5, 5}}, {4, 4, 5, 5}},
175     concat_test_params{engine::kind::cpu, 1, {fmt::nChw8c, fmt::nChw16c}, fmt::nchw,  {{0, 4, 5, 5}, {0, 2, 5, 5}}, {0, 6, 5, 5}},
176     concat_test_params{engine::kind::cpu, 1, {fmt::nChw8c, fmt::nChw16c}, fmt::nchw,  {{2, 4, 0, 5}, {2, 2, 0, 5}}, {2, 6, 0, 5}},
177     concat_test_params{engine::kind::cpu, 1, {fmt::nhwc, fmt::nhwc}, fmt::nhwc,  {{0, 4, 5, 5}, {0, 2, 5, 5}}, {0, 6, 5, 5}},
178     concat_test_params{engine::kind::cpu, 1, {fmt::nchw, fmt::nchw}, fmt::nchw,  {{0, 4, 5, 5}, {0, 2, 5, 5}}, {0, 6, 5, 5}},
179     concat_test_params{engine::kind::cpu, 1, {fmt::nhwc, fmt::nhwc}, fmt::nhwc,  {{2, 4, 0, 5}, {2, 2, 0, 5}}, {2, 6, 0, 5}},
180     concat_test_params{engine::kind::cpu, 1, {fmt::nchw, fmt::nchw}, fmt::nchw,  {{2, 4, 0, 5}, {2, 2, 0, 5}}, {2, 6, 0, 5}}
181 ));
182
183 INSTANTIATE_TEST_CASE_P(TestConcat_EF, concat_test_float, ::testing::Values(
184     concat_test_params{engine::kind::cpu, 1, {fmt::nChw8c, fmt::nChw16c}, fmt::nchw,  {{4, 2, 5, 5}, {4, 5, 5, 5}}, {4, 5, 5, 5}, true, mkldnn_invalid_arguments},
185     concat_test_params{engine::kind::cpu, 2, {fmt::nChw8c, fmt::nChw16c}, fmt::nchw,  {{4, 2, 5, 5}, {4, 3, 5, 5}}, {4, 5, 5, 5}, true, mkldnn_invalid_arguments},
186     concat_test_params{engine::kind::cpu, 5, {fmt::nChw8c, fmt::nChw16c}, fmt::nchw,  {{4, 4, 5, 5}, {4, 0, 5, 5}}, {4, 4, 5, 5}, true, mkldnn_invalid_arguments},
187     concat_test_params{engine::kind::cpu, 1, {fmt::nChw8c, fmt::nChw8c}, fmt::nChw8c, {{4, -1, 5, 5}, {4, 5, 5, 5}}, {4, 5, 5, 5}, true, mkldnn_invalid_arguments},
188     concat_test_params{engine::kind::cpu, 1, {fmt::nChw8c, fmt::nChw8c}, fmt::nChw8c, {{4, 4, 5, 5}, {4, 4, 5, 5}}, {4, 4, 5, 5}, true, mkldnn_invalid_arguments},
189     concat_test_params{engine::kind::cpu, 1, {fmt::nChw8c, fmt::nChw16c}, fmt::nchw,  {{0, 4, 5, 5}, {0, 4, 5, 5}}, {0, 6, 5, 5}, true, mkldnn_invalid_arguments},
190     concat_test_params{engine::kind::cpu, 1, {fmt::nChw8c, fmt::nChw16c}, fmt::nchw,  {{2, 4, 2, 5}, {2, 2, 1, 5}}, {2, 6, 2, 5}, true, mkldnn_invalid_arguments},
191     concat_test_params{engine::kind::cpu, 1, {fmt::nhwc, fmt::nhwc}, fmt::nhwc,  {{1, 4, 5, 5}, {1, 2, 5, 5}}, {1, 7, 5, 5}, true, mkldnn_invalid_arguments},
192     concat_test_params{engine::kind::cpu, 1, {fmt::nchw, fmt::nchw}, fmt::nchw,  {{1, 4, 5, 5}, {1, 2, 5, 5}}, {1, 6, 6, 5}, true, mkldnn_invalid_arguments}
193 ));
194
195 INSTANTIATE_TEST_CASE_P(TestConcat_padded, concat_test_float, ::testing::Values(
196     concat_test_params{engine::kind::cpu, 1, {fmt::nChw16c, fmt::nChw16c}, fmt::nChw16c, {{4, 25, 5, 5}, {4, 45, 5, 5}}, {4, 70,  5,  5}, true, mkldnn_unimplemented},
197     concat_test_params{engine::kind::cpu, 1, {fmt::nChw16c, fmt::nChw16c}, fmt::nchw,    {{4, 25, 5, 5}, {4, 45, 5, 5}}, {4, 70,  5,  5}},
198     concat_test_params{engine::kind::cpu, 1, {fmt::nChw8c,  fmt::nChw8c},  fmt::nchw,    {{4, 25, 5, 5}, {4, 45, 5, 5}}, {4, 70,  5,  5}},
199     concat_test_params{engine::kind::cpu, 1, {fmt::nChw16c, fmt::nChw8c},  fmt::nchw,    {{4, 25, 5, 5}, {4, 45, 5, 5}}, {4, 70,  5,  5}},
200     concat_test_params{engine::kind::cpu, 1, {fmt::nChw8c,  fmt::nChw16c}, fmt::nchw,    {{4, 25, 5, 5}, {4, 45, 5, 5}}, {4, 70,  5,  5}},
201     concat_test_params{engine::kind::cpu, 1, {fmt::nChw16c, fmt::nChw16c}, fmt::nChw16c, {{4,  4, 5, 5}, {4,  6, 5, 5}}, {4, 10,  5,  5}, true, mkldnn_unimplemented},
202     concat_test_params{engine::kind::cpu, 1, {fmt::nChw16c, fmt::nChw16c}, fmt::nchw,    {{4,  4, 5, 5}, {4,  6, 5, 5}}, {4, 10,  5,  5}},
203     concat_test_params{engine::kind::cpu, 1, {fmt::nchw,    fmt::nChw16c}, fmt::nChw16c, {{4, 25, 5, 5}, {4, 45, 5, 5}}, {4, 70,  5,  5}, true, mkldnn_unimplemented},
204     concat_test_params{engine::kind::cpu, 1, {fmt::nchw,    fmt::nChw16c}, fmt::nchw,    {{4, 25, 5, 5}, {4, 45, 5, 5}}, {4, 70,  5,  5}},
205     // right border
206     concat_test_params{engine::kind::cpu, 1, {fmt::nChw16c, fmt::nChw16c}, fmt::nChw16c, {{4, 16, 5, 5}, {4,  3, 5, 5}}, {4, 19,  5,  5}},
207     concat_test_params{engine::kind::cpu, 1, {fmt::nChw16c, fmt::nChw16c}, fmt::nChw8c, {{4, 16, 5, 5}, {4,  3, 5, 5}}, {4, 19,  5,  5}},
208     concat_test_params{engine::kind::cpu, 1, {fmt::nChw8c,  fmt::nChw8c},  fmt::nChw8c, {{4, 8, 5, 5}, {4,  3, 5, 5}}, {4, 11,  5,  5}},
209     concat_test_params{engine::kind::cpu, 1, {fmt::nChw8c,  fmt::nChw16c}, fmt::nChw16c, {{4, 8, 5, 5}, {4,  3, 5, 5}}, {4, 11,  5,  5}},
210     // not over channels
211     concat_test_params{engine::kind::cpu, 2, {fmt::nChw16c, fmt::nChw16c}, fmt::nchw,    {{4, 25, 5, 5}, {4, 25, 5, 5}}, {4, 25, 10,  5}},
212     concat_test_params{engine::kind::cpu, 2, {fmt::nChw8c,  fmt::nChw8c},  fmt::nchw,    {{4, 25, 5, 5}, {4, 25, 5, 5}}, {4, 25, 10,  5}},
213     concat_test_params{engine::kind::cpu, 2, {fmt::nChw8c,  fmt::nChw16c}, fmt::nchw,    {{4, 25, 5, 5}, {4, 25, 5, 5}}, {4, 25, 10,  5}}
214 ));
215
216 INSTANTIATE_TEST_CASE_P(TestConcat3D, concat_test_float, ::testing::Values(
217     concat_test_params{engine::kind::cpu, 0,
218     {memory::format::ncdhw, memory::format::ncdhw}, memory::format::ncdhw,
219     {{2, 8, 3, 4, 5}, {2, 8, 3, 4, 5}}, {4, 8, 3, 4, 5}},
220     concat_test_params{engine::kind::cpu, 1,
221     {memory::format::ncdhw, memory::format::ncdhw}, memory::format::ncdhw,
222     {{2, 8, 3, 4, 5}, {2, 8, 3, 4, 5}}, {2, 16, 3, 4, 5}},
223     concat_test_params{engine::kind::cpu, 2,
224     {memory::format::ncdhw, memory::format::ncdhw}, memory::format::ncdhw,
225     {{2, 8, 3, 4, 5}, {2, 8, 3, 4, 5}}, {2, 8, 6, 4, 5}},
226     concat_test_params{engine::kind::cpu, 3,
227     {memory::format::ncdhw, memory::format::ncdhw}, memory::format::ncdhw,
228     {{2, 8, 3, 4, 5}, {2, 8, 3, 4, 5}}, {2, 8, 3, 8, 5}},
229     concat_test_params{engine::kind::cpu, 4,
230     {memory::format::ncdhw, memory::format::ncdhw}, memory::format::ncdhw,
231     {{2, 8, 3, 4, 5}, {2, 8, 3, 4, 5}}, {2, 8, 3, 4, 10}},
232     concat_test_params{engine::kind::cpu, 0,
233     {memory::format::nCdhw8c, memory::format::nCdhw8c}, memory::format::nCdhw8c,
234     {{2, 8, 3, 4, 5}, {2, 8, 3, 4, 5}}, {4, 8, 3, 4, 5}},
235     concat_test_params{engine::kind::cpu, 1,
236     {memory::format::nCdhw8c, memory::format::nCdhw8c}, memory::format::nCdhw8c,
237     {{2, 8, 3, 4, 5}, {2, 8, 3, 4, 5}}, {2, 16, 3, 4, 5}},
238     concat_test_params{engine::kind::cpu, 1,
239     {memory::format::nCdhw8c, memory::format::ncdhw}, memory::format::nCdhw8c,
240     {{2, 8, 3, 4, 5}, {2, 8, 3, 4, 5}}, {2, 16, 3, 4, 5}},
241     concat_test_params{engine::kind::cpu, 1,
242     {memory::format::ncdhw, memory::format::ncdhw}, memory::format::nCdhw8c,
243     {{2, 8, 3, 4, 5}, {2, 8, 3, 4, 5}}, {2, 16, 3, 4, 5}},
244     concat_test_params{engine::kind::cpu, 2,
245     {memory::format::nCdhw8c, memory::format::nCdhw8c}, memory::format::nCdhw8c,
246     {{2, 8, 3, 4, 5}, {2, 8, 3, 4, 5}}, {2, 8, 6, 4, 5}},
247     concat_test_params{engine::kind::cpu, 3,
248     {memory::format::nCdhw8c, memory::format::nCdhw8c}, memory::format::nCdhw8c,
249     {{2, 8, 3, 4, 5}, {2, 8, 3, 4, 5}}, {2, 8, 3, 8, 5}},
250     concat_test_params{engine::kind::cpu, 4,
251     {memory::format::nCdhw8c, memory::format::nCdhw8c}, memory::format::nCdhw8c,
252     {{2, 8, 3, 4, 5}, {2, 8, 3, 4, 5}}, {2, 8, 3, 4, 10}}
253 ));
254
255 INSTANTIATE_TEST_CASE_P(TestConcat, concat_test_float, ::testing::Values(
256     concat_test_params{engine::kind::cpu, 1,
257     {memory::format::nchw, memory::format::nchw}, memory::format::nchw,
258     {{2, 8, 3, 4}, {2, 8, 3, 4}}, {2, 16, 3, 4}},
259     concat_test_params{engine::kind::cpu, 1,
260     {memory::format::nChw8c, memory::format::nChw8c}, memory::format::nChw8c,
261     {{2, 16, 1, 1}, {2, 16, 1, 1}}, {2, 32, 1, 1}},
262     concat_test_params{engine::kind::cpu, 1,
263     {memory::format::nchw, memory::format::nchw}, memory::format::nChw8c,
264     {{2, 16, 1, 1}, {2, 16, 1, 1}}, {2, 32, 1, 1}},
265     concat_test_params{engine::kind::cpu, 1,
266     {memory::format::nhwc, memory::format::nhwc}, memory::format::nhwc,
267     {{2, 16, 1, 1}, {2, 16, 1, 1}}, {2, 32, 1, 1}},
268     concat_test_params{engine::kind::cpu, 1,
269     {memory::format::nChw8c, memory::format::nChw8c}, memory::format::nchw,
270     {{2, 16, 1, 1}, {2, 16, 1, 1}}, {2, 32, 1, 1}},
271
272     concat_test_params{engine::kind::cpu, 0,
273     {memory::format::nchw, memory::format::nchw}, memory::format::nchw,
274     {{2, 8, 3, 4}, {2, 8, 3, 4}}, {4, 8, 3, 4}},
275     concat_test_params{engine::kind::cpu, 0,
276     {memory::format::nChw8c, memory::format::nChw8c}, memory::format::nChw8c,
277     {{2, 16, 1, 1}, {2, 16, 1, 1}}, {4, 16, 1, 1}},
278     concat_test_params{engine::kind::cpu, 0,
279     {memory::format::nchw, memory::format::nchw}, memory::format::nChw8c,
280     {{2, 16, 1, 1}, {2, 16, 1, 1}}, {4, 16, 1, 1}},
281     concat_test_params{engine::kind::cpu, 0,
282     {memory::format::nChw8c, memory::format::nChw8c}, memory::format::nchw,
283     {{2, 16, 1, 1}, {2, 16, 1, 1}}, {4, 16, 1, 1}},
284
285     concat_test_params{engine::kind::cpu, 1,
286     {memory::format::nChw8c, memory::format::nChw8c}, memory::format::nChw8c,
287     {{2, 8, 1, 1}, {2, 8, 1, 1}}, {2, 16, 1, 1}},
288
289     concat_test_params{engine::kind::cpu, 1,
290     {memory::format::nChw8c, memory::format::nChw16c}, memory::format::nChw8c,
291     {{2, 8, 1, 1}, {2, 16, 1, 1}}, {2, 24, 1, 1}}
292 ));
293
294 INSTANTIATE_TEST_CASE_P(TestConcat, concat_test_s8, ::testing::Values(
295     concat_test_params{engine::kind::cpu, 1,
296     {memory::format::nhwc, memory::format::nhwc}, memory::format::nhwc,
297     {{2, 8, 3, 4}, {2, 8, 3, 4}}, {2, 16, 3, 4}},
298     concat_test_params{engine::kind::cpu, 1,
299     {memory::format::nchw, memory::format::nchw}, memory::format::nchw,
300     {{2, 8, 3, 4}, {2, 8, 3, 4}}, {2, 16, 3, 4}}
301     ));
302
303 }