Publishing R3
[platform/upstream/dldt.git] / inference-engine / tests / unit / engines / mkldnn / graph / layers / internal / graph_concat_test.cpp
1 // Copyright (C) 2018 Intel Corporation
2 //
3 // SPDX-License-Identifier: Apache-2.0
4 //
5
6 #include <gtest/gtest.h>
7 #include <gmock/gmock-spec-builders.h>
8 #include "mkldnn_plugin/mkldnn_graph.h"
9 #include "mock_mkldnn_primitive.hpp"
10
11 #include "test_graph.hpp"
12
13 #include "single_layer_common.hpp"
14 #include <mkldnn_plugin/mkldnn_extension_utils.h>
15 #include <unordered_set>
16 #include <inference_engine/cnn_network_impl.hpp>
17 #include "tests_common.hpp"
18
19 using namespace ::testing;
20 using namespace std;
21 using namespace mkldnn;
22
23 struct dim4 {
24     size_t n;
25     size_t c;
26     size_t h;
27     size_t w;
28 };
29
30 struct concat_test_params {
31     dim4 in1;
32
33     dim4 in2;
34
35     size_t axis;
36
37     size_t num_prim_desc;
38
39     MKLDNNPlugin::impl_desc_type selectedType;
40
41     std::vector<std::function<void(MKLDNNPlugin::PrimitiveDescInfo)>> comp;
42 };
43
44 class MKLDNNGraphConcatTests: public TestsCommon,
45                               public WithParamInterface<concat_test_params> {
46     std::string model_t = R"V0G0N(
47 <net name="ConcatOnly" version="2" precision="FP32" batch="1">
48     <layers>
49         <layer name="in1" type="Input" precision="FP32" id="1">
50             <output>
51                 <port id="1">
52                     <dim>_IN1_</dim>
53                     <dim>_IC1_</dim>
54                     <dim>_IH1_</dim>
55                     <dim>_IW1_</dim>
56                 </port>
57             </output>
58         </layer>
59         <layer name="in2" type="Input" precision="FP32" id="2">
60             <output>
61                 <port id="2">
62                     <dim>_IN2_</dim>
63                     <dim>_IC2_</dim>
64                     <dim>_IH2_</dim>
65                     <dim>_IW2_</dim>
66                 </port>
67             </output>
68         </layer>
69         <layer name="con" id="3" type="Concat" precision="FP32">
70             <concat_data axis="_AXIS_"/>
71             <input>
72                 <port id="1">
73                     <dim>_IN1_</dim>
74                     <dim>_IC1_</dim>
75                     <dim>_IH1_</dim>
76                     <dim>_IW1_</dim>
77                 </port>
78                 <port id="2">
79                     <dim>_IN2_</dim>
80                     <dim>_IC2_</dim>
81                     <dim>_IH2_</dim>
82                     <dim>_IW2_</dim>
83                 </port>
84             </input>
85             <output>
86                 <port id="3">
87                     <dim>_ON_</dim>
88                     <dim>_OC_</dim>
89                     <dim>_OH_</dim>
90                     <dim>_OW_</dim>
91                 </port>
92             </output>
93         </layer>
94     </layers>
95     <edges>
96         <edge from-layer="1" from-port="1" to-layer="3" to-port="1"/>
97         <edge from-layer="2" from-port="2" to-layer="3" to-port="2"/>
98     </edges>
99 </net>
100 )V0G0N";
101
102     std::string getModel(concat_test_params p) {
103         std::string model = model_t;
104         REPLACE_WITH_NUM(model, "_IN1_", p.in1.n);
105         REPLACE_WITH_NUM(model, "_IC1_", p.in1.c);
106         REPLACE_WITH_NUM(model, "_IW1_", p.in1.w);
107         REPLACE_WITH_NUM(model, "_IH1_", p.in1.h);
108
109         REPLACE_WITH_NUM(model, "_IN2_", p.in2.n);
110         REPLACE_WITH_NUM(model, "_IC2_", p.in2.c);
111         REPLACE_WITH_NUM(model, "_IW2_", p.in2.w);
112         REPLACE_WITH_NUM(model, "_IH2_", p.in2.h);
113
114         REPLACE_WITH_NUM(model, "_ON_", p.axis == 0 ? p.in1.n + p.in2.n : p.in1.n);
115         REPLACE_WITH_NUM(model, "_OC_", p.axis == 1 ? p.in1.c + p.in2.c : p.in1.c);
116         REPLACE_WITH_NUM(model, "_OH_", p.axis == 2 ? p.in1.h + p.in2.h : p.in1.h);
117         REPLACE_WITH_NUM(model, "_OW_", p.axis == 3 ? p.in1.w + p.in2.w : p.in1.w);
118
119         REPLACE_WITH_NUM(model, "_AXIS_", p.axis);
120         return model;
121     }
122
123 protected:
124     virtual void TearDown() {
125     }
126
127     virtual void SetUp() {
128         try {
129             TestsCommon::SetUp();
130             concat_test_params p = ::testing::WithParamInterface<concat_test_params>::GetParam();
131             std::string model = getModel(p);
132
133             InferenceEngine::CNNNetReader net_reader;
134             ASSERT_NO_THROW(net_reader.ReadNetwork(model.data(), model.length()));
135
136             MKLDNNGraphTestClass graph;
137             graph.CreateGraph(net_reader.getNetwork());
138             auto& nodes = graph.getNodes();
139             for (int i = 0; i < nodes.size(); i++) {
140                 if (nodes[i]->getType() == MKLDNNPlugin::Concatenation) {
141                     ASSERT_EQ(p.num_prim_desc, nodes[i]->getSupportedPrimitiveDescriptors().size());
142                     for (size_t j = 0; j < p.num_prim_desc && j < p.comp.size(); j++) {
143                         p.comp.at(j)(nodes[i]->getSupportedPrimitiveDescriptors().at(j));
144                     }
145                     ASSERT_NE(nullptr, nodes[i]->getSelectedPrimitiveDescriptor());
146                     ASSERT_EQ(p.selectedType, nodes[i]->getSelectedPrimitiveDescriptor()->getImplementationType());
147                 }
148             }
149             ASSERT_LE(3, nodes.size());
150
151             InferenceEngine::SizeVector dims_src1 = {p.in1.n, p.in1.c, p.in1.h, p.in1.w};
152             InferenceEngine::SizeVector dims_src2 = {p.in2.n, p.in2.c, p.in2.h, p.in2.w};
153
154             InferenceEngine::Blob::Ptr src1 = InferenceEngine::make_shared_blob<float, const InferenceEngine::SizeVector>(InferenceEngine::Precision::FP32, InferenceEngine::NCHW, dims_src1);
155             src1->allocate();
156
157             fill_data(src1->buffer(), src1->size());
158             InferenceEngine::Blob::Ptr src2 = InferenceEngine::make_shared_blob<float, const InferenceEngine::SizeVector>(InferenceEngine::Precision::FP32, InferenceEngine::NCHW, dims_src2);
159             src2->allocate();
160             fill_data(src2->buffer(), src2->size());
161             InferenceEngine::BlobMap srcs;
162             srcs.insert(std::pair<std::string, InferenceEngine::Blob::Ptr>("in1", src1));
163             srcs.insert(std::pair<std::string, InferenceEngine::Blob::Ptr>("in2", src2));
164
165             InferenceEngine::OutputsDataMap out;
166             out = net_reader.getNetwork().getOutputsInfo();
167             InferenceEngine::BlobMap outputBlobs;
168
169             std::pair<std::string, InferenceEngine::DataPtr> item = *out.begin();
170
171             InferenceEngine::TBlob<float>::Ptr output;
172             output = InferenceEngine::make_shared_blob<float>(item.second->getTensorDesc());
173             output->allocate();
174             outputBlobs[item.first] = output;
175
176             graph.Infer(srcs, outputBlobs);
177
178             // Compare
179             float *src1_ptr = src1->buffer();
180             size_t src1_size = src1->size();
181             float *src2_ptr = src2->buffer();
182             size_t src2_size = src2->size();
183             float *dst_ptr = output->buffer();
184             size_t dst_size = output->size();
185
186             int len1 = 1, len2 = 1, cycles;
187             for (int dim = p.axis; dim < output->dims().size(); dim++) {
188                 len1 *= src1->dims()[dim];
189                 len2 *= src2->dims()[dim];
190             }
191             cycles = p.axis;
192
193
194             int index1 = 0, index2 = 0, index = 0;
195             for (int cycle = 0; cycle < cycles; cycle ++) {
196                 for (int i1 = 0; i1 < len1; i1++) {
197                     if (src1_ptr[index1] != dst_ptr[index])
198                     {
199                         FAIL() << "index: " << index << " src: " << src1_ptr[index1] << ", dst: " << dst_ptr[index];
200                     }
201                     index1++; index++;
202                 }
203                 for (int i2 = 0; i2 < len2; i2++) {
204                     if (src2_ptr[index2] != dst_ptr[index])
205                     {
206                         FAIL() << "index: " << index << " src: " << src2_ptr[index2] << ", dst: " << dst_ptr[index];
207                     }
208                     index2++; index++;
209                 }
210             }
211         } catch (const InferenceEngine::details::InferenceEngineException &e) {
212             FAIL() << e.what();
213         }
214     }
215 };
216
217 TEST_P(MKLDNNGraphConcatTests, TestsConcat) {}
218
219 class MKLDNNGraphDynBatchConcatTests: public TestsCommon, public WithParamInterface<concat_test_params> {
220     std::string model_t = R"V0G0N(
221 <net name="ConcatOnly" version="2" precision="FP32" batch="1">
222     <layers>
223         <layer name="in1" type="Input" precision="FP32" id="1">
224             <output>
225                 <port id="1">
226                     <dim>1</dim>
227                     <dim>_IC1_</dim>
228                     <dim>_IH1_</dim>
229                     <dim>_IW1_</dim>
230                 </port>
231             </output>
232         </layer>
233         <layer name="in2" type="Input" precision="FP32" id="2">
234             <output>
235                 <port id="2">
236                     <dim>1</dim>
237                     <dim>_IC2_</dim>
238                     <dim>_IH2_</dim>
239                     <dim>_IW2_</dim>
240                 </port>
241             </output>
242         </layer>
243         <layer name="con" id="3" type="Concat" precision="FP32">
244             <concat_data axis="_AXIS_"/>
245             <input>
246                 <port id="1">
247                     <dim>1</dim>
248                     <dim>_IC1_</dim>
249                     <dim>_IH1_</dim>
250                     <dim>_IW1_</dim>
251                 </port>
252                 <port id="2">
253                     <dim>1</dim>
254                     <dim>_IC2_</dim>
255                     <dim>_IH2_</dim>
256                     <dim>_IW2_</dim>
257                 </port>
258             </input>
259             <output>
260                 <port id="3">
261                     <dim>1</dim>
262                     <dim>_OC_</dim>
263                     <dim>_OH_</dim>
264                     <dim>_OW_</dim>
265                 </port>
266             </output>
267         </layer>
268     </layers>
269     <edges>
270         <edge from-layer="1" from-port="1" to-layer="3" to-port="1"/>
271         <edge from-layer="2" from-port="2" to-layer="3" to-port="2"/>
272     </edges>
273 </net>
274 )V0G0N";
275
276     std::string getModel(concat_test_params p) {
277         std::string model = model_t;
278         REPLACE_WITH_NUM(model, "_IN1_", p.in1.n);
279         REPLACE_WITH_NUM(model, "_IC1_", p.in1.c);
280         REPLACE_WITH_NUM(model, "_IW1_", p.in1.w);
281         REPLACE_WITH_NUM(model, "_IH1_", p.in1.h);
282
283         REPLACE_WITH_NUM(model, "_IN2_", p.in2.n);
284         REPLACE_WITH_NUM(model, "_IC2_", p.in2.c);
285         REPLACE_WITH_NUM(model, "_IW2_", p.in2.w);
286         REPLACE_WITH_NUM(model, "_IH2_", p.in2.h);
287
288         REPLACE_WITH_NUM(model, "_ON_", p.axis == 0 ? p.in1.n + p.in2.n : p.in1.n);
289         REPLACE_WITH_NUM(model, "_OC_", p.axis == 1 ? p.in1.c + p.in2.c : p.in1.c);
290         REPLACE_WITH_NUM(model, "_OH_", p.axis == 2 ? p.in1.h + p.in2.h : p.in1.h);
291         REPLACE_WITH_NUM(model, "_OW_", p.axis == 3 ? p.in1.w + p.in2.w : p.in1.w);
292
293         REPLACE_WITH_NUM(model, "_AXIS_", p.axis);
294         return model;
295     }
296
297 protected:
298     virtual void TearDown() {
299     }
300
301     virtual void SetUp() {
302         try {
303             TestsCommon::SetUp();
304             concat_test_params p = ::testing::WithParamInterface<concat_test_params>::GetParam();
305             std::string model = getModel(p);
306             size_t MB = p.in1.n;
307             if (MB < 2)
308                 MB = 2;
309
310             InferenceEngine::CNNNetReader net_reader;
311             ASSERT_NO_THROW(net_reader.ReadNetwork(model.data(), model.length()));
312             InferenceEngine::CNNNetwork network = net_reader.getNetwork();
313             auto implNet = dynamic_cast<InferenceEngine::details::CNNNetworkImpl *>(&((InferenceEngine::ICNNNetwork&)network));
314             ASSERT_NE(nullptr, implNet) << "Failed to cast ICNNNetwork to CNNNetworkImpl";
315             InferenceEngine::ResponseDesc resp;
316             InferenceEngine::StatusCode sts  = implNet->setBatchSizeReshape(MB, &resp);
317             ASSERT_EQ((int)InferenceEngine::StatusCode::OK, sts) << resp.msg;
318
319             MKLDNNGraphTestClass graph;
320             graph.setProperty({{InferenceEngine::PluginConfigParams::KEY_DYN_BATCH_ENABLED, InferenceEngine::PluginConfigParams::YES}});
321             graph.CreateGraph(net_reader.getNetwork());
322
323             InferenceEngine::SizeVector dims_src1 = {MB, p.in1.c, p.in1.h, p.in1.w};
324             InferenceEngine::SizeVector dims_src2 = {MB, p.in2.c, p.in2.h, p.in2.w};
325
326             InferenceEngine::Blob::Ptr src1 = InferenceEngine::make_shared_blob<float, const InferenceEngine::SizeVector>(InferenceEngine::Precision::FP32, InferenceEngine::NCHW, dims_src1);
327             src1->allocate();
328
329             fill_data(src1->buffer(), src1->size());
330             InferenceEngine::Blob::Ptr src2 = InferenceEngine::make_shared_blob<float, const InferenceEngine::SizeVector>(InferenceEngine::Precision::FP32, InferenceEngine::NCHW, dims_src2);
331             src2->allocate();
332             fill_data(src2->buffer(), src2->size());
333             InferenceEngine::BlobMap srcs;
334             srcs.insert(std::pair<std::string, InferenceEngine::Blob::Ptr>("in1", src1));
335             srcs.insert(std::pair<std::string, InferenceEngine::Blob::Ptr>("in2", src2));
336
337             InferenceEngine::OutputsDataMap out;
338             out = net_reader.getNetwork().getOutputsInfo();
339             InferenceEngine::BlobMap outputBlobs;
340
341             std::pair<std::string, InferenceEngine::DataPtr> item = *out.begin();
342
343             InferenceEngine::TBlob<float>::Ptr output;
344             output = InferenceEngine::make_shared_blob<float>(item.second->getTensorDesc());
345             output->allocate();
346             outputBlobs[item.first] = output;
347
348
349             auto checkConcat = [](const MKLDNNPlugin::MKLDNNNodePtr& node) {
350                 return node->getType() == MKLDNNPlugin::Concatenation;
351             };
352
353             MKLDNNGraphTestClass::CheckDynBatchType checkType = MKLDNNGraphTestClass::CheckDynBatchType::Both;
354             if (p.selectedType == MKLDNNPlugin::impl_desc_type::unknown)
355                 checkType = MKLDNNGraphTestClass::CheckDynBatchType::Child;
356
357             graph.checkDynBatch(srcs, outputBlobs, MB, MB, checkConcat, checkType);
358             graph.checkDynBatch(srcs, outputBlobs, 1, MB, checkConcat, checkType);
359         } catch (const InferenceEngine::details::InferenceEngineException &e) {
360             FAIL() << e.what();
361         }
362     }
363 };
364
365 TEST_P(MKLDNNGraphDynBatchConcatTests, TestsDynBatchConcat) {}
366
367
368 INSTANTIATE_TEST_CASE_P(
369         TestsDynBatchConcat, MKLDNNGraphDynBatchConcatTests,
370         ::testing::Values(
371                 concat_test_params {
372                         {1, 7, 2, 5},
373                         {1, 7, 2, 5},
374                         2, 1, MKLDNNPlugin::impl_desc_type::ref
375                 },
376                 concat_test_params {
377                         {1, 7, 2, 5},
378                         {1, 13, 2, 5},
379                         1, 2, MKLDNNPlugin::impl_desc_type::unknown
380                 },
381                 concat_test_params {
382                         {3, 7, 2, 5},
383                         {3, 13, 2, 5},
384                         1, 2, MKLDNNPlugin::impl_desc_type::unknown
385                 },
386                 concat_test_params {
387                         {1, 7, 2, 13},
388                         {1, 7, 2, 17},
389                         3, 1, MKLDNNPlugin::impl_desc_type::ref
390                 },
391                 concat_test_params {
392                         {1, 8, 8, 16},
393                         {1, 16, 8, 16},
394                         1, 4, MKLDNNPlugin::impl_desc_type::unknown
395                 },
396                 concat_test_params {
397                         {2, 2, 3, 3},
398                         {2, 3, 3, 3},
399                         1, 2, MKLDNNPlugin::impl_desc_type::unknown
400                 }));
401
402 struct concat_param {
403     std::string name;
404     size_t axis;
405     size_t input1;
406     size_t input2;
407 };
408
409 struct two_concat_test_params {
410     dim4 in1;
411     dim4 in2;
412     dim4 in3;
413
414     concat_param concat1;
415     concat_param concat2;
416 };
417
418 class MKLDNNGraphTwoConcatTests: public TestsCommon,
419                                  public WithParamInterface<two_concat_test_params>  {
420     std::string model_t = R"V0G0N(
421 <net name="TwoConcatsDiffFwd" version="2" precision="FP32" batch="1">
422     <layers>
423         <layer name="in1" type="Input" precision="FP32" id="1">
424             <output>
425                 <port id="1">
426                     <dim>_IN1_</dim>
427                     <dim>_IC1_</dim>
428                     <dim>_IH1_</dim>
429                     <dim>_IW1_</dim>
430                 </port>
431             </output>
432         </layer>
433         <layer name="in2" type="Input" precision="FP32" id="2">
434             <output>
435                 <port id="1">
436                     <dim>_IN2_</dim>
437                     <dim>_IC2_</dim>
438                     <dim>_IH2_</dim>
439                     <dim>_IW2_</dim>
440                 </port>
441             </output>
442         </layer>
443         <layer name="in3" type="Input" precision="FP32" id="3">
444             <output>
445                 <port id="1">
446                     <dim>_IN3_</dim>
447                     <dim>_IC3_</dim>
448                     <dim>_IH3_</dim>
449                     <dim>_IW3_</dim>
450                 </port>
451             </output>
452         </layer>
453         <layer name="_CONCAT1_NAME_" id="4" type="Concat" precision="FP32">
454             <concat_data axis="_CONCAT1_AXIS_"/>
455             <input>
456                 <port id="1">
457                     <dim>_CI41N_</dim>
458                     <dim>_CI41C_</dim>
459                     <dim>_CI41H_</dim>
460                     <dim>_CI41W_</dim>
461                 </port>
462                 <port id="2">
463                     <dim>_CI42N_</dim>
464                     <dim>_CI42C_</dim>
465                     <dim>_CI42H_</dim>
466                     <dim>_CI42W_</dim>
467                 </port>
468             </input>
469             <output>
470                 <port id="3">
471                     <dim>_CON1_</dim>
472                     <dim>_COC1_</dim>
473                     <dim>_COH1_</dim>
474                     <dim>_COW1_</dim>
475                 </port>
476             </output>
477         </layer>
478         <layer name="_CONCAT2_NAME_" id="5" type="Concat" precision="FP32">
479             <concat_data axis="_CONCAT2_AXIS_"/>
480             <input>
481                 <port id="1">
482                     <dim>_CI51N_</dim>
483                     <dim>_CI51C_</dim>
484                     <dim>_CI51H_</dim>
485                     <dim>_CI51W_</dim>
486                 </port>
487                 <port id="2">
488                     <dim>_CI52N_</dim>
489                     <dim>_CI52C_</dim>
490                     <dim>_CI52H_</dim>
491                     <dim>_CI52W_</dim>
492                 </port>
493             </input>
494             <output>
495                 <port id="3">
496                     <dim>_CON2_</dim>
497                     <dim>_COC2_</dim>
498                     <dim>_COH2_</dim>
499                     <dim>_COW2_</dim>
500                 </port>
501             </output>
502         </layer>
503     </layers>
504     <edges>
505         <edge from-layer="1" from-port="1" to-layer="_FL11_" to-port="_FP11_"/>
506         <edge from-layer="2" from-port="1" to-layer="_FL21_" to-port="_FP21_"/>
507         <edge from-layer="3" from-port="1" to-layer="_FL31_" to-port="_FP31_"/>
508         <edge from-layer="_FSL_" from-port="_FSP_" to-layer="_FSLTL_" to-port="_FSLTP_"/>
509     </edges>
510 </net>
511 )V0G0N";
512     void changeEdgeToLayer(std::string& model, int f_l, int f_p, int t_l, int t_p, dim4 dims) {
513         std::string TL = "_FL" + std::to_string(f_l) + std::to_string(f_p) + "_";
514         std::string TP = "_FP" + std::to_string(f_l) + std::to_string(f_p) + "_";
515         if (!FIND_STR(model, TL) || !FIND_STR(model, TP)) {
516             if (!FIND_STR(model, "_FSL_") || !FIND_STR(model, "_FSP_") ||
517                     !FIND_STR(model, "_FSLTL_") || !FIND_STR(model, "_FSLTP_")) {
518                 THROW_IE_EXCEPTION << "Incorrect configuration!";
519             }
520             REPLACE_WITH_NUM(model, "_FSL_", f_l);
521             REPLACE_WITH_NUM(model, "_FSP_", f_p);
522             REPLACE_WITH_NUM(model, "_FSLTL_", t_l);
523             REPLACE_WITH_NUM(model, "_FSLTP_", t_p);
524         } else {
525             REPLACE_WITH_NUM(model, TL, t_l);
526             REPLACE_WITH_NUM(model, TP, t_p);
527         }
528
529         std::string CI = "_CI" + std::to_string(t_l) + std::to_string(t_p);
530         REPLACE_WITH_NUM(model, CI + "N_", dims.n);
531         REPLACE_WITH_NUM(model, CI + "C_", dims.c);
532         REPLACE_WITH_NUM(model, CI + "H_", dims.h);
533         REPLACE_WITH_NUM(model, CI + "W_", dims.w);
534     }
535
536
537     std::string getModel(two_concat_test_params p) {
538         std::string model = model_t;
539         REPLACE_WITH_NUM(model, "_IN1_", p.in1.n);
540         REPLACE_WITH_NUM(model, "_IC1_", p.in1.c);
541         REPLACE_WITH_NUM(model, "_IW1_", p.in1.w);
542         REPLACE_WITH_NUM(model, "_IH1_", p.in1.h);
543
544         REPLACE_WITH_NUM(model, "_IN2_", p.in2.n);
545         REPLACE_WITH_NUM(model, "_IC2_", p.in2.c);
546         REPLACE_WITH_NUM(model, "_IW2_", p.in2.w);
547         REPLACE_WITH_NUM(model, "_IH2_", p.in2.h);
548
549         REPLACE_WITH_NUM(model, "_IN3_", p.in3.n);
550         REPLACE_WITH_NUM(model, "_IC3_", p.in3.c);
551         REPLACE_WITH_NUM(model, "_IW3_", p.in3.w);
552         REPLACE_WITH_NUM(model, "_IH3_", p.in3.h);
553
554         dim4 concat11;
555         switch (p.concat1.input1) {
556             case 1:
557                 changeEdgeToLayer(model, 2, 1, 4, 1, p.in2);
558                 concat11 = p.in2;
559                 break;
560             case 2:
561                 changeEdgeToLayer(model, 3, 1, 4, 1, p.in3);
562                 concat11 = p.in3;
563                 break;
564             default:
565                 changeEdgeToLayer(model, 1, 1, 4, 1, p.in1);
566                 concat11 = p.in1;
567         }
568
569         dim4 concat12;
570         switch (p.concat1.input2) {
571             case 1:
572                 changeEdgeToLayer(model, 2, 1, 4, 2, p.in2);
573                 concat12 = p.in2;
574                 break;
575             case 2:
576                 changeEdgeToLayer(model, 3, 1, 4, 2, p.in3);
577                 concat12 = p.in3;
578                 break;
579             default:
580                 changeEdgeToLayer(model, 1, 1, 4, 2, p.in1);
581                 concat12 = p.in1;
582         }
583
584         dim4 concat21;
585         switch (p.concat2.input1) {
586             case 1:
587                 changeEdgeToLayer(model, 2, 1, 5, 1, p.in2);
588                 concat21 = p.in2;
589                 break;
590             case 2:
591                 changeEdgeToLayer(model, 3, 1, 5, 1, p.in3);
592                 concat21 = p.in3;
593                 break;
594             default:
595                 changeEdgeToLayer(model, 1, 1, 5, 1, p.in1);
596                 concat21 = p.in1;
597         }
598
599         dim4 concat22;
600         switch (p.concat2.input2) {
601             case 1:
602                 changeEdgeToLayer(model, 2, 1, 5, 2, p.in2);
603                 concat22 = p.in2;
604                 break;
605             case 2:
606                 changeEdgeToLayer(model, 3, 1, 5, 2, p.in3);
607                 concat22 = p.in3;
608                 break;
609             default:
610                 changeEdgeToLayer(model, 1, 1, 5, 2, p.in1);
611                 concat22 = p.in1;
612         }
613
614         REPLACE_WITH_NUM(model, "_CON1_", p.concat1.axis == 0 ? concat11.n + concat12.n : concat21.n);
615         REPLACE_WITH_NUM(model, "_COC1_", p.concat1.axis == 1 ? concat11.c + concat12.c : concat21.c);
616         REPLACE_WITH_NUM(model, "_COH1_", p.concat1.axis == 2 ? concat11.h + concat12.h : concat21.h);
617         REPLACE_WITH_NUM(model, "_COW1_", p.concat1.axis == 3 ? concat11.w + concat12.w : concat21.w);
618         REPLACE_WITH_NUM(model, "_CONCAT1_AXIS_", p.concat1.axis);
619         REPLACE_WITH_STR(model, "_CONCAT1_NAME_", p.concat1.name);
620
621         REPLACE_WITH_NUM(model, "_CON2_", p.concat2.axis == 0 ? concat21.n + concat22.n : concat21.n);
622         REPLACE_WITH_NUM(model, "_COC2_", p.concat2.axis == 1 ? concat21.c + concat22.c : concat21.c);
623         REPLACE_WITH_NUM(model, "_COH2_", p.concat2.axis == 2 ? concat21.h + concat22.h : concat21.h);
624         REPLACE_WITH_NUM(model, "_COW2_", p.concat2.axis == 3 ? concat21.w + concat22.w : concat21.w);
625         REPLACE_WITH_NUM(model, "_CONCAT2_AXIS_", p.concat2.axis);
626         REPLACE_WITH_STR(model, "_CONCAT2_NAME_", p.concat2.name);
627         return model;
628     }
629
630 protected:
631     virtual void TearDown() {
632     }
633
634     virtual void SetUp() {
635         try {
636             TestsCommon::SetUp();
637             two_concat_test_params p = ::testing::WithParamInterface<two_concat_test_params>::GetParam();
638             std::string model = getModel(p);
639
640             InferenceEngine::CNNNetReader net_reader;
641             ASSERT_NO_THROW(net_reader.ReadNetwork(model.data(), model.length()));
642
643             MKLDNNGraphTestClass graph;
644             graph.CreateGraph(net_reader.getNetwork());
645
646             InferenceEngine::SizeVector dims_src1 = {p.in1.n, p.in1.c, p.in1.h, p.in1.w};
647             InferenceEngine::SizeVector dims_src2 = {p.in2.n, p.in2.c, p.in2.h, p.in2.w};
648             InferenceEngine::SizeVector dims_src3 = {p.in3.n, p.in3.c, p.in3.h, p.in3.w};
649
650             InferenceEngine::Blob::Ptr src1 = InferenceEngine::make_shared_blob<float, const InferenceEngine::SizeVector>(InferenceEngine::Precision::FP32, InferenceEngine::NCHW, dims_src1);
651             src1->allocate();
652             fill_data(src1->buffer(), src1->size());
653
654             InferenceEngine::Blob::Ptr src2 = InferenceEngine::make_shared_blob<float, const InferenceEngine::SizeVector>(InferenceEngine::Precision::FP32, InferenceEngine::NCHW, dims_src2);
655             src2->allocate();
656             fill_data(src2->buffer(), src2->size());
657
658             InferenceEngine::Blob::Ptr src3 = InferenceEngine::make_shared_blob<float, const InferenceEngine::SizeVector>(InferenceEngine::Precision::FP32, InferenceEngine::NCHW, dims_src3);
659             src3->allocate();
660             fill_data(src3->buffer(), src3->size());
661
662             InferenceEngine::BlobMap srcs;
663             srcs.insert(std::pair<std::string, InferenceEngine::Blob::Ptr>("in1", src1));
664             srcs.insert(std::pair<std::string, InferenceEngine::Blob::Ptr>("in2", src2));
665             srcs.insert(std::pair<std::string, InferenceEngine::Blob::Ptr>("in3", src3));
666
667             InferenceEngine::OutputsDataMap out;
668             out = net_reader.getNetwork().getOutputsInfo();
669             InferenceEngine::BlobMap outputBlobs;
670
671             for (auto it = out.begin(); it != out.end(); it++) {
672                 std::pair<std::string, InferenceEngine::DataPtr> item = *it;
673                 InferenceEngine::TBlob<float>::Ptr output;
674                 output = InferenceEngine::make_shared_blob<float>(item.second->getTensorDesc());
675                 output->allocate();
676                 outputBlobs[item.first] = output;
677             }
678
679             graph.Infer(srcs, outputBlobs);
680
681             for (auto concat : {p.concat1, p.concat2}) {
682                 float *src1_ptr;
683                 size_t src1_size;
684                 float *src2_ptr;
685                 size_t src2_size;
686                 InferenceEngine::Blob::Ptr src1_c;
687                 InferenceEngine::Blob::Ptr src2_c;
688
689                 switch (concat.input1) {
690                     case 1:
691                         src1_ptr = src2->buffer();
692                         src1_size = src2->size();
693                         src1_c = src2;
694                         break;
695                     case 2:
696                         src1_ptr = src3->buffer();
697                         src1_size = src3->size();
698                         src1_c = src3;
699                         break;
700                     default:
701                         src1_ptr = src1->buffer();
702                         src1_size = src1->size();
703                         src1_c = src1;
704                 }
705
706                 switch (concat.input2) {
707                     case 1:
708                         src2_ptr = src2->buffer();
709                         src2_size = src2->size();
710                         src2_c = src2;
711                         break;
712                     case 2:
713                         src2_ptr = src3->buffer();
714                         src2_size = src3->size();
715                         src2_c = src3;
716                         break;
717                     default:
718                         src2_ptr = src1->buffer();
719                         src2_size = src1->size();
720                         src2_c = src1;
721                 }
722
723                 float *dst_ptr = outputBlobs[concat.name]->buffer();
724                 size_t dst_size = outputBlobs[concat.name]->size();
725
726                 int len1 = 1, len2 = 1, cycles;
727                 for (int dim = concat.axis; dim < outputBlobs[concat.name]->dims().size(); dim++) {
728                     len1 *= src1_c->dims()[dim];
729                     len2 *= src2_c->dims()[dim];
730                 }
731                 cycles = concat.axis;
732
733                 int index1 = 0, index2 = 0, index = 0;
734                 for (int cycle = 0; cycle < cycles; cycle ++) {
735                     for (int i1 = 0; i1 < len1; i1++) {
736                         if (src1_ptr[index1] != dst_ptr[index])
737                         {
738                             FAIL() << concat.name << " index: " << index << " src: "
739                                    << src1_ptr[index1] << ", dst: " << dst_ptr[index];
740                         }
741                         index1++; index++;
742                     }
743                     for (int i2 = 0; i2 < len2; i2++) {
744                         if (src2_ptr[index2] != dst_ptr[index])
745                         {
746                             FAIL() << concat.name << " index: " << index << " src: "
747                                    << src2_ptr[index2] << ", dst: " << dst_ptr[index];
748                         }
749                         index2++; index++;
750                     }
751                 }
752             }
753         } catch (const InferenceEngine::details::InferenceEngineException &e) {
754             FAIL() << e.what();
755         }
756     }
757 };
758
759 TEST_P(MKLDNNGraphTwoConcatTests, TestsTwoConcat) {}
760
761 INSTANTIATE_TEST_CASE_P(
762         TestsTwoConcat, MKLDNNGraphTwoConcatTests,
763         ::testing::Values(
764                 two_concat_test_params {
765                         {1, 5, 2, 5},
766                         {3, 5, 2, 5},
767                         {1, 5, 2, 5},
768                         {"concat1", 0, 0, 1},
769                         {"concat2", 0, 1, 2}
770                 },
771                 two_concat_test_params {
772                         {1, 2, 2, 5},
773                         {1, 5, 2, 5},
774                         {3, 5, 2, 5},
775                         {"concat1", 1, 0, 1},
776                         {"concat2", 0, 1, 2}
777                 },
778                 two_concat_test_params {
779                         {1, 2, 2, 2},
780                         {1, 1, 2, 2},
781                         {1, 3, 2, 2},
782                         {"concat1", 1, 0, 1},
783                         {"concat2", 1, 1, 2}
784                 },
785                 two_concat_test_params {
786                         {1, 5, 2, 5},
787                         {3, 5, 2, 5},
788                         {1, 5, 2, 5},
789                         {"concat1", 0, 0, 1},
790                         {"concat2", 0, 2, 1}
791                 },
792                 two_concat_test_params {
793                         {1, 2, 2, 5},
794                         {1, 5, 2, 5},
795                         {3, 5, 2, 5},
796                         {"concat1", 1, 0, 1},
797                         {"concat2", 0, 2, 1}
798                 },
799                 two_concat_test_params {
800                         {1, 2, 2, 2},
801                         {1, 1, 2, 2},
802                         {1, 3, 2, 2},
803                         {"concat1", 1, 0, 1},
804                         {"concat2", 1, 2, 1}
805                 }));
806
807
808 class MKLDNNGraphTwoInputInConcatTests: public TestsCommon {
809     std::string model_t = R"V0G0N(
810 <net name="TwoConcatsDiffFwd" version="2" precision="FP32" batch="1">
811     <layers>
812         <layer name="in1" type="Input" precision="FP32" id="1">
813             <output>
814                 <port id="1">
815                     <dim>1</dim>
816                     <dim>3</dim>
817                     <dim>2</dim>
818                     <dim>2</dim>
819                 </port>
820             </output>
821         </layer>
822         <layer name="in2" type="Input" precision="FP32" id="2">
823             <output>
824                 <port id="1">
825                     <dim>1</dim>
826                     <dim>2</dim>
827                     <dim>2</dim>
828                     <dim>2</dim>
829                 </port>
830             </output>
831         </layer>
832         <layer name="norm" id="3" type="ReLU" precision="FP32">
833             <input>
834                 <port id="1">
835                     <dim>1</dim>
836                     <dim>3</dim>
837                     <dim>2</dim>
838                     <dim>2</dim>
839                 </port>
840             </input>
841             <output>
842                 <port id="2">
843                     <dim>1</dim>
844                     <dim>3</dim>
845                     <dim>2</dim>
846                     <dim>2</dim>
847                 </port>
848             </output>
849         </layer>
850         <layer name="power" id="4" type="Power" precision="FP32">
851             <power_data power="-1" scale="-1" shift="0"/>
852             <input>
853                 <port id="1">
854                     <dim>1</dim>
855                     <dim>3</dim>
856                     <dim>2</dim>
857                     <dim>2</dim>
858                 </port>
859             </input>
860             <output>
861                 <port id="2">
862                     <dim>1</dim>
863                     <dim>3</dim>
864                     <dim>2</dim>
865                     <dim>2</dim>
866                 </port>
867             </output>
868         </layer>
869         <layer name="o_concat" id="5" type="Concat" precision="FP32">
870             <concat_data axis="1"/>
871             <input>
872                 <port id="1">
873                     <dim>1</dim>
874                     <dim>2</dim>
875                     <dim>2</dim>
876                     <dim>2</dim>
877                 </port>
878                 <port id="2">
879                     <dim>1</dim>
880                     <dim>3</dim>
881                     <dim>2</dim>
882                     <dim>2</dim>
883                 </port>
884             </input>
885             <output>
886                 <port id="3">
887                     <dim>1</dim>
888                     <dim>5</dim>
889                     <dim>2</dim>
890                     <dim>2</dim>
891                 </port>
892             </output>
893         </layer>
894     </layers>
895     <edges>
896         <edge from-layer="1" from-port="1" to-layer="3" to-port="1"/>
897         <edge from-layer="1" from-port="1" to-layer="5" to-port="2"/>
898         <edge from-layer="1" from-port="1" to-layer="4" to-port="1"/>
899         <edge from-layer="2" from-port="1" to-layer="5" to-port="1"/>
900     </edges>
901 </net>
902 )V0G0N";
903
904 protected:
905     virtual void TearDown() {
906     }
907
908     virtual void SetUp() {
909         try {
910             TestsCommon::SetUp();
911             std::string model = model_t;
912
913             InferenceEngine::CNNNetReader net_reader;
914             ASSERT_NO_THROW(net_reader.ReadNetwork(model.data(), model.length()));
915
916             MKLDNNGraphTestClass graph;
917             graph.CreateGraph(net_reader.getNetwork());
918
919             InferenceEngine::SizeVector dims_src1 = {1, 3, 2, 2};
920             InferenceEngine::SizeVector dims_src2 = {1, 2, 2, 2};
921
922             InferenceEngine::Blob::Ptr src1 = InferenceEngine::make_shared_blob<float, const InferenceEngine::SizeVector>(InferenceEngine::Precision::FP32, InferenceEngine::NCHW, dims_src1);
923             src1->allocate();
924             float *src1_data = src1->buffer();
925             for (size_t i = 0; i < src1->size(); i++) {
926                 src1_data[i] = i + 1;
927             }
928
929             InferenceEngine::Blob::Ptr src2 = InferenceEngine::make_shared_blob<float, const InferenceEngine::SizeVector>(InferenceEngine::Precision::FP32, InferenceEngine::NCHW, dims_src2);
930             src2->allocate();
931             fill_data(src2->buffer(), src2->size());
932
933             InferenceEngine::BlobMap srcs;
934             srcs.insert(std::pair<std::string, InferenceEngine::Blob::Ptr>("in1", src1));
935             srcs.insert(std::pair<std::string, InferenceEngine::Blob::Ptr>("in2", src2));
936
937             InferenceEngine::OutputsDataMap out;
938             out = net_reader.getNetwork().getOutputsInfo();
939             InferenceEngine::BlobMap outputBlobs;
940
941             for (auto it = out.begin(); it != out.end(); it++) {
942                 std::pair<std::string, InferenceEngine::DataPtr> item = *it;
943                 InferenceEngine::TBlob<float>::Ptr output;
944                 output = InferenceEngine::make_shared_blob<float>(item.second->getTensorDesc());
945                 output->allocate();
946                 outputBlobs[item.first] = output;
947             }
948
949             graph.Infer(srcs, outputBlobs);
950
951             float *src1_ptr = src2->buffer();
952             size_t src1_size = src2->size();
953             float *src2_ptr = src1->buffer();
954             size_t src2_size = src1->size();
955
956             float *dst_ptr = outputBlobs["o_concat"]->buffer();
957             size_t dst_size = outputBlobs["o_concat"]->size();
958
959             int len1 = 1, len2 = 1, cycles;
960             for (int dim = 1; dim < outputBlobs["o_concat"]->dims().size(); dim++) {
961                 len1 *= src2->dims()[dim];
962                 len2 *= src1->dims()[dim];
963             }
964             cycles = 1;
965
966             int index1 = 0, index2 = 0, index = 0;
967             for (int cycle = 0; cycle < cycles; cycle ++) {
968                 for (int i1 = 0; i1 < len1; i1++) {
969                     if (src1_ptr[index1] != dst_ptr[index])
970                     {
971                         FAIL() << "concat index: " << index << " src: "
972                                << src1_ptr[index1] << ", dst: " << dst_ptr[index];
973                     }
974                     index1++; index++;
975                 }
976                 for (int i2 = 0; i2 < len2; i2++) {
977                     if (src2_ptr[index2] != dst_ptr[index])
978                     {
979                         FAIL() << "concat index: " << index << " src: "
980                                << src2_ptr[index2] << ", dst: " << dst_ptr[index];
981                     }
982                     index2++; index++;
983                 }
984             }
985         } catch (const InferenceEngine::details::InferenceEngineException &e) {
986             FAIL() << e.what();
987         }
988     }
989 };
990
991 TEST_F(MKLDNNGraphTwoInputInConcatTests, TestSecondInputToConcat) {}
992
993 class MKLDNNGraphIncorrectConcatTests: public TestsCommon,
994                               public WithParamInterface<concat_test_params> {
995     std::string model_t = R"V0G0N(
996 <net name="ConcatOnly" version="2" precision="FP32" batch="1">
997     <layers>
998         <layer name="in1" type="Input" precision="FP32" id="1">
999             <output>
1000                 <port id="1">
1001                     <dim>_IN1_</dim>
1002                     <dim>_IC1_</dim>
1003                     <dim>_IH1_</dim>
1004                     <dim>_IW1_</dim>
1005                 </port>
1006             </output>
1007         </layer>
1008         <layer name="in2" type="Input" precision="FP32" id="2">
1009             <output>
1010                 <port id="2">
1011                     <dim>_IN2_</dim>
1012                     <dim>_IC2_</dim>
1013                     <dim>_IH2_</dim>
1014                     <dim>_IW2_</dim>
1015                 </port>
1016             </output>
1017         </layer>
1018         <layer name="con" id="3" type="Concat" precision="FP32">
1019             <concat_data axis="_AXIS_"/>
1020             <input>
1021                 <port id="1">
1022                     <dim>_IN1_</dim>
1023                     <dim>_IC1_</dim>
1024                     <dim>_IH1_</dim>
1025                     <dim>_IW1_</dim>
1026                 </port>
1027                 <port id="2">
1028                     <dim>_IN2_</dim>
1029                     <dim>_IC2_</dim>
1030                     <dim>_IH2_</dim>
1031                     <dim>_IW2_</dim>
1032                 </port>
1033             </input>
1034             <output>
1035                 <port id="3">
1036                     <dim>_ON_</dim>
1037                     <dim>_OC_</dim>
1038                     <dim>_OH_</dim>
1039                     <dim>_OW_</dim>
1040                 </port>
1041             </output>
1042         </layer>
1043     </layers>
1044     <edges>
1045         <edge from-layer="1" from-port="1" to-layer="3" to-port="1"/>
1046         <edge from-layer="2" from-port="2" to-layer="3" to-port="2"/>
1047     </edges>
1048 </net>
1049 )V0G0N";
1050
1051     std::string getModel(concat_test_params p) {
1052         std::string model = model_t;
1053         REPLACE_WITH_NUM(model, "_IN1_", p.in1.n);
1054         REPLACE_WITH_NUM(model, "_IC1_", p.in1.c);
1055         REPLACE_WITH_NUM(model, "_IW1_", p.in1.w);
1056         REPLACE_WITH_NUM(model, "_IH1_", p.in1.h);
1057
1058         REPLACE_WITH_NUM(model, "_IN2_", p.in2.n);
1059         REPLACE_WITH_NUM(model, "_IC2_", p.in2.c);
1060         REPLACE_WITH_NUM(model, "_IW2_", p.in2.w);
1061         REPLACE_WITH_NUM(model, "_IH2_", p.in2.h);
1062
1063         REPLACE_WITH_NUM(model, "_ON_", p.axis == 0 ? p.in1.n + p.in2.n : p.in1.n);
1064         REPLACE_WITH_NUM(model, "_OC_", p.axis == 1 ? p.in1.c + p.in2.c : p.in1.c);
1065         REPLACE_WITH_NUM(model, "_OH_", p.axis == 2 ? p.in1.h + p.in2.h : p.in1.h);
1066         REPLACE_WITH_NUM(model, "_OW_", p.axis == 3 ? p.in1.w + p.in2.w : p.in1.w);
1067
1068         REPLACE_WITH_NUM(model, "_AXIS_", p.axis);
1069         return model;
1070     }
1071
1072 protected:
1073     virtual void TearDown() {
1074     }
1075
1076     virtual void SetUp() {
1077         try {
1078             TestsCommon::SetUp();
1079             concat_test_params p = ::testing::WithParamInterface<concat_test_params>::GetParam();
1080             std::string model = getModel(p);
1081
1082             InferenceEngine::CNNNetReader net_reader;
1083             ASSERT_NO_THROW(net_reader.ReadNetwork(model.data(), model.length()));
1084
1085             MKLDNNGraphTestClass graph;
1086             ASSERT_THROW(graph.CreateGraph(net_reader.getNetwork()), InferenceEngine::details::InferenceEngineException);
1087         } catch (const InferenceEngine::details::InferenceEngineException &e) {
1088             FAIL() << e.what();
1089         }
1090     }
1091 };
1092
1093 TEST_P(MKLDNNGraphIncorrectConcatTests, TestsIncorrectConcat) {}
1094
1095
1096 INSTANTIATE_TEST_CASE_P(
1097         TestsIncorrectConcat, MKLDNNGraphIncorrectConcatTests,
1098         ::testing::Values(
1099                 concat_test_params {
1100                         {1, 7, 2, 5},
1101                         {1, 7, 3, 5},
1102                         1
1103                 },
1104                 concat_test_params {
1105                         {1, 7, 2, 5},
1106                         {1, 7, 4, 4},
1107                         2
1108                 }));