Publishing 2019 R1 content
[platform/upstream/dldt.git] / inference-engine / tests / unit / engines / mkldnn / graph / layers / internal / graph_concat_test.cpp
1 // Copyright (C) 2018-2019 Intel Corporation
2 // SPDX-License-Identifier: Apache-2.0
3 //
4
5 #include <gtest/gtest.h>
6 #include <gmock/gmock-spec-builders.h>
7 #include "mkldnn_plugin/mkldnn_graph.h"
8
9 #include "test_graph.hpp"
10
11 #include "single_layer_common.hpp"
12 #include <mkldnn_plugin/mkldnn_extension_utils.h>
13 #include <unordered_set>
14 #include <inference_engine/cnn_network_impl.hpp>
15 #include "tests_common.hpp"
16
17 using namespace ::testing;
18 using namespace std;
19 using namespace mkldnn;
20
21
22 struct concat_test_params {
23     // Formats: NCHW, NCDHW
24     vector<size_t> in1;
25     vector<size_t> in2;
26
27     size_t axis;
28
29     size_t num_prim_desc;
30
31     MKLDNNPlugin::impl_desc_type selectedType;
32
33     std::vector<std::function<void(MKLDNNPlugin::PrimitiveDescInfo)>> comp;
34 };
35
36 class MKLDNNGraphConcatTests: public TestsCommon,
37                               public WithParamInterface<concat_test_params> {
38     std::string model_t = R"V0G0N(
39 <net name="ConcatOnly" version="3" precision="FP32" batch="1">
40     <layers>
41         <layer name="in1" type="Input" precision="FP32" id="1">
42             <output>
43                 <port id="1">__SRC_DIMS_1__
44                 </port>
45             </output>
46         </layer>
47         <layer name="in2" type="Input" precision="FP32" id="2">
48             <output>
49                 <port id="2">__SRC_DIMS_2__
50                 </port>
51             </output>
52         </layer>
53         <layer name="con" id="3" type="Concat" precision="FP32">
54             <concat_data axis="_AXIS_"/>
55             <input>
56                 <port id="1">__SRC_DIMS_1__
57                 </port>
58                 <port id="2">__SRC_DIMS_2__
59                 </port>
60             </input>
61             <output>
62                 <port id="3">__DST_DIMS__
63                 </port>
64             </output>
65         </layer>
66     </layers>
67     <edges>
68         <edge from-layer="1" from-port="1" to-layer="3" to-port="1"/>
69         <edge from-layer="2" from-port="2" to-layer="3" to-port="2"/>
70     </edges>
71 </net>
72 )V0G0N";
73
74     std::string getModel(concat_test_params p) {
75         std::string model = model_t;
76         std::string s_dims;
77         for (auto& dim : p.in1) {
78             s_dims += "\n                    <dim>";
79             s_dims += std::to_string(dim) + "</dim>";
80         }
81         REPLACE_WITH_STR(model, "__SRC_DIMS_1__", s_dims);
82
83         s_dims = "";
84         for (auto& dim : p.in2) {
85             s_dims += "\n                    <dim>";
86             s_dims += std::to_string(dim) + "</dim>";
87         }
88         REPLACE_WITH_STR(model, "__SRC_DIMS_2__", s_dims);
89
90         s_dims = "";
91         for (size_t i = 0; i < p.in1.size(); i++) {
92             size_t dim = p.axis == i ? p.in1[i] + p.in2[i] : p.in1[i];
93             s_dims += "\n                    <dim>";
94             s_dims += std::to_string(dim) + "</dim>";
95         }
96         REPLACE_WITH_STR(model, "__DST_DIMS__", s_dims);
97
98         REPLACE_WITH_NUM(model, "_AXIS_", p.axis);
99         return model;
100     }
101
102 protected:
103     virtual void TearDown() {
104     }
105
106     virtual void SetUp() {
107         try {
108             TestsCommon::SetUp();
109             concat_test_params p = ::testing::WithParamInterface<concat_test_params>::GetParam();
110             std::string model = getModel(p);
111
112             InferenceEngine::CNNNetReader net_reader;
113             ASSERT_NO_THROW(net_reader.ReadNetwork(model.data(), model.length()));
114
115             MKLDNNGraphTestClass graph;
116             graph.CreateGraph(net_reader.getNetwork());
117             auto& nodes = graph.getNodes();
118             for (int i = 0; i < nodes.size(); i++) {
119                 if (nodes[i]->getType() == MKLDNNPlugin::Concatenation) {
120                     ASSERT_EQ(p.num_prim_desc, nodes[i]->getSupportedPrimitiveDescriptors().size());
121                     for (size_t j = 0; j < p.num_prim_desc && j < p.comp.size(); j++) {
122                         p.comp.at(j)(nodes[i]->getSupportedPrimitiveDescriptors().at(j));
123                     }
124                     ASSERT_NE(nullptr, nodes[i]->getSelectedPrimitiveDescriptor());
125                     ASSERT_EQ(p.selectedType, nodes[i]->getSelectedPrimitiveDescriptor()->getImplementationType());
126                 }
127             }
128             ASSERT_LE(3, nodes.size());
129
130             InferenceEngine::SizeVector dims_src1 = p.in1;
131             InferenceEngine::SizeVector dims_src2 = p.in2;
132             InferenceEngine::Layout layout = InferenceEngine::ANY;
133             switch (p.in1.size()) {
134                 case 4:
135                     layout = InferenceEngine::NCHW;
136                     break;
137                 case 5:
138                     layout = InferenceEngine::NCDHW;
139                     break;
140             }
141
142             InferenceEngine::Blob::Ptr src1 = InferenceEngine::make_shared_blob<float, const InferenceEngine::SizeVector>(InferenceEngine::Precision::FP32, layout, dims_src1);
143             src1->allocate();
144
145             fill_data(src1->buffer(), src1->size());
146             InferenceEngine::Blob::Ptr src2 = InferenceEngine::make_shared_blob<float, const InferenceEngine::SizeVector>(InferenceEngine::Precision::FP32, layout, dims_src2);
147             src2->allocate();
148             fill_data(src2->buffer(), src2->size());
149             InferenceEngine::BlobMap srcs;
150             srcs.insert(std::pair<std::string, InferenceEngine::Blob::Ptr>("in1", src1));
151             srcs.insert(std::pair<std::string, InferenceEngine::Blob::Ptr>("in2", src2));
152
153             InferenceEngine::OutputsDataMap out;
154             out = net_reader.getNetwork().getOutputsInfo();
155             InferenceEngine::BlobMap outputBlobs;
156
157             std::pair<std::string, InferenceEngine::DataPtr> item = *out.begin();
158
159             InferenceEngine::TBlob<float>::Ptr output;
160             output = InferenceEngine::make_shared_blob<float>(item.second->getTensorDesc());
161             output->allocate();
162             outputBlobs[item.first] = output;
163
164             graph.Infer(srcs, outputBlobs);
165
166             // Compare
167             float *src1_ptr = src1->buffer();
168             size_t src1_size = src1->size();
169             float *src2_ptr = src2->buffer();
170             size_t src2_size = src2->size();
171             float *dst_ptr = output->buffer();
172             size_t dst_size = output->size();
173
174             int len1 = 1, len2 = 1, cycles;
175             for (int dim = p.axis; dim < output->dims().size(); dim++) {
176                 len1 *= src1->dims()[dim];
177                 len2 *= src2->dims()[dim];
178             }
179             cycles = p.axis;
180
181
182             int index1 = 0, index2 = 0, index = 0;
183             for (int cycle = 0; cycle < cycles; cycle ++) {
184                 for (int i1 = 0; i1 < len1; i1++) {
185                     if (src1_ptr[index1] != dst_ptr[index])
186                     {
187                         FAIL() << "index: " << index << " src: " << src1_ptr[index1] << ", dst: " << dst_ptr[index];
188                     }
189                     index1++; index++;
190                 }
191                 for (int i2 = 0; i2 < len2; i2++) {
192                     if (src2_ptr[index2] != dst_ptr[index])
193                     {
194                         FAIL() << "index: " << index << " src: " << src2_ptr[index2] << ", dst: " << dst_ptr[index];
195                     }
196                     index2++; index++;
197                 }
198             }
199         } catch (const InferenceEngine::details::InferenceEngineException &e) {
200             FAIL() << e.what();
201         }
202     }
203 };
204
205 TEST_P(MKLDNNGraphConcatTests, TestsConcat) {}
206
207 INSTANTIATE_TEST_CASE_P(
208         TestsConcat, MKLDNNGraphConcatTests,
209         ::testing::Values(
210                 concat_test_params {
211                         {1, 3, 3, 5},
212                         {1, 3, 3, 5},
213                         1, 2
214                 },
215                 concat_test_params {
216                         {1, 7, 1, 5},
217                         {1, 7, 9, 5},
218                         2, 1, MKLDNNPlugin::impl_desc_type::ref
219                 },
220                 concat_test_params {
221                         {1, 2, 3, 5, 3},
222                         {1, 5, 3, 5, 3},
223                         1, 2
224                 },
225                 concat_test_params {
226                         {1, 32, 3, 4, 5},
227                         {1, 32, 3, 4, 5},
228                         1, 6, MKLDNNPlugin::impl_desc_type::unknown
229                 },
230                 concat_test_params {
231                         {1, 64, 16, 16, 16, 1},
232                         {1, 64, 16, 16, 16, 1},
233                         5, 1, MKLDNNPlugin::impl_desc_type::ref
234                 }));
235
236 class MKLDNNGraphDynBatchConcatTests: public TestsCommon, public WithParamInterface<concat_test_params> {
237     std::string model_t = R"V0G0N(
238 <net name="ConcatOnly" version="2" precision="FP32" batch="1">
239     <layers>
240         <layer name="in1" type="Input" precision="FP32" id="1">
241             <output>
242                 <port id="1">
243                     <dim>1</dim>__SRC_DIMS_1__
244                 </port>
245             </output>
246         </layer>
247         <layer name="in2" type="Input" precision="FP32" id="2">
248             <output>
249                 <port id="2">
250                     <dim>1</dim>__SRC_DIMS_2__
251                 </port>
252             </output>
253         </layer>
254         <layer name="con" id="3" type="Concat" precision="FP32">
255             <concat_data axis="_AXIS_"/>
256             <input>
257                 <port id="1">
258                     <dim>1</dim>__SRC_DIMS_1__
259                 </port>
260                 <port id="2">
261                     <dim>1</dim>__SRC_DIMS_2__
262                 </port>
263             </input>
264             <output>
265                 <port id="3">
266                     <dim>1</dim>__DST_DIMS__
267                 </port>
268             </output>
269         </layer>
270     </layers>
271     <edges>
272         <edge from-layer="1" from-port="1" to-layer="3" to-port="1"/>
273         <edge from-layer="2" from-port="2" to-layer="3" to-port="2"/>
274     </edges>
275 </net>
276 )V0G0N";
277
278     std::string getModel(concat_test_params p) {
279         std::string model = model_t;
280         std::string s_dims;
281         for (size_t i = 1; i < p.in1.size(); i++) {
282             s_dims += "\n                    <dim>";
283             s_dims += std::to_string(p.in1[i]) + "</dim>";
284         }
285         REPLACE_WITH_STR(model, "__SRC_DIMS_1__", s_dims);
286
287         s_dims = "";
288         for (size_t i = 1; i < p.in2.size(); i++) {
289             s_dims += "\n                    <dim>";
290             s_dims += std::to_string(p.in2[i]) + "</dim>";
291         }
292         REPLACE_WITH_STR(model, "__SRC_DIMS_2__", s_dims);
293
294         s_dims = "";
295         for (size_t i = 1; i < p.in1.size(); i++) {
296             size_t dim = p.axis == i ? p.in1[i] + p.in2[i] : p.in1[i];
297             s_dims += "\n                    <dim>";
298             s_dims += std::to_string(dim) + "</dim>";
299         }
300         REPLACE_WITH_STR(model, "__DST_DIMS__", s_dims);
301
302         REPLACE_WITH_NUM(model, "_AXIS_", p.axis);
303         return model;
304     }
305
306 protected:
307     virtual void TearDown() {
308     }
309
310     virtual void SetUp() {
311         try {
312             TestsCommon::SetUp();
313             concat_test_params p = ::testing::WithParamInterface<concat_test_params>::GetParam();
314             std::string model = getModel(p);
315             size_t MB = p.in1[0];
316             if (MB < 2)
317                 MB = 2;
318
319             InferenceEngine::CNNNetReader net_reader;
320             ASSERT_NO_THROW(net_reader.ReadNetwork(model.data(), model.length()));
321             InferenceEngine::CNNNetwork network = net_reader.getNetwork();
322             auto implNet = dynamic_cast<InferenceEngine::details::CNNNetworkImpl *>(&((InferenceEngine::ICNNNetwork&)network));
323             ASSERT_NE(nullptr, implNet) << "Failed to cast ICNNNetwork to CNNNetworkImpl";
324             InferenceEngine::ResponseDesc resp;
325             InferenceEngine::StatusCode sts  = implNet->setBatchSizeReshape(MB, &resp);
326             ASSERT_EQ((int)InferenceEngine::StatusCode::OK, sts) << resp.msg;
327
328             MKLDNNGraphTestClass graph;
329             graph.setProperty({{InferenceEngine::PluginConfigParams::KEY_DYN_BATCH_ENABLED, InferenceEngine::PluginConfigParams::YES}});
330             graph.CreateGraph(net_reader.getNetwork());
331
332             InferenceEngine::SizeVector dims_src1 = p.in1;
333             InferenceEngine::SizeVector dims_src2 = p.in2;
334             InferenceEngine::Layout layout = InferenceEngine::ANY;
335             switch (p.in1.size()) {
336                 case 4:
337                     layout = InferenceEngine::NCHW;
338                     break;
339                 case 5:
340                     layout = InferenceEngine::NCDHW;
341                     break;
342             }
343
344             InferenceEngine::Blob::Ptr src1 = InferenceEngine::make_shared_blob<float, const InferenceEngine::SizeVector>(InferenceEngine::Precision::FP32, layout, dims_src1);
345             src1->allocate();
346
347             fill_data(src1->buffer(), src1->size());
348             InferenceEngine::Blob::Ptr src2 = InferenceEngine::make_shared_blob<float, const InferenceEngine::SizeVector>(InferenceEngine::Precision::FP32, layout, dims_src2);
349             src2->allocate();
350             fill_data(src2->buffer(), src2->size());
351             InferenceEngine::BlobMap srcs;
352             srcs.insert(std::pair<std::string, InferenceEngine::Blob::Ptr>("in1", src1));
353             srcs.insert(std::pair<std::string, InferenceEngine::Blob::Ptr>("in2", src2));
354
355             InferenceEngine::OutputsDataMap out;
356             out = net_reader.getNetwork().getOutputsInfo();
357             InferenceEngine::BlobMap outputBlobs;
358
359             std::pair<std::string, InferenceEngine::DataPtr> item = *out.begin();
360
361             InferenceEngine::TBlob<float>::Ptr output;
362             output = InferenceEngine::make_shared_blob<float>(item.second->getTensorDesc());
363             output->allocate();
364             outputBlobs[item.first] = output;
365
366
367             auto checkConcat = [](const MKLDNNPlugin::MKLDNNNodePtr& node) {
368                 return node->getType() == MKLDNNPlugin::Concatenation;
369             };
370
371             MKLDNNGraphTestClass::CheckDynBatchType checkType = MKLDNNGraphTestClass::CheckDynBatchType::Both;
372             if (p.selectedType == MKLDNNPlugin::impl_desc_type::unknown)
373                 checkType = MKLDNNGraphTestClass::CheckDynBatchType::Child;
374
375             graph.checkDynBatch(srcs, outputBlobs, MB, MB, checkConcat, checkType);
376             graph.checkDynBatch(srcs, outputBlobs, 1, MB, checkConcat, checkType);
377         } catch (const InferenceEngine::details::InferenceEngineException &e) {
378             FAIL() << e.what();
379         }
380     }
381 };
382
383 TEST_P(MKLDNNGraphDynBatchConcatTests, TestsDynBatchConcat) {}
384
385
386 INSTANTIATE_TEST_CASE_P(
387         TestsDynBatchConcat, MKLDNNGraphDynBatchConcatTests,
388         ::testing::Values(
389                 concat_test_params {
390                         {1, 7, 2, 5},
391                         {1, 7, 2, 5},
392                         2, 1, MKLDNNPlugin::impl_desc_type::ref
393                 },
394                 concat_test_params {
395                         {1, 7, 2, 5},
396                         {1, 13, 2, 5},
397                         1, 2, MKLDNNPlugin::impl_desc_type::unknown
398                 },
399                 concat_test_params {
400                         {3, 7, 2, 5},
401                         {3, 13, 2, 5},
402                         1, 2, MKLDNNPlugin::impl_desc_type::unknown
403                 },
404                 concat_test_params {
405                         {1, 7, 2, 13},
406                         {1, 7, 2, 17},
407                         3, 1, MKLDNNPlugin::impl_desc_type::ref
408                 },
409                 concat_test_params {
410                         {1, 8, 8, 16},
411                         {1, 16, 8, 16},
412                         1, 4, MKLDNNPlugin::impl_desc_type::unknown
413                 },
414                 concat_test_params {
415                         {2, 2, 3, 3},
416                         {2, 3, 3, 3},
417                         1, 2, MKLDNNPlugin::impl_desc_type::unknown
418                 },
419                 concat_test_params {
420                         {2, 2, 3, 3, 3},
421                         {2, 3, 3, 3, 3},
422                         1, 2, MKLDNNPlugin::impl_desc_type::unknown
423                 }));
424
425 struct concat_param {
426     std::string name;
427     size_t axis;
428     size_t input1;
429     size_t input2;
430 };
431
432 struct two_concat_test_params {
433     // Formats: NCHW, NCDHW
434     vector<size_t> in1;
435     vector<size_t> in2;
436     vector<size_t> in3;
437
438     concat_param concat1;
439     concat_param concat2;
440 };
441
442 class MKLDNNGraphTwoConcatTests: public TestsCommon,
443                                  public WithParamInterface<two_concat_test_params>  {
444     std::string model_t = R"V0G0N(
445 <net name="TwoConcatsDiffFwd" version="2" precision="FP32" batch="1">
446     <layers>
447         <layer name="in1" type="Input" precision="FP32" id="1">
448             <output>
449                 <port id="1">__SRC_DIMS_1__
450                 </port>
451             </output>
452         </layer>
453         <layer name="in2" type="Input" precision="FP32" id="2">
454             <output>
455                 <port id="1">__SRC_DIMS_2__
456                 </port>
457             </output>
458         </layer>
459         <layer name="in3" type="Input" precision="FP32" id="3">
460             <output>
461                 <port id="1">__SRC_DIMS_3__
462                 </port>
463             </output>
464         </layer>
465         <layer name="_CONCAT1_NAME_" id="4" type="Concat" precision="FP32">
466             <concat_data axis="_CONCAT1_AXIS_"/>
467             <input>
468                 <port id="1">
469                     <dim>_CI41N_</dim>
470                     <dim>_CI41C_</dim>
471                     <dim>_CI41D_</dim>
472                     <dim>_CI41H_</dim>
473                     <dim>_CI41W_</dim>
474                 </port>
475                 <port id="2">
476                     <dim>_CI42N_</dim>
477                     <dim>_CI42C_</dim>
478                     <dim>_CI42D_</dim>
479                     <dim>_CI42H_</dim>
480                     <dim>_CI42W_</dim>
481                 </port>
482             </input>
483             <output>
484                 <port id="3">__CO_DIMS_1__
485                 </port>
486             </output>
487         </layer>
488         <layer name="_CONCAT2_NAME_" id="5" type="Concat" precision="FP32">
489             <concat_data axis="_CONCAT2_AXIS_"/>
490             <input>
491                 <port id="1">
492                     <dim>_CI51N_</dim>
493                     <dim>_CI51C_</dim>
494                     <dim>_CI51D_</dim>
495                     <dim>_CI51H_</dim>
496                     <dim>_CI51W_</dim>
497                 </port>
498                 <port id="2">
499                     <dim>_CI52N_</dim>
500                     <dim>_CI52C_</dim>
501                     <dim>_CI52D_</dim>
502                     <dim>_CI52H_</dim>
503                     <dim>_CI52W_</dim>
504                 </port>
505             </input>
506             <output>
507                 <port id="3">__CO_DIMS_2__
508                 </port>
509             </output>
510         </layer>
511     </layers>
512     <edges>
513         <edge from-layer="1" from-port="1" to-layer="_FL11_" to-port="_FP11_"/>
514         <edge from-layer="2" from-port="1" to-layer="_FL21_" to-port="_FP21_"/>
515         <edge from-layer="3" from-port="1" to-layer="_FL31_" to-port="_FP31_"/>
516         <edge from-layer="_FSL_" from-port="_FSP_" to-layer="_FSLTL_" to-port="_FSLTP_"/>
517     </edges>
518 </net>
519 )V0G0N";
520     void changeEdgeToLayer(std::string& model, int f_l, int f_p, int t_l, int t_p, vector<size_t> dims) {
521         std::string TL = "_FL" + std::to_string(f_l) + std::to_string(f_p) + "_";
522         std::string TP = "_FP" + std::to_string(f_l) + std::to_string(f_p) + "_";
523         if (!FIND_STR(model, TL) || !FIND_STR(model, TP)) {
524             if (!FIND_STR(model, "_FSL_") || !FIND_STR(model, "_FSP_") ||
525                     !FIND_STR(model, "_FSLTL_") || !FIND_STR(model, "_FSLTP_")) {
526                 THROW_IE_EXCEPTION << "Incorrect configuration!";
527             }
528             REPLACE_WITH_NUM(model, "_FSL_", f_l);
529             REPLACE_WITH_NUM(model, "_FSP_", f_p);
530             REPLACE_WITH_NUM(model, "_FSLTL_", t_l);
531             REPLACE_WITH_NUM(model, "_FSLTP_", t_p);
532         } else {
533             REPLACE_WITH_NUM(model, TL, t_l);
534             REPLACE_WITH_NUM(model, TP, t_p);
535         }
536
537         std::string CI = "_CI" + std::to_string(t_l) + std::to_string(t_p);
538         auto dims_size = dims.size();
539         REPLACE_WITH_NUM(model, CI + "N_", dims[0]);
540         REPLACE_WITH_NUM(model, CI + "C_", dims[1]);
541         REPLACE_WITH_NUM(model, CI + "H_", dims[dims_size - 2]);
542         REPLACE_WITH_NUM(model, CI + "W_", dims[dims_size - 1]);
543         if (dims_size < 5) REMOVE_LINE(model, std::string("<dim>") + CI + std::string("D_") + "</dim>");
544         else REPLACE_WITH_NUM(model, CI + "D_", dims[dims_size - 3]);
545     }
546
547
548     std::string getModel(two_concat_test_params p) {
549         std::string model = model_t;
550         std::string s_dims;
551         for (size_t i = 0; i < p.in1.size(); i++) {
552             s_dims += "\n                    <dim>";
553             s_dims += std::to_string(p.in1[i]) + "</dim>";
554         }
555         REPLACE_WITH_STR(model, "__SRC_DIMS_1__", s_dims);
556
557         s_dims = "";
558         for (size_t i = 0; i < p.in2.size(); i++) {
559             s_dims += "\n                    <dim>";
560             s_dims += std::to_string(p.in2[i]) + "</dim>";
561         }
562         REPLACE_WITH_STR(model, "__SRC_DIMS_2__", s_dims);
563
564         s_dims = "";
565         for (size_t i = 0; i < p.in3.size(); i++) {
566             s_dims += "\n                    <dim>";
567             s_dims += std::to_string(p.in3[i]) + "</dim>";
568         }
569         REPLACE_WITH_STR(model, "__SRC_DIMS_3__", s_dims);
570
571         vector<size_t> concat11;
572         switch (p.concat1.input1) {
573             case 1:
574                 changeEdgeToLayer(model, 2, 1, 4, 1, p.in2);
575                 concat11 = p.in2;
576                 break;
577             case 2:
578                 changeEdgeToLayer(model, 3, 1, 4, 1, p.in3);
579                 concat11 = p.in3;
580                 break;
581             default:
582                 changeEdgeToLayer(model, 1, 1, 4, 1, p.in1);
583                 concat11 = p.in1;
584         }
585
586         vector<size_t> concat12;
587         switch (p.concat1.input2) {
588             case 1:
589                 changeEdgeToLayer(model, 2, 1, 4, 2, p.in2);
590                 concat12 = p.in2;
591                 break;
592             case 2:
593                 changeEdgeToLayer(model, 3, 1, 4, 2, p.in3);
594                 concat12 = p.in3;
595                 break;
596             default:
597                 changeEdgeToLayer(model, 1, 1, 4, 2, p.in1);
598                 concat12 = p.in1;
599         }
600
601         vector<size_t> concat21;
602         switch (p.concat2.input1) {
603             case 1:
604                 changeEdgeToLayer(model, 2, 1, 5, 1, p.in2);
605                 concat21 = p.in2;
606                 break;
607             case 2:
608                 changeEdgeToLayer(model, 3, 1, 5, 1, p.in3);
609                 concat21 = p.in3;
610                 break;
611             default:
612                 changeEdgeToLayer(model, 1, 1, 5, 1, p.in1);
613                 concat21 = p.in1;
614         }
615
616         vector<size_t> concat22;
617         switch (p.concat2.input2) {
618             case 1:
619                 changeEdgeToLayer(model, 2, 1, 5, 2, p.in2);
620                 concat22 = p.in2;
621                 break;
622             case 2:
623                 changeEdgeToLayer(model, 3, 1, 5, 2, p.in3);
624                 concat22 = p.in3;
625                 break;
626             default:
627                 changeEdgeToLayer(model, 1, 1, 5, 2, p.in1);
628                 concat22 = p.in1;
629         }
630
631         s_dims = "";
632         for (size_t i = 0; i < p.in2.size(); i++) {
633             size_t concat = p.concat1.axis == i ? concat11[i] + concat12[i] : concat21[i];
634             s_dims += "\n                    <dim>";
635             s_dims += std::to_string(concat) + "</dim>";
636         }
637         REPLACE_WITH_STR(model, "__CO_DIMS_1__", s_dims);
638
639         REPLACE_WITH_NUM(model, "_CONCAT1_AXIS_", p.concat1.axis);
640         REPLACE_WITH_STR(model, "_CONCAT1_NAME_", p.concat1.name);
641
642         s_dims = "";
643         for (size_t i = 0; i < p.in2.size(); i++) {
644             size_t concat = p.concat2.axis == i ? concat21[i] + concat22[i] : concat21[i];
645             s_dims += "\n                    <dim>";
646             s_dims += std::to_string(concat) + "</dim>";
647         }
648         REPLACE_WITH_STR(model, "__CO_DIMS_2__", s_dims);
649
650         REPLACE_WITH_NUM(model, "_CONCAT2_AXIS_", p.concat2.axis);
651         REPLACE_WITH_STR(model, "_CONCAT2_NAME_", p.concat2.name);
652         return model;
653     }
654
655 protected:
656     virtual void TearDown() {
657     }
658
659     virtual void SetUp() {
660         try {
661             TestsCommon::SetUp();
662             two_concat_test_params p = ::testing::WithParamInterface<two_concat_test_params>::GetParam();
663             std::string model = getModel(p);
664
665             InferenceEngine::CNNNetReader net_reader;
666             ASSERT_NO_THROW(net_reader.ReadNetwork(model.data(), model.length()));
667
668             MKLDNNGraphTestClass graph;
669             graph.CreateGraph(net_reader.getNetwork());
670
671             InferenceEngine::SizeVector dims_src1 = p.in1;
672             InferenceEngine::SizeVector dims_src2 = p.in2;
673             InferenceEngine::SizeVector dims_src3 = p.in3;
674             InferenceEngine::Layout layout = InferenceEngine::ANY;
675             switch (p.in1.size()) {
676                 case 4:
677                     layout = InferenceEngine::NCHW;
678                     break;
679                 case 5:
680                     layout = InferenceEngine::NCDHW;
681                     break;
682             }
683
684             InferenceEngine::Blob::Ptr src1 = InferenceEngine::make_shared_blob<float, const InferenceEngine::SizeVector>(InferenceEngine::Precision::FP32, layout, dims_src1);
685             src1->allocate();
686             fill_data(src1->buffer(), src1->size());
687
688             InferenceEngine::Blob::Ptr src2 = InferenceEngine::make_shared_blob<float, const InferenceEngine::SizeVector>(InferenceEngine::Precision::FP32, layout, dims_src2);
689             src2->allocate();
690             fill_data(src2->buffer(), src2->size());
691
692             InferenceEngine::Blob::Ptr src3 = InferenceEngine::make_shared_blob<float, const InferenceEngine::SizeVector>(InferenceEngine::Precision::FP32, layout, dims_src3);
693             src3->allocate();
694             fill_data(src3->buffer(), src3->size());
695
696             InferenceEngine::BlobMap srcs;
697             srcs.insert(std::pair<std::string, InferenceEngine::Blob::Ptr>("in1", src1));
698             srcs.insert(std::pair<std::string, InferenceEngine::Blob::Ptr>("in2", src2));
699             srcs.insert(std::pair<std::string, InferenceEngine::Blob::Ptr>("in3", src3));
700
701             InferenceEngine::OutputsDataMap out;
702             out = net_reader.getNetwork().getOutputsInfo();
703             InferenceEngine::BlobMap outputBlobs;
704
705             for (auto it = out.begin(); it != out.end(); it++) {
706                 std::pair<std::string, InferenceEngine::DataPtr> item = *it;
707                 InferenceEngine::TBlob<float>::Ptr output;
708                 output = InferenceEngine::make_shared_blob<float>(item.second->getTensorDesc());
709                 output->allocate();
710                 outputBlobs[item.first] = output;
711             }
712
713             graph.Infer(srcs, outputBlobs);
714
715             for (auto concat : {p.concat1, p.concat2}) {
716                 float *src1_ptr;
717                 size_t src1_size;
718                 float *src2_ptr;
719                 size_t src2_size;
720                 InferenceEngine::Blob::Ptr src1_c;
721                 InferenceEngine::Blob::Ptr src2_c;
722
723                 switch (concat.input1) {
724                     case 1:
725                         src1_ptr = src2->buffer();
726                         src1_size = src2->size();
727                         src1_c = src2;
728                         break;
729                     case 2:
730                         src1_ptr = src3->buffer();
731                         src1_size = src3->size();
732                         src1_c = src3;
733                         break;
734                     default:
735                         src1_ptr = src1->buffer();
736                         src1_size = src1->size();
737                         src1_c = src1;
738                 }
739
740                 switch (concat.input2) {
741                     case 1:
742                         src2_ptr = src2->buffer();
743                         src2_size = src2->size();
744                         src2_c = src2;
745                         break;
746                     case 2:
747                         src2_ptr = src3->buffer();
748                         src2_size = src3->size();
749                         src2_c = src3;
750                         break;
751                     default:
752                         src2_ptr = src1->buffer();
753                         src2_size = src1->size();
754                         src2_c = src1;
755                 }
756
757                 float *dst_ptr = outputBlobs[concat.name]->buffer();
758                 size_t dst_size = outputBlobs[concat.name]->size();
759
760                 int len1 = 1, len2 = 1, cycles;
761                 for (int dim = concat.axis; dim < outputBlobs[concat.name]->dims().size(); dim++) {
762                     len1 *= src1_c->dims()[dim];
763                     len2 *= src2_c->dims()[dim];
764                 }
765                 cycles = concat.axis;
766
767                 int index1 = 0, index2 = 0, index = 0;
768                 for (int cycle = 0; cycle < cycles; cycle ++) {
769                     for (int i1 = 0; i1 < len1; i1++) {
770                         if (src1_ptr[index1] != dst_ptr[index])
771                         {
772                             FAIL() << concat.name << " index: " << index << " src: "
773                                    << src1_ptr[index1] << ", dst: " << dst_ptr[index];
774                         }
775                         index1++; index++;
776                     }
777                     for (int i2 = 0; i2 < len2; i2++) {
778                         if (src2_ptr[index2] != dst_ptr[index])
779                         {
780                             FAIL() << concat.name << " index: " << index << " src: "
781                                    << src2_ptr[index2] << ", dst: " << dst_ptr[index];
782                         }
783                         index2++; index++;
784                     }
785                 }
786             }
787         } catch (const InferenceEngine::details::InferenceEngineException &e) {
788             FAIL() << e.what();
789         }
790     }
791 };
792
793 TEST_P(MKLDNNGraphTwoConcatTests, TestsTwoConcat) {}
794
795 INSTANTIATE_TEST_CASE_P(
796         TestsTwoConcat, MKLDNNGraphTwoConcatTests,
797         ::testing::Values(
798                 two_concat_test_params {
799                         {1, 5, 2, 5},
800                         {3, 5, 2, 5},
801                         {1, 5, 2, 5},
802                         {"concat1", 0, 0, 1},
803                         {"concat2", 0, 1, 2}
804                 },
805                 two_concat_test_params {
806                         {1, 2, 2, 5},
807                         {1, 5, 2, 5},
808                         {3, 5, 2, 5},
809                         {"concat1", 1, 0, 1},
810                         {"concat2", 0, 1, 2}
811                 },
812                 two_concat_test_params {
813                         {1, 2, 2, 2},
814                         {1, 1, 2, 2},
815                         {1, 3, 2, 2},
816                         {"concat1", 1, 0, 1},
817                         {"concat2", 1, 1, 2}
818                 },
819                 two_concat_test_params {
820                         {1, 5, 2, 5},
821                         {3, 5, 2, 5},
822                         {1, 5, 2, 5},
823                         {"concat1", 0, 0, 1},
824                         {"concat2", 0, 2, 1}
825                 },
826                 two_concat_test_params {
827                         {1, 2, 2, 5},
828                         {1, 5, 2, 5},
829                         {3, 5, 2, 5},
830                         {"concat1", 1, 0, 1},
831                         {"concat2", 0, 2, 1}
832                 },
833                 two_concat_test_params {
834                         {1, 2, 2, 2},
835                         {1, 1, 2, 2},
836                         {1, 3, 2, 2},
837                         {"concat1", 1, 0, 1},
838                         {"concat2", 1, 2, 1}
839                 }));
840
841
842 class MKLDNNGraphTwoInputInConcatTests: public TestsCommon {
843     std::string model_t = R"V0G0N(
844 <net name="TwoConcatsDiffFwd" version="2" precision="FP32" batch="1">
845     <layers>
846         <layer name="in1" type="Input" precision="FP32" id="1">
847             <output>
848                 <port id="1">
849                     <dim>1</dim>
850                     <dim>3</dim>
851                     <dim>2</dim>
852                     <dim>2</dim>
853                 </port>
854             </output>
855         </layer>
856         <layer name="in2" type="Input" precision="FP32" id="2">
857             <output>
858                 <port id="1">
859                     <dim>1</dim>
860                     <dim>2</dim>
861                     <dim>2</dim>
862                     <dim>2</dim>
863                 </port>
864             </output>
865         </layer>
866         <layer name="norm" id="3" type="ReLU" precision="FP32">
867             <input>
868                 <port id="1">
869                     <dim>1</dim>
870                     <dim>3</dim>
871                     <dim>2</dim>
872                     <dim>2</dim>
873                 </port>
874             </input>
875             <output>
876                 <port id="2">
877                     <dim>1</dim>
878                     <dim>3</dim>
879                     <dim>2</dim>
880                     <dim>2</dim>
881                 </port>
882             </output>
883         </layer>
884         <layer name="power" id="4" type="Power" precision="FP32">
885             <power_data power="-1" scale="-1" shift="0"/>
886             <input>
887                 <port id="1">
888                     <dim>1</dim>
889                     <dim>3</dim>
890                     <dim>2</dim>
891                     <dim>2</dim>
892                 </port>
893             </input>
894             <output>
895                 <port id="2">
896                     <dim>1</dim>
897                     <dim>3</dim>
898                     <dim>2</dim>
899                     <dim>2</dim>
900                 </port>
901             </output>
902         </layer>
903         <layer name="o_concat" id="5" type="Concat" precision="FP32">
904             <concat_data axis="1"/>
905             <input>
906                 <port id="1">
907                     <dim>1</dim>
908                     <dim>2</dim>
909                     <dim>2</dim>
910                     <dim>2</dim>
911                 </port>
912                 <port id="2">
913                     <dim>1</dim>
914                     <dim>3</dim>
915                     <dim>2</dim>
916                     <dim>2</dim>
917                 </port>
918             </input>
919             <output>
920                 <port id="3">
921                     <dim>1</dim>
922                     <dim>5</dim>
923                     <dim>2</dim>
924                     <dim>2</dim>
925                 </port>
926             </output>
927         </layer>
928     </layers>
929     <edges>
930         <edge from-layer="1" from-port="1" to-layer="3" to-port="1"/>
931         <edge from-layer="1" from-port="1" to-layer="5" to-port="2"/>
932         <edge from-layer="1" from-port="1" to-layer="4" to-port="1"/>
933         <edge from-layer="2" from-port="1" to-layer="5" to-port="1"/>
934     </edges>
935 </net>
936 )V0G0N";
937
938 protected:
939     virtual void TearDown() {
940     }
941
942     virtual void SetUp() {
943         try {
944             TestsCommon::SetUp();
945             std::string model = model_t;
946
947             InferenceEngine::CNNNetReader net_reader;
948             ASSERT_NO_THROW(net_reader.ReadNetwork(model.data(), model.length()));
949
950             MKLDNNGraphTestClass graph;
951             graph.CreateGraph(net_reader.getNetwork());
952
953             InferenceEngine::SizeVector dims_src1 = {1, 3, 2, 2};
954             InferenceEngine::SizeVector dims_src2 = {1, 2, 2, 2};
955
956             InferenceEngine::Blob::Ptr src1 = InferenceEngine::make_shared_blob<float, const InferenceEngine::SizeVector>(InferenceEngine::Precision::FP32, InferenceEngine::NCHW, dims_src1);
957             src1->allocate();
958             float *src1_data = src1->buffer();
959             for (size_t i = 0; i < src1->size(); i++) {
960                 src1_data[i] = i + 1;
961             }
962
963             InferenceEngine::Blob::Ptr src2 = InferenceEngine::make_shared_blob<float, const InferenceEngine::SizeVector>(InferenceEngine::Precision::FP32, InferenceEngine::NCHW, dims_src2);
964             src2->allocate();
965             fill_data(src2->buffer(), src2->size());
966
967             InferenceEngine::BlobMap srcs;
968             srcs.insert(std::pair<std::string, InferenceEngine::Blob::Ptr>("in1", src1));
969             srcs.insert(std::pair<std::string, InferenceEngine::Blob::Ptr>("in2", src2));
970
971             InferenceEngine::OutputsDataMap out;
972             out = net_reader.getNetwork().getOutputsInfo();
973             InferenceEngine::BlobMap outputBlobs;
974
975             for (auto it = out.begin(); it != out.end(); it++) {
976                 std::pair<std::string, InferenceEngine::DataPtr> item = *it;
977                 InferenceEngine::TBlob<float>::Ptr output;
978                 output = InferenceEngine::make_shared_blob<float>(item.second->getTensorDesc());
979                 output->allocate();
980                 outputBlobs[item.first] = output;
981             }
982
983             graph.Infer(srcs, outputBlobs);
984
985             float *src1_ptr = src2->buffer();
986             size_t src1_size = src2->size();
987             float *src2_ptr = src1->buffer();
988             size_t src2_size = src1->size();
989
990             float *dst_ptr = outputBlobs["o_concat"]->buffer();
991             size_t dst_size = outputBlobs["o_concat"]->size();
992
993             int len1 = 1, len2 = 1, cycles;
994             for (int dim = 1; dim < outputBlobs["o_concat"]->dims().size(); dim++) {
995                 len1 *= src2->dims()[dim];
996                 len2 *= src1->dims()[dim];
997             }
998             cycles = 1;
999
1000             int index1 = 0, index2 = 0, index = 0;
1001             for (int cycle = 0; cycle < cycles; cycle ++) {
1002                 for (int i1 = 0; i1 < len1; i1++) {
1003                     if (src1_ptr[index1] != dst_ptr[index])
1004                     {
1005                         FAIL() << "concat index: " << index << " src: "
1006                                << src1_ptr[index1] << ", dst: " << dst_ptr[index];
1007                     }
1008                     index1++; index++;
1009                 }
1010                 for (int i2 = 0; i2 < len2; i2++) {
1011                     if (src2_ptr[index2] != dst_ptr[index])
1012                     {
1013                         FAIL() << "concat index: " << index << " src: "
1014                                << src2_ptr[index2] << ", dst: " << dst_ptr[index];
1015                     }
1016                     index2++; index++;
1017                 }
1018             }
1019         } catch (const InferenceEngine::details::InferenceEngineException &e) {
1020             FAIL() << e.what();
1021         }
1022     }
1023 };
1024
1025 TEST_F(MKLDNNGraphTwoInputInConcatTests, TestSecondInputToConcat) {}
1026
1027 class MKLDNNGraphIncorrectConcatTests: public TestsCommon,
1028                               public WithParamInterface<concat_test_params> {
1029     std::string model_t = R"V0G0N(
1030 <net name="ConcatOnly" version="2" precision="FP32" batch="1">
1031     <layers>
1032         <layer name="in1" type="Input" precision="FP32" id="1">
1033             <output>
1034                 <port id="1">__SRC_DIMS_1__
1035                 </port>
1036             </output>
1037         </layer>
1038         <layer name="in2" type="Input" precision="FP32" id="2">
1039             <output>
1040                 <port id="2">__SRC_DIMS_2__
1041                 </port>
1042             </output>
1043         </layer>
1044         <layer name="con" id="3" type="Concat" precision="FP32">
1045             <concat_data axis="_AXIS_"/>
1046             <input>
1047                 <port id="1">__SRC_DIMS_1__
1048                 </port>
1049                 <port id="2">__SRC_DIMS_2__
1050                 </port>
1051             </input>
1052             <output>
1053                 <port id="3">__DST_DIMS__
1054                 </port>
1055             </output>
1056         </layer>
1057     </layers>
1058     <edges>
1059         <edge from-layer="1" from-port="1" to-layer="3" to-port="1"/>
1060         <edge from-layer="2" from-port="2" to-layer="3" to-port="2"/>
1061     </edges>
1062 </net>
1063 )V0G0N";
1064
1065     std::string getModel(concat_test_params p) {
1066         std::string model = model_t;
1067         std::string s_dims;
1068         for (auto& dim : p.in1) {
1069             s_dims += "\n                    <dim>";
1070             s_dims += std::to_string(dim) + "</dim>";
1071         }
1072         REPLACE_WITH_STR(model, "__SRC_DIMS_1__", s_dims);
1073
1074         s_dims = "";
1075         for (auto& dim : p.in2) {
1076             s_dims += "\n                    <dim>";
1077             s_dims += std::to_string(dim) + "</dim>";
1078         }
1079         REPLACE_WITH_STR(model, "__SRC_DIMS_2__", s_dims);
1080
1081         s_dims = "";
1082         for (size_t i = 0; i < p.in1.size(); i++) {
1083             size_t dim = p.axis == i ? p.in1[i] + p.in2[i] : p.in1[i];
1084             s_dims += "\n                    <dim>";
1085             s_dims += std::to_string(dim) + "</dim>";
1086         }
1087         REPLACE_WITH_STR(model, "__DST_DIMS__", s_dims);
1088
1089         REPLACE_WITH_NUM(model, "_AXIS_", p.axis);
1090         return model;
1091     }
1092
1093 protected:
1094     virtual void TearDown() {
1095     }
1096
1097     virtual void SetUp() {
1098         try {
1099             TestsCommon::SetUp();
1100             concat_test_params p = ::testing::WithParamInterface<concat_test_params>::GetParam();
1101             std::string model = getModel(p);
1102
1103             InferenceEngine::CNNNetReader net_reader;
1104             ASSERT_THROW(net_reader.ReadNetwork(model.data(), model.length()), 
1105                          InferenceEngine::details::InferenceEngineException);
1106         } catch (const InferenceEngine::details::InferenceEngineException &e) {
1107             FAIL() << e.what();
1108         }
1109     }
1110 };
1111
1112 TEST_P(MKLDNNGraphIncorrectConcatTests, TestsIncorrectConcat) {}
1113
1114
1115 INSTANTIATE_TEST_CASE_P(
1116         TestsIncorrectConcat, MKLDNNGraphIncorrectConcatTests,
1117         ::testing::Values(
1118                 concat_test_params {
1119                         {1, 7, 2, 5},
1120                         {1, 7, 3, 5},
1121                         1
1122                 },
1123                 concat_test_params {
1124                         {1, 7, 2, 5},
1125                         {1, 7, 4, 4},
1126                         2
1127                 }));