1 /*******************************************************************************
2 * Copyright 2016-2018 Intel Corporation
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
8 * http://www.apache.org/licenses/LICENSE-2.0
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
17 #include "mkldnn_test_common.hpp"
18 #include "gtest/gtest.h"
24 struct concat_test_params {
25 const engine::kind engine_kind;
26 size_t concat_dimension;
27 std::vector<memory::format> srcs_format;
28 memory::format dst_format;
29 std::vector<memory::dims> srcs_cds;
32 mkldnn_status_t expected_status;
35 template <typename data_t>
36 class concat_test: public ::testing::TestWithParam<concat_test_params> {
37 void check_data(const std::vector<memory> &srcs, const memory &dst,
39 const data_t *dst_data = (const data_t *)dst.get_data_handle();
40 const auto &dst_d = dst.get_primitive_desc().desc();
41 const auto dst_dims = dst_d.data.dims;
42 const ptrdiff_t* dst_pdims = dst_d.data.layout_desc.blocking.padding_dims;
44 int acc_concat_dim = 0;
45 const auto ndims = dst_d.data.ndims;
47 for (size_t num = 0; num < srcs.size(); num++) {
48 const data_t *src_data = (const data_t *)srcs[num].get_data_handle();
49 const auto &src_d = srcs[num].get_primitive_desc().desc();
50 const ptrdiff_t* src_dims = src_d.data.dims;
51 const ptrdiff_t* src_pdims = src_d.data.layout_desc.blocking.padding_dims;
55 auto C_PADDED = src_pdims[1];
56 auto D = (ndims == 5) ? src_dims[2] : 1;
57 auto H = src_dims[ndims-2];
58 auto W = src_dims[ndims-1];
60 auto DST_C_PADDED = dst_pdims[1];
61 auto DST_D = (ndims == 5) ? dst_dims[2] : 1;
62 auto DST_H = dst_dims[ndims-2];
63 auto DST_W = dst_dims[ndims-1];
65 for (auto n = 0; n < N; n++)
66 for (auto c = 0; c < C; c++)
67 for (auto d = 0; d < D; d++)
68 for (auto h = 0; h < H; h++)
69 for (auto w = 0; w < W; w++) {
70 auto src_idx = w + W*h + H*W*d + D*H*W*c + C_PADDED*D*H*W*n;
72 auto adj_dst_dim = [&](int dim, int dim_sz) {
73 if (concat_dim == dim) return dim_sz + acc_concat_dim;
76 auto dst_idx = adj_dst_dim(ndims-1, w)
77 + DST_W*adj_dst_dim(ndims-2, h)
78 + DST_D*DST_H*DST_W*adj_dst_dim(1, c)
79 + DST_C_PADDED*DST_D*DST_H*DST_W*adj_dst_dim(0, n);
80 if (ndims == 5) dst_idx += DST_H*DST_W*adj_dst_dim(2, d);
81 EXPECT_NEAR(src_data[map_index(src_d, src_idx)],
82 dst_data[map_index(dst_d, dst_idx)],
86 acc_concat_dim += src_dims[concat_dim];
91 virtual void SetUp() {
93 = ::testing::TestWithParam<decltype(p)>::GetParam();
94 catch_expected_failures([=](){Test();}, p.expect_to_fail,
100 = ::testing::TestWithParam<concat_test_params>::GetParam();
103 for (size_t i = 0; i < p.srcs_cds.size(); i++) {
104 for (size_t dim = 0; dim < p.dst_cds.size(); dim++) {
105 if (dim == p.concat_dimension)
106 src_dim_sum += p.srcs_cds[i][dim];
107 else if (p.expect_to_fail == false) {
108 ASSERT_TRUE(p.srcs_cds[i][dim] == p.dst_cds[dim]);
113 if (p.expect_to_fail == false) {
114 ASSERT_TRUE(src_dim_sum == p.dst_cds[p.concat_dimension]);
117 ASSERT_TRUE(p.engine_kind == engine::kind::cpu);
118 auto eng = engine(p.engine_kind, 0);
119 memory::data_type data_type = data_traits<data_t>::data_type;
121 std::vector<memory::primitive_desc> srcs_pd;
122 std::vector<memory> srcs;
123 for (size_t i = 0; i < p.srcs_cds.size(); i++) {
124 auto desc = memory::desc(p.srcs_cds[i], data_type, p.srcs_format[i]);
125 auto mpd = memory::primitive_desc(desc, eng);
126 auto src_memory = memory(mpd);
127 const size_t sz = src_memory.get_primitive_desc().get_size() / sizeof(data_t);
128 fill_data<data_t>(sz, (data_t *)src_memory.get_data_handle());
129 check_zero_tail<data_t>(1, src_memory);
130 srcs_pd.push_back(mpd);
131 srcs.push_back(src_memory);
134 auto dst_desc = memory::desc(p.dst_cds, data_type, p.dst_format);
135 auto concat_pd = concat::primitive_desc(dst_desc, static_cast<int>(p.concat_dimension), srcs_pd);
136 auto dst = memory(concat_pd.dst_primitive_desc());
137 fill_data<data_t>(dst.get_primitive_desc().get_size() / sizeof(data_t),
138 (data_t *)dst.get_data_handle());
139 check_zero_tail<data_t>(1, dst);
141 std::vector<primitive::at> inputs;
142 for (size_t i = 0; i < p.srcs_cds.size(); i++) {
143 inputs.push_back(srcs[i]);
145 auto c = concat(concat_pd, inputs, dst);
147 ASSERT_EQ(concat_pd.dst_primitive_desc().desc().data.format,
148 dst_desc.data.format);
149 ASSERT_EQ(concat_pd.dst_primitive_desc().desc().data.ndims,
150 dst_desc.data.ndims);
152 std::vector<primitive> pipeline;
153 pipeline.push_back(c);
154 auto s = stream(stream::kind::eager);
155 s.submit(pipeline).wait();
157 check_data(srcs, dst, static_cast<int>(p.concat_dimension));
158 check_zero_tail<data_t>(0, dst);
162 using concat_test_float = concat_test<float>;
163 using concat_test_s8 = concat_test<int8_t>;
165 TEST_P(concat_test_float, TestsConcat) {}
166 TEST_P(concat_test_s8, TestsConcat) {}
168 using fmt = memory::format;
170 INSTANTIATE_TEST_CASE_P(TestConcat_ZeroDim, concat_test_float, ::testing::Values(
171 concat_test_params{engine::kind::cpu, 1, {fmt::nChw8c, fmt::nChw16c}, fmt::nchw, {{4, 0, 5, 5}, {4, 5, 5, 5}}, {4, 5, 5, 5}},
172 concat_test_params{engine::kind::cpu, 1, {fmt::nChw8c, fmt::nChw16c}, fmt::nchw, {{4, 4, 5, 5}, {4, 0, 5, 5}}, {4, 4, 5, 5}},
173 concat_test_params{engine::kind::cpu, 1, {fmt::nChw8c, fmt::nChw8c}, fmt::nChw8c, {{4, 0, 5, 5}, {4, 5, 5, 5}}, {4, 5, 5, 5}},
174 concat_test_params{engine::kind::cpu, 1, {fmt::nChw8c, fmt::nChw8c}, fmt::nChw8c, {{4, 4, 5, 5}, {4, 0, 5, 5}}, {4, 4, 5, 5}},
175 concat_test_params{engine::kind::cpu, 1, {fmt::nChw8c, fmt::nChw16c}, fmt::nchw, {{0, 4, 5, 5}, {0, 2, 5, 5}}, {0, 6, 5, 5}},
176 concat_test_params{engine::kind::cpu, 1, {fmt::nChw8c, fmt::nChw16c}, fmt::nchw, {{2, 4, 0, 5}, {2, 2, 0, 5}}, {2, 6, 0, 5}},
177 concat_test_params{engine::kind::cpu, 1, {fmt::nhwc, fmt::nhwc}, fmt::nhwc, {{0, 4, 5, 5}, {0, 2, 5, 5}}, {0, 6, 5, 5}},
178 concat_test_params{engine::kind::cpu, 1, {fmt::nchw, fmt::nchw}, fmt::nchw, {{0, 4, 5, 5}, {0, 2, 5, 5}}, {0, 6, 5, 5}},
179 concat_test_params{engine::kind::cpu, 1, {fmt::nhwc, fmt::nhwc}, fmt::nhwc, {{2, 4, 0, 5}, {2, 2, 0, 5}}, {2, 6, 0, 5}},
180 concat_test_params{engine::kind::cpu, 1, {fmt::nchw, fmt::nchw}, fmt::nchw, {{2, 4, 0, 5}, {2, 2, 0, 5}}, {2, 6, 0, 5}}
183 INSTANTIATE_TEST_CASE_P(TestConcat_EF, concat_test_float, ::testing::Values(
184 concat_test_params{engine::kind::cpu, 1, {fmt::nChw8c, fmt::nChw16c}, fmt::nchw, {{4, 2, 5, 5}, {4, 5, 5, 5}}, {4, 5, 5, 5}, true, mkldnn_invalid_arguments},
185 concat_test_params{engine::kind::cpu, 2, {fmt::nChw8c, fmt::nChw16c}, fmt::nchw, {{4, 2, 5, 5}, {4, 3, 5, 5}}, {4, 5, 5, 5}, true, mkldnn_invalid_arguments},
186 concat_test_params{engine::kind::cpu, 5, {fmt::nChw8c, fmt::nChw16c}, fmt::nchw, {{4, 4, 5, 5}, {4, 0, 5, 5}}, {4, 4, 5, 5}, true, mkldnn_invalid_arguments},
187 concat_test_params{engine::kind::cpu, 1, {fmt::nChw8c, fmt::nChw8c}, fmt::nChw8c, {{4, -1, 5, 5}, {4, 5, 5, 5}}, {4, 5, 5, 5}, true, mkldnn_invalid_arguments},
188 concat_test_params{engine::kind::cpu, 1, {fmt::nChw8c, fmt::nChw8c}, fmt::nChw8c, {{4, 4, 5, 5}, {4, 4, 5, 5}}, {4, 4, 5, 5}, true, mkldnn_invalid_arguments},
189 concat_test_params{engine::kind::cpu, 1, {fmt::nChw8c, fmt::nChw16c}, fmt::nchw, {{0, 4, 5, 5}, {0, 4, 5, 5}}, {0, 6, 5, 5}, true, mkldnn_invalid_arguments},
190 concat_test_params{engine::kind::cpu, 1, {fmt::nChw8c, fmt::nChw16c}, fmt::nchw, {{2, 4, 2, 5}, {2, 2, 1, 5}}, {2, 6, 2, 5}, true, mkldnn_invalid_arguments},
191 concat_test_params{engine::kind::cpu, 1, {fmt::nhwc, fmt::nhwc}, fmt::nhwc, {{1, 4, 5, 5}, {1, 2, 5, 5}}, {1, 7, 5, 5}, true, mkldnn_invalid_arguments},
192 concat_test_params{engine::kind::cpu, 1, {fmt::nchw, fmt::nchw}, fmt::nchw, {{1, 4, 5, 5}, {1, 2, 5, 5}}, {1, 6, 6, 5}, true, mkldnn_invalid_arguments}
195 INSTANTIATE_TEST_CASE_P(TestConcat_padded, concat_test_float, ::testing::Values(
196 concat_test_params{engine::kind::cpu, 1, {fmt::nChw16c, fmt::nChw16c}, fmt::nChw16c, {{4, 25, 5, 5}, {4, 45, 5, 5}}, {4, 70, 5, 5}, true, mkldnn_unimplemented},
197 concat_test_params{engine::kind::cpu, 1, {fmt::nChw16c, fmt::nChw16c}, fmt::nchw, {{4, 25, 5, 5}, {4, 45, 5, 5}}, {4, 70, 5, 5}},
198 concat_test_params{engine::kind::cpu, 1, {fmt::nChw8c, fmt::nChw8c}, fmt::nchw, {{4, 25, 5, 5}, {4, 45, 5, 5}}, {4, 70, 5, 5}},
199 concat_test_params{engine::kind::cpu, 1, {fmt::nChw16c, fmt::nChw8c}, fmt::nchw, {{4, 25, 5, 5}, {4, 45, 5, 5}}, {4, 70, 5, 5}},
200 concat_test_params{engine::kind::cpu, 1, {fmt::nChw8c, fmt::nChw16c}, fmt::nchw, {{4, 25, 5, 5}, {4, 45, 5, 5}}, {4, 70, 5, 5}},
201 concat_test_params{engine::kind::cpu, 1, {fmt::nChw16c, fmt::nChw16c}, fmt::nChw16c, {{4, 4, 5, 5}, {4, 6, 5, 5}}, {4, 10, 5, 5}, true, mkldnn_unimplemented},
202 concat_test_params{engine::kind::cpu, 1, {fmt::nChw16c, fmt::nChw16c}, fmt::nchw, {{4, 4, 5, 5}, {4, 6, 5, 5}}, {4, 10, 5, 5}},
203 concat_test_params{engine::kind::cpu, 1, {fmt::nchw, fmt::nChw16c}, fmt::nChw16c, {{4, 25, 5, 5}, {4, 45, 5, 5}}, {4, 70, 5, 5}, true, mkldnn_unimplemented},
204 concat_test_params{engine::kind::cpu, 1, {fmt::nchw, fmt::nChw16c}, fmt::nchw, {{4, 25, 5, 5}, {4, 45, 5, 5}}, {4, 70, 5, 5}},
206 concat_test_params{engine::kind::cpu, 1, {fmt::nChw16c, fmt::nChw16c}, fmt::nChw16c, {{4, 16, 5, 5}, {4, 3, 5, 5}}, {4, 19, 5, 5}},
207 concat_test_params{engine::kind::cpu, 1, {fmt::nChw16c, fmt::nChw16c}, fmt::nChw8c, {{4, 16, 5, 5}, {4, 3, 5, 5}}, {4, 19, 5, 5}},
208 concat_test_params{engine::kind::cpu, 1, {fmt::nChw8c, fmt::nChw8c}, fmt::nChw8c, {{4, 8, 5, 5}, {4, 3, 5, 5}}, {4, 11, 5, 5}},
209 concat_test_params{engine::kind::cpu, 1, {fmt::nChw8c, fmt::nChw16c}, fmt::nChw16c, {{4, 8, 5, 5}, {4, 3, 5, 5}}, {4, 11, 5, 5}},
211 concat_test_params{engine::kind::cpu, 2, {fmt::nChw16c, fmt::nChw16c}, fmt::nchw, {{4, 25, 5, 5}, {4, 25, 5, 5}}, {4, 25, 10, 5}},
212 concat_test_params{engine::kind::cpu, 2, {fmt::nChw8c, fmt::nChw8c}, fmt::nchw, {{4, 25, 5, 5}, {4, 25, 5, 5}}, {4, 25, 10, 5}},
213 concat_test_params{engine::kind::cpu, 2, {fmt::nChw8c, fmt::nChw16c}, fmt::nchw, {{4, 25, 5, 5}, {4, 25, 5, 5}}, {4, 25, 10, 5}}
216 INSTANTIATE_TEST_CASE_P(TestConcat3D, concat_test_float, ::testing::Values(
217 concat_test_params{engine::kind::cpu, 0,
218 {memory::format::ncdhw, memory::format::ncdhw}, memory::format::ncdhw,
219 {{2, 8, 3, 4, 5}, {2, 8, 3, 4, 5}}, {4, 8, 3, 4, 5}},
220 concat_test_params{engine::kind::cpu, 1,
221 {memory::format::ncdhw, memory::format::ncdhw}, memory::format::ncdhw,
222 {{2, 8, 3, 4, 5}, {2, 8, 3, 4, 5}}, {2, 16, 3, 4, 5}},
223 concat_test_params{engine::kind::cpu, 2,
224 {memory::format::ncdhw, memory::format::ncdhw}, memory::format::ncdhw,
225 {{2, 8, 3, 4, 5}, {2, 8, 3, 4, 5}}, {2, 8, 6, 4, 5}},
226 concat_test_params{engine::kind::cpu, 3,
227 {memory::format::ncdhw, memory::format::ncdhw}, memory::format::ncdhw,
228 {{2, 8, 3, 4, 5}, {2, 8, 3, 4, 5}}, {2, 8, 3, 8, 5}},
229 concat_test_params{engine::kind::cpu, 4,
230 {memory::format::ncdhw, memory::format::ncdhw}, memory::format::ncdhw,
231 {{2, 8, 3, 4, 5}, {2, 8, 3, 4, 5}}, {2, 8, 3, 4, 10}},
232 concat_test_params{engine::kind::cpu, 0,
233 {memory::format::nCdhw8c, memory::format::nCdhw8c}, memory::format::nCdhw8c,
234 {{2, 8, 3, 4, 5}, {2, 8, 3, 4, 5}}, {4, 8, 3, 4, 5}},
235 concat_test_params{engine::kind::cpu, 1,
236 {memory::format::nCdhw8c, memory::format::nCdhw8c}, memory::format::nCdhw8c,
237 {{2, 8, 3, 4, 5}, {2, 8, 3, 4, 5}}, {2, 16, 3, 4, 5}},
238 concat_test_params{engine::kind::cpu, 1,
239 {memory::format::nCdhw8c, memory::format::ncdhw}, memory::format::nCdhw8c,
240 {{2, 8, 3, 4, 5}, {2, 8, 3, 4, 5}}, {2, 16, 3, 4, 5}},
241 concat_test_params{engine::kind::cpu, 1,
242 {memory::format::ncdhw, memory::format::ncdhw}, memory::format::nCdhw8c,
243 {{2, 8, 3, 4, 5}, {2, 8, 3, 4, 5}}, {2, 16, 3, 4, 5}},
244 concat_test_params{engine::kind::cpu, 2,
245 {memory::format::nCdhw8c, memory::format::nCdhw8c}, memory::format::nCdhw8c,
246 {{2, 8, 3, 4, 5}, {2, 8, 3, 4, 5}}, {2, 8, 6, 4, 5}},
247 concat_test_params{engine::kind::cpu, 3,
248 {memory::format::nCdhw8c, memory::format::nCdhw8c}, memory::format::nCdhw8c,
249 {{2, 8, 3, 4, 5}, {2, 8, 3, 4, 5}}, {2, 8, 3, 8, 5}},
250 concat_test_params{engine::kind::cpu, 4,
251 {memory::format::nCdhw8c, memory::format::nCdhw8c}, memory::format::nCdhw8c,
252 {{2, 8, 3, 4, 5}, {2, 8, 3, 4, 5}}, {2, 8, 3, 4, 10}}
255 INSTANTIATE_TEST_CASE_P(TestConcat, concat_test_float, ::testing::Values(
256 concat_test_params{engine::kind::cpu, 1,
257 {memory::format::nchw, memory::format::nchw}, memory::format::nchw,
258 {{2, 8, 3, 4}, {2, 8, 3, 4}}, {2, 16, 3, 4}},
259 concat_test_params{engine::kind::cpu, 1,
260 {memory::format::nChw8c, memory::format::nChw8c}, memory::format::nChw8c,
261 {{2, 16, 1, 1}, {2, 16, 1, 1}}, {2, 32, 1, 1}},
262 concat_test_params{engine::kind::cpu, 1,
263 {memory::format::nchw, memory::format::nchw}, memory::format::nChw8c,
264 {{2, 16, 1, 1}, {2, 16, 1, 1}}, {2, 32, 1, 1}},
265 concat_test_params{engine::kind::cpu, 1,
266 {memory::format::nhwc, memory::format::nhwc}, memory::format::nhwc,
267 {{2, 16, 1, 1}, {2, 16, 1, 1}}, {2, 32, 1, 1}},
268 concat_test_params{engine::kind::cpu, 1,
269 {memory::format::nChw8c, memory::format::nChw8c}, memory::format::nchw,
270 {{2, 16, 1, 1}, {2, 16, 1, 1}}, {2, 32, 1, 1}},
272 concat_test_params{engine::kind::cpu, 0,
273 {memory::format::nchw, memory::format::nchw}, memory::format::nchw,
274 {{2, 8, 3, 4}, {2, 8, 3, 4}}, {4, 8, 3, 4}},
275 concat_test_params{engine::kind::cpu, 0,
276 {memory::format::nChw8c, memory::format::nChw8c}, memory::format::nChw8c,
277 {{2, 16, 1, 1}, {2, 16, 1, 1}}, {4, 16, 1, 1}},
278 concat_test_params{engine::kind::cpu, 0,
279 {memory::format::nchw, memory::format::nchw}, memory::format::nChw8c,
280 {{2, 16, 1, 1}, {2, 16, 1, 1}}, {4, 16, 1, 1}},
281 concat_test_params{engine::kind::cpu, 0,
282 {memory::format::nChw8c, memory::format::nChw8c}, memory::format::nchw,
283 {{2, 16, 1, 1}, {2, 16, 1, 1}}, {4, 16, 1, 1}},
285 concat_test_params{engine::kind::cpu, 1,
286 {memory::format::nChw8c, memory::format::nChw8c}, memory::format::nChw8c,
287 {{2, 8, 1, 1}, {2, 8, 1, 1}}, {2, 16, 1, 1}},
289 concat_test_params{engine::kind::cpu, 1,
290 {memory::format::nChw8c, memory::format::nChw16c}, memory::format::nChw8c,
291 {{2, 8, 1, 1}, {2, 16, 1, 1}}, {2, 24, 1, 1}}
294 INSTANTIATE_TEST_CASE_P(TestConcat, concat_test_s8, ::testing::Values(
295 concat_test_params{engine::kind::cpu, 1,
296 {memory::format::nhwc, memory::format::nhwc}, memory::format::nhwc,
297 {{2, 8, 3, 4}, {2, 8, 3, 4}}, {2, 16, 3, 4}},
298 concat_test_params{engine::kind::cpu, 1,
299 {memory::format::nchw, memory::format::nchw}, memory::format::nchw,
300 {{2, 8, 3, 4}, {2, 8, 3, 4}}, {2, 16, 3, 4}}