Publishing 2019 R1 content
[platform/upstream/dldt.git] / inference-engine / tests / unit / engines / mkldnn / graph / layers / extensions / gather_tests.cpp
1 // Copyright (C) 2018-2019 Intel Corporation
2 // SPDX-License-Identifier: Apache-2.0
3 //
4
5 #include <gtest/gtest.h>
6 #include <gmock/gmock-spec-builders.h>
7 #include "mkldnn_plugin/mkldnn_graph.h"
8
9 #include "test_graph.hpp"
10
11 #include "single_layer_common.hpp"
12 #include <mkldnn_plugin/mkldnn_extension_utils.h>
13 #include <extension/ext_list.hpp>
14 #include "tests_common.hpp"
15
16
17 using namespace ::testing;
18 using namespace std;
19 using namespace mkldnn;
20
21
22 struct gather_test_params {
23     std::string inIdxPrecision;
24     InferenceEngine::SizeVector inIdx;
25     InferenceEngine::SizeVector inDict;
26     int axis;
27     InferenceEngine::SizeVector out;
28
29     size_t num_prim_desc;
30     int selectedType;
31
32     std::vector<std::function<void(MKLDNNPlugin::PrimitiveDescInfo)>> comp;
33 };
34
35 template <typename data_t>
36 void ref_gather(InferenceEngine::TBlob<data_t> &srcIdx, InferenceEngine::TBlob<float> &srcDct, InferenceEngine::TBlob<float> &dst, size_t axis) {
37     size_t i, j;
38     const data_t *src_dataIdx = srcIdx.data();
39     float* src_dataDict = srcDct.data();
40     float *dst_data = dst.data();
41     size_t src_size = srcIdx.size();
42
43     std::vector<size_t> dims = srcDct.getTensorDesc().getDims();
44     std::vector<size_t> dims_actual;
45
46     //  Remove redundant dimensions
47     for (size_t i = 0; i < dims.size(); i++) {
48         if (dims[i] > 1) {
49             for (size_t j = i; j < dims.size(); j++)
50                 dims_actual.push_back(dims[j]);
51             break;
52         }
53     }
54
55     //  Find number of dictionaries, index range and data length
56     size_t numDictionaries = 1;
57     for (i = 0; i < axis; i++)
58         numDictionaries *= dims_actual[i];
59     size_t indexRange = dims_actual[axis];
60     size_t dataLength = 1;
61     for (i = axis + 1; i < dims_actual.size(); i++)
62         dataLength *= dims_actual[i];
63
64     //  The gathering process
65     for (i = 0; i < src_size; i++) {
66         unsigned int idx = static_cast<unsigned int>(src_dataIdx[i]);
67
68         //  Index clipping
69         if (idx < indexRange)
70         {
71             //  Copying data to destination from Dictionary
72             for (j = 0; j < numDictionaries; j++) {
73                 memcpy(&dst_data[dataLength * (i + j * src_size)],
74                        &src_dataDict[dataLength * (idx + j * indexRange)], sizeof(float) * dataLength);
75             }
76         } else {
77             for (j = 0; j < numDictionaries; j++) {
78                 std::fill_n(&dst_data[dataLength * (i + j * src_size)], dataLength, 0.0f);
79             }
80         }
81     }
82 }
83
84 class MKLDNNCPUExtGatherTests: public TestsCommon, public WithParamInterface<gather_test_params> {
85     std::string model_t = R"V0G0N(
86 <net Name="Gather_net" version="2" precision="FP32" batch="1">
87     <layers>
88         <layer name="InputText" type="Input" precision="_IIDXP_" id="1">
89             <output>
90                 <port id="1">
91                     _IIDX_
92                 </port>
93             </output>
94         </layer>
95         <layer name="InputDictionary" type="Input" precision="FP32" id="2">
96             <output>
97                 <port id="2">
98                     _IDICT_
99                 </port>
100             </output>
101         </layer>
102         <layer name="gather" id="3" type="Gather" precision="FP32">
103             <data axis="_AX_"/>
104             <input>
105                 <port id="1">
106                     _IDICT_
107                 </port>
108                 <port id="2">
109                     _IIDX_
110                 </port>
111             </input>
112             <output>
113                 <port id="3">
114                     _OUT_
115                 </port>
116             </output>
117         </layer>
118     </layers>
119     <edges>
120         <edge from-layer="1" from-port="1" to-layer="3" to-port="2"/>
121         <edge from-layer="2" from-port="2" to-layer="3" to-port="1"/>
122     </edges>
123 </net>
124 )V0G0N";
125
126     std::string getModel(gather_test_params p) {
127         std::string model = model_t;
128         std::string inIdx;
129         std::string inDict;
130         std::string out;
131
132         for (auto& idx : p.inIdx) {
133             inIdx += "<dim>";
134             inIdx += std::to_string(idx) + "</dim>\n";
135         }
136
137         for (auto& dct : p.inDict) {
138             inDict += "<dim>";
139             inDict += std::to_string(dct) + "</dim>\n";
140         }
141
142         for (auto& dst : p.out) {
143             out += "<dim>";
144             out += std::to_string(dst) + "</dim>\n";
145         }
146
147         REPLACE_WITH_STR(model, "_IIDXP_", p.inIdxPrecision);
148         REPLACE_WITH_STR(model, "_IIDX_", inIdx);
149         REPLACE_WITH_STR(model, "_IDICT_", inDict);
150         REPLACE_WITH_NUM(model, "_AX_", p.axis);
151         REPLACE_WITH_STR(model, "_OUT_", out);
152
153         return model;
154     }
155
156     template <typename data_t>
157     static void fill_data_dbgval(data_t *data, size_t size) {
158         for (size_t i = 0; i < size; i++) {
159             data[i] = static_cast<data_t>(i & (sizeof(data_t) * 8 - 1));
160         }
161     }
162 protected:
163     virtual void TearDown() {
164     }
165
166     virtual void SetUp() {
167         try {
168             TestsCommon::SetUp();
169             gather_test_params p = ::testing::WithParamInterface<gather_test_params>::GetParam();
170             std::string model = getModel(p);
171
172             InferenceEngine::CNNNetReader net_reader;
173             ASSERT_NO_THROW(net_reader.ReadNetwork(model.data(), model.length()));
174
175             InferenceEngine::Extension cpuExt(make_so_name("cpu_extension"));
176             MKLDNNPlugin::MKLDNNExtensionManager::Ptr extMgr(new MKLDNNPlugin::MKLDNNExtensionManager());
177             extMgr->AddExtension(InferenceEngine::IExtensionPtr(&cpuExt, [](InferenceEngine::IExtension*){}));
178
179             MKLDNNGraphTestClass graph;
180             graph.CreateGraph(net_reader.getNetwork(), extMgr);
181
182             auto& nodes = graph.getNodes();
183             nodes = graph.getNodes();
184
185             for (auto &node : nodes) {
186                 if (node->getName() == "gather") {
187                     ASSERT_EQ(p.num_prim_desc, node->getSupportedPrimitiveDescriptors().size());
188                     for (size_t j = 0; j < p.num_prim_desc && j < p.comp.size(); j++) {
189                         p.comp.at(j)(node->getSupportedPrimitiveDescriptors().at(j));
190                     }
191                     ASSERT_NE(nullptr, node->getSelectedPrimitiveDescriptor());
192                     ASSERT_EQ(p.selectedType,
193                               node->getSelectedPrimitiveDescriptor()->getImplementationType() & p.selectedType);
194                 }
195             }
196             ASSERT_EQ(4, nodes.size());
197
198             // Input Dictionary
199             InferenceEngine::Blob::Ptr srcDict = InferenceEngine::make_shared_blob<float>({ InferenceEngine::Precision::FP32, p.inDict, InferenceEngine::TensorDesc::getLayoutByDims(p.inDict) });
200             srcDict->allocate();
201             fill_data(srcDict->buffer(), srcDict->size());
202             auto * srcDictPtr = dynamic_cast<InferenceEngine::TBlob<float>*>(srcDict.get());
203             if (srcDictPtr == nullptr)
204                 FAIL() << "Cannot cast blob to TBlob<float>.";
205
206             // Output Data
207             InferenceEngine::OutputsDataMap out;
208             out = net_reader.getNetwork().getOutputsInfo();
209             InferenceEngine::BlobMap outputBlobs;
210
211             std::pair<std::string, InferenceEngine::DataPtr> item = *out.begin();
212
213             InferenceEngine::TBlob<float>::Ptr output;
214             output = InferenceEngine::make_shared_blob<float>(item.second->getTensorDesc());
215             output->allocate();
216             outputBlobs[item.first] = output;
217
218             // Output Reference
219             InferenceEngine::TBlob<float> dst_ref(item.second->getTensorDesc());
220             dst_ref.allocate();
221
222             // Input Indexes
223             InferenceEngine::Blob::Ptr srcIdx;
224             if (p.inIdxPrecision == "I32") {
225                 srcIdx = InferenceEngine::make_shared_blob<int32_t>({ InferenceEngine::Precision::I32, p.inIdx, InferenceEngine::TensorDesc::getLayoutByDims(p.inIdx) });
226                 srcIdx->allocate();
227                 fill_data_dbgval(static_cast<int32_t*>(srcIdx->buffer()), srcIdx->size());
228                 auto * srcIdxPtr = dynamic_cast<InferenceEngine::TBlob<int32_t>*>(srcIdx.get());
229                 if (srcIdxPtr == nullptr)
230                     FAIL() << "Cannot cast blob to TBlob<int32_t>.";
231
232                 // Check results
233                 ref_gather(*srcIdxPtr, *srcDictPtr, dst_ref, p.axis);
234             }
235             else if (p.inIdxPrecision == "FP32") {
236                 srcIdx = InferenceEngine::make_shared_blob<float>({ InferenceEngine::Precision::FP32, p.inIdx, InferenceEngine::TensorDesc::getLayoutByDims(p.inIdx) });
237                 srcIdx->allocate();
238                 fill_data(srcIdx->buffer(), srcIdx->size());
239                 auto * srcIdxPtr = dynamic_cast<InferenceEngine::TBlob<float>*>(srcIdx.get());
240                 if (srcIdxPtr == nullptr)
241                     FAIL() << "Cannot cast blob to TBlob<float>.";
242
243                 // Check results
244                 ref_gather(*srcIdxPtr, *srcDictPtr, dst_ref, p.axis);
245             }
246             else if (p.inIdxPrecision == "U16") {
247                 srcIdx = InferenceEngine::make_shared_blob<uint16_t>({ InferenceEngine::Precision::U16, p.inIdx, InferenceEngine::TensorDesc::getLayoutByDims(p.inIdx) });
248                 srcIdx->allocate();
249                 fill_data_dbgval(static_cast<uint16_t*>(srcIdx->buffer()), srcIdx->size());
250                 auto * srcIdxPtr = dynamic_cast<InferenceEngine::TBlob<uint16_t>*>(srcIdx.get());
251                 if (srcIdxPtr == nullptr)
252                     FAIL() << "Cannot cast blob to TBlob<uint16_t>.";
253
254                 // Check results
255                 ref_gather(*srcIdxPtr, *srcDictPtr, dst_ref, p.axis);
256             }
257             else if (p.inIdxPrecision == "I16") {
258                 srcIdx = InferenceEngine::make_shared_blob<int16_t>({ InferenceEngine::Precision::I16, p.inIdx, InferenceEngine::TensorDesc::getLayoutByDims(p.inIdx) });
259                 srcIdx->allocate();
260                 fill_data_dbgval(static_cast<int16_t*>(srcIdx->buffer()), srcIdx->size());
261                 auto * srcIdxPtr = dynamic_cast<InferenceEngine::TBlob<int16_t>*>(srcIdx.get());
262                 if (srcIdxPtr == nullptr)
263                     FAIL() << "Cannot cast blob to TBlob<int16_t>.";
264
265                 // Check results
266                 ref_gather(*srcIdxPtr, *srcDictPtr, dst_ref, p.axis);
267             }
268             else if (p.inIdxPrecision == "U8") {
269                 srcIdx = InferenceEngine::make_shared_blob<uint8_t>({ InferenceEngine::Precision::U8, p.inIdx, InferenceEngine::TensorDesc::getLayoutByDims(p.inIdx) });
270                 srcIdx->allocate();
271                 fill_data_dbgval(static_cast<uint8_t*>(srcIdx->buffer()), srcIdx->size());
272                 auto * srcIdxPtr = dynamic_cast<InferenceEngine::TBlob<uint8_t>*>(srcIdx.get());
273                 if (srcIdxPtr == nullptr)
274                     FAIL() << "Cannot cast blob to TBlob<uint8_t>.";
275
276                 // Check results
277                 ref_gather(*srcIdxPtr, *srcDictPtr, dst_ref, p.axis);
278             }
279             else if (p.inIdxPrecision == "I8") {
280                 srcIdx = InferenceEngine::make_shared_blob<int8_t>({ InferenceEngine::Precision::I8, p.inIdx, InferenceEngine::TensorDesc::getLayoutByDims(p.inIdx) });
281                 srcIdx->allocate();
282                 fill_data_dbgval(static_cast<int8_t*>(srcIdx->buffer()), srcIdx->size());
283                 auto * srcIdxPtr = dynamic_cast<InferenceEngine::TBlob<int8_t>*>(srcIdx.get());
284                 if (srcIdxPtr == nullptr)
285                     FAIL() << "Cannot cast blob to TBlob<int8_t>.";
286
287                 // Check results
288                 ref_gather(*srcIdxPtr, *srcDictPtr, dst_ref, p.axis);
289             }
290             else {
291                 return;
292             }
293
294             InferenceEngine::BlobMap srcs;
295             srcs.insert(std::pair<std::string, InferenceEngine::Blob::Ptr>("InputDictionary", srcDict));
296             srcs.insert(std::pair<std::string, InferenceEngine::Blob::Ptr>("InputText", srcIdx));
297
298             // Infer
299             graph.Infer(srcs, outputBlobs);
300             compare(*output, dst_ref);
301         } catch (const InferenceEngine::details::InferenceEngineException &e) {
302             FAIL() << e.what();
303         }
304     }
305 };
306
307 TEST_P(MKLDNNCPUExtGatherTests, TestsGather) {}
308
309 INSTANTIATE_TEST_CASE_P(
310         TestsGather, MKLDNNCPUExtGatherTests,
311             ::testing::Values(
312                 gather_test_params{ "FP32", {1, 1, 12, 256}, {1, 1, 71, 16}, 0, {1, 12, 256, 16}, 1, MKLDNNPlugin::impl_desc_type::unknown },
313                 gather_test_params{  "I32", {1, 1, 12, 256}, {1, 1, 71, 16}, 0, {1, 12, 256, 16}, 1, MKLDNNPlugin::impl_desc_type::unknown },
314                 gather_test_params{  "I32", {12, 256}, {71, 16}, 0, {12, 256, 16}, 1, MKLDNNPlugin::impl_desc_type::unknown },
315                 gather_test_params{  "I32", {3, 4}, {2, 5, 6}, 0, {3, 4, 5, 6}, 1, MKLDNNPlugin::impl_desc_type::unknown },
316                 gather_test_params{  "I32", {3, 4}, {5, 1}, 0, {3, 4, 1}, 1, MKLDNNPlugin::impl_desc_type::unknown },
317                 gather_test_params{ "FP32", {1, 1, 12, 256}, {1, 1, 71, 16}, 1, {1, 71, 12, 256}, 1, MKLDNNPlugin::impl_desc_type::unknown },
318                 gather_test_params{  "I32", {1, 1, 3, 4}, {1, 2, 5, 6}, 1, {2, 3, 4, 6}, 1, MKLDNNPlugin::impl_desc_type::unknown },
319                 gather_test_params{  "I32", {1, 1, 3, 4}, {1, 2, 5, 6}, 2, {2, 5, 3, 4}, 1, MKLDNNPlugin::impl_desc_type::unknown },
320                 gather_test_params{  "I32", {12, 4, 9, 8}, {6, 13, 10, 3}, 1, {6, 12, 4, 9, 8, 10, 3}, 1, MKLDNNPlugin::impl_desc_type::unknown }
321             ));
322
323
324
325
326 struct gatherTF_test_params {
327     InferenceEngine::SizeVector in_dim;
328     std::vector<int32_t> in;
329
330     InferenceEngine::SizeVector dct_dim;
331     std::vector<float> dct;
332
333     int axis;
334
335     InferenceEngine::SizeVector ref_dim;
336     std::vector<float> ref;
337
338     std::vector<std::function<void(MKLDNNPlugin::PrimitiveDescInfo)>> comp;
339 };
340
341 class MKLDNNCPUExtGatherTFTests : public TestsCommon, public WithParamInterface<gatherTF_test_params> {
342     std::string model_t = R"V0G0N(
343 <net Name="Gather_net" version="2" precision="FP32" batch="1">
344     <layers>
345         <layer name="InputText" type="Input" precision="I32" id="1">
346             <output>
347                 <port id="1">
348                     _IIDX_
349                 </port>
350             </output>
351         </layer>
352         <layer name="InputDictionary" type="Input" precision="FP32" id="2">
353             <output>
354                 <port id="2">
355                     _IDICT_
356                 </port>
357             </output>
358         </layer>
359         <layer name="gather" id="3" type="Gather" precision="FP32">
360             <data axis="_AX_"/>
361             <input>
362                 <port id="1">
363                     _IDICT_
364                 </port>
365                 <port id="2">
366                     _IIDX_
367                 </port>
368             </input>
369             <output>
370                 <port id="3">
371                     _OUT_
372                 </port>
373             </output>
374         </layer>
375     </layers>
376     <edges>
377         <edge from-layer="1" from-port="1" to-layer="3" to-port="2"/>
378         <edge from-layer="2" from-port="2" to-layer="3" to-port="1"/>
379     </edges>
380 </net>
381 )V0G0N";
382
383     std::string getModel(gatherTF_test_params p) {
384         std::string model = model_t;
385         std::string inIdx;
386         std::string inDict;
387         std::string out;
388
389         for (auto& idx : p.in_dim) {
390             inIdx += "<dim>";
391             inIdx += std::to_string(idx) + "</dim>\n";
392         }
393
394         for (auto& dct : p.dct_dim) {
395             inDict += "<dim>";
396             inDict += std::to_string(dct) + "</dim>\n";
397         }
398
399         for (auto& dst : p.ref_dim) {
400             out += "<dim>";
401             out += std::to_string(dst) + "</dim>\n";
402         }
403
404         REPLACE_WITH_STR(model, "_IIDX_", inIdx);
405         REPLACE_WITH_STR(model, "_IDICT_", inDict);
406         REPLACE_WITH_NUM(model, "_AX_", p.axis);
407         REPLACE_WITH_STR(model, "_OUT_", out);
408
409         return model;
410     }
411
412 protected:
413     virtual void TearDown() {
414     }
415
416     virtual void SetUp() {
417         try {
418             TestsCommon::SetUp();
419             gatherTF_test_params p = ::testing::WithParamInterface<gatherTF_test_params>::GetParam();
420             std::string model = getModel(p);
421
422             InferenceEngine::CNNNetReader net_reader;
423             ASSERT_NO_THROW(net_reader.ReadNetwork(model.data(), model.length()));
424
425             InferenceEngine::Extension cpuExt(make_so_name("cpu_extension"));
426             MKLDNNPlugin::MKLDNNExtensionManager::Ptr extMgr(new MKLDNNPlugin::MKLDNNExtensionManager());
427             extMgr->AddExtension(InferenceEngine::IExtensionPtr(&cpuExt, [](InferenceEngine::IExtension*){}));
428
429             MKLDNNGraphTestClass graph;
430             graph.CreateGraph(net_reader.getNetwork(), extMgr);
431
432             // Input Indexes
433             InferenceEngine::Blob::Ptr srcIdx;
434             srcIdx = InferenceEngine::make_shared_blob<int32_t>({ InferenceEngine::Precision::I32, p.in_dim, InferenceEngine::TensorDesc::getLayoutByDims(p.in_dim) });
435             srcIdx->allocate();
436             memcpy(static_cast<int32_t*>(srcIdx->buffer()), &p.in[0], sizeof(int32_t)*p.in.size());
437             auto * srcIdxPtr = dynamic_cast<InferenceEngine::TBlob<int32_t>*>(srcIdx.get());
438             if (srcIdxPtr == nullptr)
439                 FAIL() << "Cannot cast blob to TBlob<int32_t>.";
440
441             //  Input Dictionary
442             InferenceEngine::Blob::Ptr srcDict = InferenceEngine::make_shared_blob<float>({ InferenceEngine::Precision::FP32, p.dct_dim, InferenceEngine::TensorDesc::getLayoutByDims(p.dct_dim) });
443             srcDict->allocate();
444             memcpy(srcDict->buffer(), &p.dct[0], sizeof(float)*p.dct.size());
445             auto * srcDictPtr = dynamic_cast<InferenceEngine::TBlob<float>*>(srcDict.get());
446             if (srcDictPtr == nullptr)
447                 FAIL() << "Cannot cast blob to TBlob<float>.";
448
449             //  Output Data
450             InferenceEngine::OutputsDataMap out;
451             out = net_reader.getNetwork().getOutputsInfo();
452             InferenceEngine::BlobMap outputBlobs;
453             std::pair<std::string, InferenceEngine::DataPtr> item = *out.begin();
454             InferenceEngine::TBlob<float>::Ptr output;
455             output = InferenceEngine::make_shared_blob<float>(item.second->getTensorDesc());
456             output->allocate();
457             outputBlobs[item.first] = output;
458
459             //  Infer
460             InferenceEngine::BlobMap srcs;
461             srcs.insert(std::pair<std::string, InferenceEngine::Blob::Ptr>("InputDictionary", srcDict));
462             srcs.insert(std::pair<std::string, InferenceEngine::Blob::Ptr>("InputText", srcIdx));
463             graph.Infer(srcs, outputBlobs);
464
465             //  Check results
466             if (memcmp((*output).data(), &p.ref[0], p.ref.size()) != 0)
467                 FAIL() << "Wrong result with compare TF reference!";
468         } catch (const InferenceEngine::details::InferenceEngineException &e) {
469             FAIL() << e.what();
470         }
471     }
472 };
473
474 TEST_P(MKLDNNCPUExtGatherTFTests, TestsGather) {}
475
476 //  Test data vectors
477 std::vector<int32_t> in0 = { 0, 1, 1, 0 };
478 std::vector<int32_t> in1 = { 0, 1, 2, 1 };
479 std::vector<float> dict = { 1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f, 9.f, 10.f, 11.f, 12.f };
480 std::vector<float> ref_in0_a0_d223 = { 1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f, 9.f, 10.f, 11.f, 12.f, 7.f, 8.f, 9.f, 10.f, 11.f, 12.f, 1.f, 2.f, 3.f, 4.f, 5.f, 6.f }; // 2x2x2x3
481 std::vector<float> ref_in0_a2_d232 = { 1.f, 2.f, 2.f, 1.f, 3.f, 4.f, 4.f, 3.f, 5.f, 6.f, 6.f, 5.f, 7.f, 8.f, 8.f, 7.f, 9.f, 10.f, 10.f, 9.f, 11.f, 12.f, 12.f, 11.f }; // 2x3x2x2
482 std::vector<float> ref_in1_a0_d322 = { 1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f, 9.f, 10.f, 11.f, 12.f, 5.f, 6.f, 7.f, 8.f }; // 2x2x2x2
483 std::vector<float> ref_in1_a1_d232 = { 1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 3.f, 4.f, 7.f, 8.f, 9.f, 10.f, 11.f, 12.f, 9.f, 10.f }; // 2x2x2x2
484 std::vector<float> ref_in1_a2_d223 = { 1.f, 2.f, 3.f, 2.f, 4.f, 5.f, 6.f, 5.f, 7.f, 8.f, 9.f, 8.f, 10.f, 11.f, 12.f, 11.f }; // 2x2x2x2
485
486 INSTANTIATE_TEST_CASE_P(
487         TestsGather, MKLDNNCPUExtGatherTFTests,
488         ::testing::Values(
489         gatherTF_test_params{ { 2, 2 }, in0,{ 2, 2, 3 }, dict, 0, { 2, 2, 2, 3 }, ref_in0_a0_d223 },
490         gatherTF_test_params{ { 2, 2 }, in0,{ 2, 2, 3 }, dict,-3, { 2, 2, 2, 3 }, ref_in0_a0_d223 },
491         gatherTF_test_params{ { 2, 2 }, in0,{ 2, 3, 2 }, dict, 2, { 2, 3, 2, 2 }, ref_in0_a2_d232 },
492         gatherTF_test_params{ { 2, 2 }, in0,{ 2, 3, 2 }, dict,-1, { 2, 3, 2, 2 }, ref_in0_a2_d232 },
493         gatherTF_test_params{ { 2, 2 }, in1,{ 3, 2, 2 }, dict, 0, { 2, 2, 2, 2 }, ref_in1_a0_d322 },
494         gatherTF_test_params{ { 2, 2 }, in1,{ 3, 2, 2 }, dict,-3, { 2, 2, 2, 2 }, ref_in1_a0_d322 },
495         gatherTF_test_params{ { 2, 2 }, in1,{ 2, 3, 2 }, dict, 1, { 2, 2, 2, 2 }, ref_in1_a1_d232 },
496         gatherTF_test_params{ { 2, 2 }, in1,{ 2, 3, 2 }, dict,-2, { 2, 2, 2, 2 }, ref_in1_a1_d232 },
497         gatherTF_test_params{ { 2, 2 }, in1,{ 2, 2, 3 }, dict, 2, { 2, 2, 2, 2 }, ref_in1_a2_d223 },
498         gatherTF_test_params{ { 2, 2 }, in1,{ 2, 2, 3 }, dict,-1, { 2, 2, 2, 2 }, ref_in1_a2_d223 }));
499
500
501 class MKLDNNCPUExtGatherHolesTests : public TestsCommon, public WithParamInterface<gatherTF_test_params> {
502     std::string model_t = R"V0G0N(
503 <net Name="Gather_net" version="2" precision="FP32" batch="1">
504     <layers>
505         <layer name="InputText" type="Input" precision="I32" id="1">
506             <output>
507                 <port id="1">
508                     <dim>2</dim>
509                     <dim>2</dim>
510                 </port>
511             </output>
512         </layer>
513         <layer name="InputDictionary" type="Input" precision="FP32" id="2">
514             <output>
515                 <port id="2">
516                     <dim>3</dim>
517                     <dim>2</dim>
518                     <dim>2</dim>
519                 </port>
520             </output>
521         </layer>
522         <layer name="Input3" type="Input" precision="FP32" id="3">
523             <output>
524                 <port id="3">
525                     <dim>2</dim>
526                     <dim>5</dim>
527                     <dim>2</dim>
528                     <dim>2</dim>
529                 </port>
530             </output>
531         </layer>
532         <layer name="gather" id="4" type="Gather" precision="FP32">
533             <data axis="0"/>
534             <input>
535                 <port id="1">
536                     <dim>3</dim>
537                     <dim>2</dim>
538                     <dim>2</dim>
539                 </port>
540                 <port id="2">
541                     <dim>2</dim>
542                     <dim>2</dim>
543                 </port>
544             </input>
545             <output>
546                 <port id="3">
547                     <dim>2</dim>
548                     <dim>2</dim>
549                     <dim>2</dim>
550                     <dim>2</dim>
551                 </port>
552             </output>
553         </layer>
554         <layer name="con" id="5" type="Concat" precision="FP32">
555             <concat_data axis="1"/>
556             <input>
557                 <port id="1">
558                     <dim>2</dim>
559                     <dim>2</dim>
560                     <dim>2</dim>
561                     <dim>2</dim>
562                 </port>
563                 <port id="2">
564                     <dim>2</dim>
565                     <dim>5</dim>
566                     <dim>2</dim>
567                     <dim>2</dim>
568                 </port>
569             </input>
570             <output>
571                 <port id="3">
572                     <dim>2</dim>
573                     <dim>7</dim>
574                     <dim>2</dim>
575                     <dim>2</dim>
576                 </port>
577             </output>
578         </layer>
579     </layers>
580     <edges>
581         <edge from-layer="1" from-port="1" to-layer="4" to-port="2"/>
582         <edge from-layer="2" from-port="2" to-layer="4" to-port="1"/>
583         <edge from-layer="4" from-port="3" to-layer="5" to-port="1"/>
584         <edge from-layer="3" from-port="3" to-layer="5" to-port="2"/>
585     </edges>
586 </net>
587 )V0G0N";
588
589     std::string getModel(gatherTF_test_params p) {
590         std::string model = model_t;
591         std::string inIdx;
592         std::string inDict;
593         std::string out;
594
595         for (auto& idx : p.in_dim) {
596             inIdx += "<dim>";
597             inIdx += std::to_string(idx) + "</dim>\n";
598         }
599
600         for (auto& dct : p.dct_dim) {
601             inDict += "<dim>";
602             inDict += std::to_string(dct) + "</dim>\n";
603         }
604
605         for (auto& dst : p.ref_dim) {
606             out += "<dim>";
607             out += std::to_string(dst) + "</dim>\n";
608         }
609
610         REPLACE_WITH_STR(model, "_OUTC_", inIdx);
611         REPLACE_WITH_STR(model, "_IDICT_", inDict);
612         REPLACE_WITH_NUM(model, "_AX_", p.axis);
613         REPLACE_WITH_STR(model, "_OUT_", out);
614
615         return model;
616     }
617
618 protected:
619     virtual void TearDown() {
620     }
621
622     virtual void SetUp() {
623         try {
624             TestsCommon::SetUp();
625             gatherTF_test_params p = ::testing::WithParamInterface<gatherTF_test_params>::GetParam();
626             std::string model = getModel(p);
627
628             InferenceEngine::CNNNetReader net_reader;
629             ASSERT_NO_THROW(net_reader.ReadNetwork(model.data(), model.length()));
630
631             InferenceEngine::Extension cpuExt(make_so_name("cpu_extension"));
632             MKLDNNPlugin::MKLDNNExtensionManager::Ptr extMgr(new MKLDNNPlugin::MKLDNNExtensionManager());
633             extMgr->AddExtension(InferenceEngine::IExtensionPtr(&cpuExt, [](InferenceEngine::IExtension*){}));
634
635             MKLDNNGraphTestClass graph;
636             graph.CreateGraph(net_reader.getNetwork(), extMgr);
637
638             // Input Indexes
639             InferenceEngine::Blob::Ptr srcIdx;
640             int32_t in_size = 4;
641             InferenceEngine::SizeVector in_dim = {2, 2};
642             srcIdx = InferenceEngine::make_shared_blob<int32_t>({ InferenceEngine::Precision::I32, in_dim, InferenceEngine::TensorDesc::getLayoutByDims(in_dim) });
643             srcIdx->allocate();
644             memcpy(static_cast<int32_t*>(srcIdx->buffer()), &p.in[0], sizeof(int32_t)*in_size);
645             auto * srcIdxPtr = dynamic_cast<InferenceEngine::TBlob<int32_t>*>(srcIdx.get());
646             if (srcIdxPtr == nullptr)
647                 FAIL() << "Cannot cast blob to TBlob<int32_t>.";
648
649             //  Input Dictionary
650             InferenceEngine::Blob::Ptr srcDict = InferenceEngine::make_shared_blob<float>({ InferenceEngine::Precision::FP32, p.dct_dim, InferenceEngine::TensorDesc::getLayoutByDims(p.dct_dim) });
651             srcDict->allocate();
652             memcpy(srcDict->buffer(), &p.dct[0], sizeof(float)*p.dct.size());
653             auto * srcDictPtr = dynamic_cast<InferenceEngine::TBlob<float>*>(srcDict.get());
654             if (srcDictPtr == nullptr)
655                 FAIL() << "Cannot cast blob to TBlob<float>.";
656
657             //  Output Data
658             InferenceEngine::OutputsDataMap out;
659             out = net_reader.getNetwork().getOutputsInfo();
660             InferenceEngine::BlobMap outputBlobs;
661             std::pair<std::string, InferenceEngine::DataPtr> item = *out.begin();
662             InferenceEngine::TBlob<float>::Ptr output;
663             output = InferenceEngine::make_shared_blob<float>(item.second->getTensorDesc());
664             output->allocate();
665             outputBlobs[item.first] = output;
666
667             //  Infer
668             InferenceEngine::BlobMap srcs;
669             srcs.insert(std::pair<std::string, InferenceEngine::Blob::Ptr>("InputDictionary", srcDict));
670             srcs.insert(std::pair<std::string, InferenceEngine::Blob::Ptr>("InputText", srcIdx));
671             srcs.insert(std::pair<std::string, InferenceEngine::Blob::Ptr>("Input3", srcIdx));
672             graph.Infer(srcs, outputBlobs);
673
674             //  Check results
675             if (memcmp((*output).data(), &p.ref[0], p.ref.size()) != 0)
676                 FAIL() << "Wrong result with compare TF reference!";
677         }
678         catch (const InferenceEngine::details::InferenceEngineException &e) {
679             FAIL() << e.what();
680         }
681     }
682 };
683
684 TEST_P(MKLDNNCPUExtGatherHolesTests, TestsGather) {}
685
686 INSTANTIATE_TEST_CASE_P(
687     TestsGather, MKLDNNCPUExtGatherHolesTests,
688     ::testing::Values(
689         gatherTF_test_params{ { 1, 5, 2, 2 }, in1,{ 1, 3, 2, 2 }, dict, 1,{ 2, 2, 2, 2 }, ref_in1_a0_d322 }));
690