Publishing 2019 R1 content
[platform/upstream/dldt.git] / inference-engine / tests / unit / engines / mkldnn / graph / layers / internal / graph_split_test.cpp
1 // Copyright (C) 2018-2019 Intel Corporation
2 // SPDX-License-Identifier: Apache-2.0
3 //
4
5 #include <gtest/gtest.h>
6 #include <gmock/gmock-spec-builders.h>
7 #include "mkldnn_plugin/mkldnn_graph.h"
8
9 #include "test_graph.hpp"
10
11 #include "single_layer_common.hpp"
12 #include <mkldnn_plugin/mkldnn_extension_utils.h>
13 #include <inference_engine/cnn_network_impl.hpp>
14 #include "tests_common.hpp"
15
16 using namespace ::testing;
17 using namespace std;
18 using namespace mkldnn;
19
20 struct split_test_params {
21     // Formats: NCHW, NCDHW
22     vector<size_t> dims;
23     std::vector<vector<size_t>> outs;
24
25     int axis;
26
27     size_t num_prim_desc;
28
29     MKLDNNPlugin::impl_desc_type selectedType;
30     std::vector<MKLDNNPlugin::impl_desc_type> preferTypes;
31
32     std::vector<std::function<void(MKLDNNPlugin::PrimitiveDescInfo)>> comp;
33 };
34
35 template <typename data_t>
36 void ref_split(InferenceEngine::TBlob<data_t> &src, std::vector<InferenceEngine::TBlob<data_t>>& dsts, split_test_params& prm) {
37     const float * srcData = src.readOnly();
38
39     int outerSize = 1;
40     for (int i = 0; i < prm.axis; i++)
41         outerSize *= src.dims()[i];
42
43     for (size_t osIdx = 0; osIdx < outerSize; osIdx++) {
44         for (size_t dstIdx = 0; dstIdx < dsts.size(); dstIdx++) {
45             float* dstData = dsts[dstIdx].data();
46             int innerSize = dsts[dstIdx].size() / outerSize;
47
48             for (size_t j = 0; j < innerSize; j++, srcData++) {
49                 dstData[osIdx*innerSize + j] = *srcData;
50             }
51         }
52     }
53 }
54
55 class MKLDNNGraphSplitTests: public TestsCommon,
56                               public WithParamInterface<split_test_params> {
57     std::string model_t = R"V0G0N(
58 <net name="ConcatOnly" version="3" precision="FP32" batch="1">
59     <layers>
60         <layer name="in1" type="Input" precision="FP32" id="1">
61             <output>
62                 <port id="1">
63                     <dim>_IN_</dim>
64                     <dim>_IC_</dim>
65                     <dim>_ID_</dim>
66                     <dim>_IH_</dim>
67                     <dim>_IW_</dim>
68                 </port>
69             </output>
70         </layer>
71         <layer name="split" id="2" type="Split" precision="FP32">
72             <split_data axis="_AXIS_" PrimitivesPriority="_IMPLS_"/>
73             <input>
74                 <port id="1">
75                     <dim>_IN_</dim>
76                     <dim>_IC_</dim>
77                     <dim>_ID_</dim>
78                     <dim>_IH_</dim>
79                     <dim>_IW_</dim>
80                 </port>
81             </input>
82             <output>
83                 _OP_
84             </output>
85         </layer>
86     </layers>
87     <edges>
88         <edge from-layer="1" from-port="1" to-layer="2" to-port="1"/>
89     </edges>
90 </net>
91 )V0G0N";
92
93     std::string port_t = R"V0G0N(
94 <port id="_ID_">
95     <dim>_N_</dim>
96     <dim>_C_</dim>
97     <dim>_D_</dim>
98     <dim>_H_</dim>
99     <dim>_W_</dim>
100 </port>
101 )V0G0N";
102
103 protected:
104     std::string getModel(split_test_params p) {
105         std::string model = model_t;
106         auto dims_size = p.dims.size();
107
108         switch (dims_size) {
109             case 3:
110                 REMOVE_LINE(model, "<dim>_IH_</dim>");
111             case 4:
112                 REMOVE_LINE(model, "<dim>_ID_</dim>");
113         }
114         REPLACE_WITH_NUM(model, "_IN_", p.dims[0]);
115         REPLACE_WITH_NUM(model, "_IC_", p.dims[1]);
116         REPLACE_WITH_NUM(model, "_IW_", p.dims[dims_size - 1]);
117         switch (dims_size) {
118             case 5:
119                 REPLACE_WITH_NUM(model, "_ID_", p.dims[dims_size - 3]);
120             case 4:
121                 REPLACE_WITH_NUM(model, "_IH_", p.dims[dims_size - 2]);
122         }
123
124         std::string outPorts;
125         for (int idx = 0; idx < p.outs.size(); idx++) {
126             std::string outPort = port_t;
127             switch (dims_size) {
128                 case 3:
129                     REMOVE_LINE(outPort, "<dim>_H_</dim>");
130                 case 4:
131                     REMOVE_LINE(outPort, "<dim>_D_</dim>");
132             }
133             REPLACE_WITH_NUM(outPort, "_ID_", idx);
134             REPLACE_WITH_NUM(outPort, "_N_", p.outs[idx][0]);
135             REPLACE_WITH_NUM(outPort, "_C_", p.outs[idx][1]);
136             REPLACE_WITH_NUM(outPort, "_W_", p.outs[idx][dims_size - 1]);
137             switch (dims_size) {
138                 case 5:
139                     REPLACE_WITH_NUM(outPort, "_D_", p.outs[idx][dims_size - 3]);
140                 case 4:
141                     REPLACE_WITH_NUM(outPort, "_H_", p.outs[idx][dims_size - 2]);
142             }
143
144             outPorts += outPort;
145         }
146         REPLACE_WITH_STR(model, "_OP_", outPorts);
147
148         REPLACE_WITH_NUM(model, "_AXIS_", p.axis);
149
150         std::string impls;
151         for (const auto& preferType : p.preferTypes) {
152             if (!impls.empty())
153                 impls += ",";
154             impls += "cpu:" + MKLDNNGraphTestClass::getStrPrimitiveDescriptorType(preferType);
155         }
156         REPLACE_WITH_STR(model, "_IMPLS_", impls);
157         return model;
158     }
159
160     virtual void TearDown() {
161     }
162
163     virtual void SetUp() {
164         try {
165             split_test_params p = ::testing::WithParamInterface<split_test_params>::GetParam();
166             std::string model = getModel(p);
167
168             InferenceEngine::CNNNetReader net_reader;
169             net_reader.ReadNetwork(model.data(), model.length());
170
171             MKLDNNGraphTestClass graph;
172             graph.CreateGraph(net_reader.getNetwork());
173             auto& nodes = graph.getNodes();
174             for (int i = 0; i < nodes.size(); i++) {
175                 if (nodes[i]->getType() == MKLDNNPlugin::Split) {
176                     ASSERT_EQ(p.num_prim_desc, nodes[i]->getSupportedPrimitiveDescriptors().size());
177                     for (size_t j = 0; j < p.num_prim_desc && j < p.comp.size(); j++) {
178                         p.comp.at(j)(nodes[i]->getSupportedPrimitiveDescriptors().at(j));
179                     }
180                     ASSERT_NE(nullptr, nodes[i]->getSelectedPrimitiveDescriptor());
181                     ASSERT_EQ(p.selectedType, nodes[i]->getSelectedPrimitiveDescriptor()->getImplementationType());
182                 }
183             }
184             ASSERT_LE(3, nodes.size());
185
186             InferenceEngine::SizeVector dims_src = p.dims;
187             InferenceEngine::Layout layout = InferenceEngine::ANY;
188             switch (p.dims.size()) {
189                 case 4:
190                     layout = InferenceEngine::NCHW;
191                     break;
192                 case 5:
193                     layout = InferenceEngine::NCDHW;
194                     break;
195             }
196
197             InferenceEngine::Blob::Ptr src = InferenceEngine::make_shared_blob<float, const InferenceEngine::SizeVector>(InferenceEngine::Precision::FP32, layout, dims_src);
198             src->allocate();
199             fill_data(src->buffer(), src->size());
200
201             InferenceEngine::BlobMap srcs;
202             srcs.insert(std::pair<std::string, InferenceEngine::Blob::Ptr>("in1", src));
203
204             auto srcPtr = dynamic_cast<InferenceEngine::TBlob<float>*>(src.get());
205
206             if (srcPtr == nullptr)
207                 FAIL() << "Cannot cast blob to TBlob<float>.";
208
209             InferenceEngine::OutputsDataMap out;
210             out = net_reader.getNetwork().getOutputsInfo();
211             InferenceEngine::BlobMap outputBlobs;
212             std::vector<InferenceEngine::TBlob<float>> dst_refs;
213             for (auto& item : out) {
214                 InferenceEngine::TBlob<float>::Ptr output;
215                 output = InferenceEngine::make_shared_blob<float>(item.second->getTensorDesc());
216                 output->allocate();
217                 outputBlobs[item.first] = output;
218
219                 InferenceEngine::TBlob<float> dst_ref(item.second->getTensorDesc());
220                 dst_ref.allocate();
221                 dst_refs.push_back(dst_ref);
222             }
223
224             graph.Infer(srcs, outputBlobs);
225
226             ref_split(*srcPtr, dst_refs, p);
227
228             int ref_idx = 0;
229             for (auto& output : outputBlobs) {
230                 compare(*output.second, dst_refs[ref_idx++], 0.0005f);
231             }
232         } catch (const InferenceEngine::details::InferenceEngineException &e) {
233             FAIL() << e.what();
234         }
235     }
236 };
237
238 TEST_P(MKLDNNGraphSplitTests, TestsSplit) {}
239
240 INSTANTIATE_TEST_CASE_P(
241         TestsSplit, MKLDNNGraphSplitTests,
242         ::testing::Values(
243                 split_test_params {
244                         {1, 24, 2, 5},
245                         {{1, 16, 2, 5}, {1, 8, 2, 5}},
246                         1, 3, MKLDNNPlugin::impl_desc_type::unknown, {}, {
247                                 [](MKLDNNPlugin::PrimitiveDescInfo impl) {
248                                     ASSERT_EQ(MKLDNNPlugin::impl_desc_type::ref, impl.getImplementationType());
249                                     ASSERT_EQ(1, impl.getConfig().inConfs.size());
250                                     ASSERT_EQ(2, impl.getConfig().outConfs.size());
251                                     ASSERT_EQ(InferenceEngine::Layout::ANY, impl.getConfig().inConfs.at(0).desc.getLayout());
252                                     ASSERT_EQ(InferenceEngine::Layout::ANY, impl.getConfig().outConfs.at(0).desc.getLayout());
253                                     ASSERT_EQ(InferenceEngine::Layout::ANY, impl.getConfig().outConfs.at(1).desc.getLayout());
254                                 },
255                                 [](MKLDNNPlugin::PrimitiveDescInfo impl) {
256                                     ASSERT_EQ(MKLDNNPlugin::impl_desc_type::unknown, impl.getImplementationType());
257                                     ASSERT_EQ(1, impl.getConfig().inConfs.size());
258                                     ASSERT_EQ(2, impl.getConfig().outConfs.size());
259                                     ASSERT_EQ(InferenceEngine::Layout::NCHW, impl.getConfig().inConfs.at(0).desc.getLayout());
260                                     ASSERT_EQ(InferenceEngine::Layout::NCHW, impl.getConfig().outConfs.at(0).desc.getLayout());
261                                     ASSERT_EQ(InferenceEngine::Layout::NCHW, impl.getConfig().outConfs.at(1).desc.getLayout());
262                                 },
263                                 [](MKLDNNPlugin::PrimitiveDescInfo impl) {
264                                     ASSERT_EQ(MKLDNNPlugin::impl_desc_type::unknown, impl.getImplementationType());
265                                     ASSERT_EQ(1, impl.getConfig().inConfs.size());
266                                     ASSERT_EQ(2, impl.getConfig().outConfs.size());
267                                     ASSERT_EQ(InferenceEngine::Layout::BLOCKED, impl.getConfig().inConfs.at(0).desc.getLayout());
268                                     ASSERT_EQ(InferenceEngine::Layout::BLOCKED, impl.getConfig().outConfs.at(0).desc.getLayout());
269                                     ASSERT_EQ(InferenceEngine::Layout::BLOCKED, impl.getConfig().outConfs.at(1).desc.getLayout());
270                                 }
271                         }
272                 },
273                 split_test_params {
274                         {1, 20, 2, 5},
275                         {{1, 13, 2, 5}, {1, 7, 2, 5}},
276                         1, 2, MKLDNNPlugin::impl_desc_type::unknown, {}, {
277                                 [](MKLDNNPlugin::PrimitiveDescInfo impl) {
278                                     ASSERT_EQ(MKLDNNPlugin::impl_desc_type::ref, impl.getImplementationType());
279                                     ASSERT_EQ(1, impl.getConfig().inConfs.size());
280                                     ASSERT_EQ(2, impl.getConfig().outConfs.size());
281                                     ASSERT_EQ(InferenceEngine::Layout::ANY, impl.getConfig().inConfs.at(0).desc.getLayout());
282                                     ASSERT_EQ(InferenceEngine::Layout::ANY, impl.getConfig().outConfs.at(0).desc.getLayout());
283                                     ASSERT_EQ(InferenceEngine::Layout::ANY, impl.getConfig().outConfs.at(1).desc.getLayout());
284                                 },
285                                 [](MKLDNNPlugin::PrimitiveDescInfo impl) {
286                                     ASSERT_EQ(MKLDNNPlugin::impl_desc_type::unknown, impl.getImplementationType());
287                                     ASSERT_EQ(1, impl.getConfig().inConfs.size());
288                                     ASSERT_EQ(2, impl.getConfig().outConfs.size());
289                                     ASSERT_EQ(InferenceEngine::Layout::NCHW, impl.getConfig().inConfs.at(0).desc.getLayout());
290                                     ASSERT_EQ(InferenceEngine::Layout::NCHW, impl.getConfig().outConfs.at(0).desc.getLayout());
291                                     ASSERT_EQ(InferenceEngine::Layout::NCHW, impl.getConfig().outConfs.at(1).desc.getLayout());
292                                 }
293                         }
294                 },
295                 split_test_params {
296                         {1, 20, 2, 5},
297                         {{1, 10, 2, 5}, {1, 10, 2, 5}},
298                         1, 2, MKLDNNPlugin::impl_desc_type::unknown, {}, {
299                                 [](MKLDNNPlugin::PrimitiveDescInfo impl) {
300                                     ASSERT_EQ(MKLDNNPlugin::impl_desc_type::ref, impl.getImplementationType());
301                                     ASSERT_EQ(1, impl.getConfig().inConfs.size());
302                                     ASSERT_EQ(2, impl.getConfig().outConfs.size());
303                                     ASSERT_EQ(InferenceEngine::Layout::ANY, impl.getConfig().inConfs.at(0).desc.getLayout());
304                                     ASSERT_EQ(InferenceEngine::Layout::ANY, impl.getConfig().outConfs.at(0).desc.getLayout());
305                                     ASSERT_EQ(InferenceEngine::Layout::ANY, impl.getConfig().outConfs.at(1).desc.getLayout());
306                                 },
307                                 [](MKLDNNPlugin::PrimitiveDescInfo impl) {
308                                     ASSERT_EQ(MKLDNNPlugin::impl_desc_type::unknown, impl.getImplementationType());
309                                     ASSERT_EQ(1, impl.getConfig().inConfs.size());
310                                     ASSERT_EQ(2, impl.getConfig().outConfs.size());
311                                     ASSERT_EQ(InferenceEngine::Layout::NCHW, impl.getConfig().inConfs.at(0).desc.getLayout());
312                                     ASSERT_EQ(InferenceEngine::Layout::NCHW, impl.getConfig().outConfs.at(0).desc.getLayout());
313                                     ASSERT_EQ(InferenceEngine::Layout::NCHW, impl.getConfig().outConfs.at(1).desc.getLayout());
314                                 }
315                         }
316                 },
317                 split_test_params {
318                         {2, 20, 2, 5},
319                         {{2, 10, 2, 5}, {2, 10, 2, 5}},
320                         1, 2, MKLDNNPlugin::impl_desc_type::unknown, {}, {
321                                 [](MKLDNNPlugin::PrimitiveDescInfo impl) {
322                                     ASSERT_EQ(MKLDNNPlugin::impl_desc_type::ref, impl.getImplementationType());
323                                     ASSERT_EQ(1, impl.getConfig().inConfs.size());
324                                     ASSERT_EQ(2, impl.getConfig().outConfs.size());
325                                     ASSERT_EQ(InferenceEngine::Layout::ANY, impl.getConfig().inConfs.at(0).desc.getLayout());
326                                     ASSERT_EQ(InferenceEngine::Layout::ANY, impl.getConfig().outConfs.at(0).desc.getLayout());
327                                     ASSERT_EQ(InferenceEngine::Layout::ANY, impl.getConfig().outConfs.at(1).desc.getLayout());
328                                 },
329                                 [](MKLDNNPlugin::PrimitiveDescInfo impl) {
330                                     ASSERT_EQ(MKLDNNPlugin::impl_desc_type::unknown, impl.getImplementationType());
331                                     ASSERT_EQ(1, impl.getConfig().inConfs.size());
332                                     ASSERT_EQ(2, impl.getConfig().outConfs.size());
333                                     ASSERT_EQ(InferenceEngine::Layout::NCHW, impl.getConfig().inConfs.at(0).desc.getLayout());
334                                     ASSERT_EQ(InferenceEngine::Layout::NCHW, impl.getConfig().outConfs.at(0).desc.getLayout());
335                                     ASSERT_EQ(InferenceEngine::Layout::NCHW, impl.getConfig().outConfs.at(1).desc.getLayout());
336                                 }
337                         }
338                 },
339                 split_test_params {
340                         {1, 24, 2, 5},
341                         {{1, 16, 2, 5}, {1, 8, 2, 5}},
342                         1, 3, MKLDNNPlugin::impl_desc_type::ref, {MKLDNNPlugin::impl_desc_type::ref}
343                 },
344                 split_test_params {
345                         {1, 20, 2, 5},
346                         {{1, 13, 2, 5}, {1, 7, 2, 5}},
347                         1, 2, MKLDNNPlugin::impl_desc_type::ref, {MKLDNNPlugin::impl_desc_type::ref}
348                 },
349                 split_test_params {
350                         {1, 20, 2, 5},
351                         {{1, 10, 2, 5}, {1, 10, 2, 5}},
352                         1, 2, MKLDNNPlugin::impl_desc_type::ref, {MKLDNNPlugin::impl_desc_type::ref}
353                 },
354                 split_test_params {
355                         {2, 20, 2, 5},
356                         {{2, 10, 2, 5}, {2, 10, 2, 5}},
357                         1, 2, MKLDNNPlugin::impl_desc_type::ref, {MKLDNNPlugin::impl_desc_type::ref}
358                 },
359                 split_test_params {
360                         {2, 20, 2, 5},
361                         {{2, 15, 2, 5}, {2,  5, 2, 5}},
362                         1, 2, MKLDNNPlugin::impl_desc_type::ref, {MKLDNNPlugin::impl_desc_type::ref}
363                 },
364                 split_test_params {
365                         {9, 11, 7, 5},
366                         {{3, 11, 7, 5}, {6, 11, 7, 5}},
367                         0, 2, MKLDNNPlugin::impl_desc_type::ref, {MKLDNNPlugin::impl_desc_type::ref}
368                 },
369                 split_test_params {
370                         {3, 11, 7, 5},
371                         {{3, 11, 4, 5}, {3, 11, 3, 5}},
372                         2, 2, MKLDNNPlugin::impl_desc_type::ref, {MKLDNNPlugin::impl_desc_type::ref}
373                 },
374                 split_test_params {
375                         {3, 11, 7, 5},
376                         {{3, 11, 7, 1}, {3, 11, 7, 4}},
377                         3, 2, MKLDNNPlugin::impl_desc_type::ref, {MKLDNNPlugin::impl_desc_type::ref}
378                 },
379                 split_test_params {
380                         {5, 6, 7, 15},
381                         {{1, 6, 7, 15}, {2, 6, 7, 15}, {1, 6, 7, 15}, {1, 6, 7, 15}},
382                         0, 2, MKLDNNPlugin::impl_desc_type::ref, {MKLDNNPlugin::impl_desc_type::ref}
383                 },
384                 split_test_params {
385                         {5, 6, 7, 15},
386                         {{5, 1, 7, 15}, {5, 2, 7, 15}, {5, 1, 7, 15}, {5, 2, 7, 15}},
387                         1, 2, MKLDNNPlugin::impl_desc_type::ref, {MKLDNNPlugin::impl_desc_type::ref}
388                 },
389                 split_test_params {
390                         {5, 6, 7, 15},
391                         {{5, 6, 3, 15}, {5, 6, 1, 15}, {5, 6, 2, 15}, {5, 6, 1, 15}},
392                         2, 2, MKLDNNPlugin::impl_desc_type::ref, {MKLDNNPlugin::impl_desc_type::ref}
393                 },
394                 split_test_params {
395                         {5, 6, 7, 15},
396                         {{5, 6, 7, 5}, {5, 6, 7, 3}, {5, 6, 7, 4}, {5, 6, 7, 3}},
397                         3, 2, MKLDNNPlugin::impl_desc_type::ref, {MKLDNNPlugin::impl_desc_type::ref}
398                 },
399                 split_test_params {
400                         {5, 6, 7, 15},
401                         {{5, 6, 7, 15}},
402                         1, 2, MKLDNNPlugin::impl_desc_type::ref, {MKLDNNPlugin::impl_desc_type::ref}},
403                 split_test_params {
404                         {1, 32, 16, 16, 16},
405                         {{1, 8, 16, 16, 16}, {1, 8, 16, 16, 16}, {1, 8, 16, 16, 16}, {1, 8, 16, 16, 16}},
406                         1, 3, MKLDNNPlugin::impl_desc_type::ref, {MKLDNNPlugin::impl_desc_type::ref}},
407                 split_test_params {
408                         {1, 32, 16, 16, 16},
409                         {{1, 8, 16, 16, 16}, {1, 8, 16, 16, 16}, {1, 8, 16, 16, 16}, {1, 8, 16, 16, 16}},
410                         1, 3, MKLDNNPlugin::impl_desc_type::unknown, {}}));
411
412 class MKLDNNGraphDynBatchSplitTests: public MKLDNNGraphSplitTests {
413 protected:
414     virtual void SetUp() {
415         try {
416             split_test_params p = ::testing::WithParamInterface<split_test_params>::GetParam();
417             std::string model = getModel(p);
418             size_t MB = p.dims[0];
419             if (MB < 2)
420                 MB = 2;
421
422             InferenceEngine::CNNNetReader net_reader;
423             ASSERT_NO_THROW(net_reader.ReadNetwork(model.data(), model.length()));
424             InferenceEngine::CNNNetwork network = net_reader.getNetwork();
425             auto implNet = dynamic_cast<InferenceEngine::details::CNNNetworkImpl *>(&((InferenceEngine::ICNNNetwork&)network));
426             ASSERT_NE(nullptr, implNet) << "Failed to cast ICNNNetwork to CNNNetworkImpl";
427             InferenceEngine::ResponseDesc resp;
428             InferenceEngine::StatusCode sts  = implNet->setBatchSizeReshape(MB, &resp);
429             ASSERT_EQ((int)InferenceEngine::StatusCode::OK, sts) << resp.msg;
430
431             MKLDNNGraphTestClass graph;
432             graph.setProperty({{InferenceEngine::PluginConfigParams::KEY_DYN_BATCH_ENABLED, InferenceEngine::PluginConfigParams::YES}});
433             graph.CreateGraph(net_reader.getNetwork());
434
435             InferenceEngine::SizeVector dims_src = p.dims;
436             InferenceEngine::Layout layout = InferenceEngine::ANY;
437             switch (p.dims.size()) {
438                 case 4:
439                     layout = InferenceEngine::NCHW;
440                     break;
441                 case 5:
442                     layout = InferenceEngine::NCDHW;
443                     break;
444             }
445
446             InferenceEngine::Blob::Ptr src = InferenceEngine::make_shared_blob<float, const InferenceEngine::SizeVector>(InferenceEngine::Precision::FP32, layout, dims_src);
447             src->allocate();
448             fill_data(src->buffer(), src->size());
449
450             InferenceEngine::BlobMap srcs;
451             srcs.insert(std::pair<std::string, InferenceEngine::Blob::Ptr>("in1", src));
452
453             InferenceEngine::TBlob<float>* srcPtr = dynamic_cast<InferenceEngine::TBlob<float>*>(src.get());
454
455             if (srcPtr == nullptr)
456                 FAIL() << "Cannot cast blob to TBlob<float>.";
457
458             InferenceEngine::OutputsDataMap out;
459             out = net_reader.getNetwork().getOutputsInfo();
460             InferenceEngine::BlobMap outputBlobs;
461             auto it = out.begin();
462
463             std::pair<std::string, InferenceEngine::DataPtr> item = *it;
464
465             InferenceEngine::TBlob<float>::Ptr output1;
466             output1 = InferenceEngine::make_shared_blob<float>(item.second->getTensorDesc());
467             output1->allocate();
468             outputBlobs[item.first] = output1;
469
470             item = *(++it);
471             InferenceEngine::TBlob<float>::Ptr output2;
472             output2 = InferenceEngine::make_shared_blob<float>(item.second->getTensorDesc());
473             output2->allocate();
474             outputBlobs[item.first] = output2;
475
476             auto checkSplit = [](const MKLDNNPlugin::MKLDNNNodePtr& node) {
477                 return node->getType() == MKLDNNPlugin::Split;
478             };
479
480             graph.checkDynBatch(srcs, outputBlobs, MB, MB, checkSplit);
481             graph.checkDynBatch(srcs, outputBlobs, 1, MB, checkSplit);
482         } catch (const InferenceEngine::details::InferenceEngineException &e) {
483             FAIL() << e.what();
484         }
485     }
486 };
487
488 TEST_P(MKLDNNGraphDynBatchSplitTests, TestsDynBatchSplit) {}
489
490 INSTANTIATE_TEST_CASE_P(
491         TestsDynBatchSplit, MKLDNNGraphDynBatchSplitTests,
492         ::testing::Values(
493                 split_test_params {
494                         {1, 24, 2, 5},
495                         {{1, 16, 2, 5}, {1, 8, 2, 5}},
496                         1, 3, MKLDNNPlugin::impl_desc_type::unknown, {}, {
497                                 [](MKLDNNPlugin::PrimitiveDescInfo impl) {
498                                     ASSERT_EQ(MKLDNNPlugin::impl_desc_type::ref, impl.getImplementationType());
499                                     ASSERT_EQ(1, impl.getConfig().inConfs.size());
500                                     ASSERT_EQ(2, impl.getConfig().outConfs.size());
501                                     ASSERT_EQ(InferenceEngine::Layout::ANY, impl.getConfig().inConfs.at(0).desc.getLayout());
502                                     ASSERT_EQ(InferenceEngine::Layout::ANY, impl.getConfig().outConfs.at(0).desc.getLayout());
503                                     ASSERT_EQ(InferenceEngine::Layout::ANY, impl.getConfig().outConfs.at(1).desc.getLayout());
504                                 },
505                                 [](MKLDNNPlugin::PrimitiveDescInfo impl) {
506                                     ASSERT_EQ(MKLDNNPlugin::impl_desc_type::unknown, impl.getImplementationType());
507                                     ASSERT_EQ(1, impl.getConfig().inConfs.size());
508                                     ASSERT_EQ(2, impl.getConfig().outConfs.size());
509                                     ASSERT_EQ(InferenceEngine::Layout::NCHW, impl.getConfig().inConfs.at(0).desc.getLayout());
510                                     ASSERT_EQ(InferenceEngine::Layout::NCHW, impl.getConfig().outConfs.at(0).desc.getLayout());
511                                     ASSERT_EQ(InferenceEngine::Layout::NCHW, impl.getConfig().outConfs.at(1).desc.getLayout());
512                                 },
513                                 [](MKLDNNPlugin::PrimitiveDescInfo impl) {
514                                     ASSERT_EQ(MKLDNNPlugin::impl_desc_type::unknown, impl.getImplementationType());
515                                     ASSERT_EQ(1, impl.getConfig().inConfs.size());
516                                     ASSERT_EQ(2, impl.getConfig().outConfs.size());
517                                     ASSERT_EQ(InferenceEngine::Layout::BLOCKED, impl.getConfig().inConfs.at(0).desc.getLayout());
518                                     ASSERT_EQ(InferenceEngine::Layout::BLOCKED, impl.getConfig().outConfs.at(0).desc.getLayout());
519                                     ASSERT_EQ(InferenceEngine::Layout::BLOCKED, impl.getConfig().outConfs.at(1).desc.getLayout());
520                                 }
521                         }
522                 },
523                 split_test_params {
524                         {1, 20, 2, 5},
525                         {{1, 13, 2, 5}, {1, 7, 2, 5}},
526                         1, 2, MKLDNNPlugin::impl_desc_type::unknown, {}, {
527                                 [](MKLDNNPlugin::PrimitiveDescInfo impl) {
528                                     ASSERT_EQ(MKLDNNPlugin::impl_desc_type::ref, impl.getImplementationType());
529                                     ASSERT_EQ(1, impl.getConfig().inConfs.size());
530                                     ASSERT_EQ(2, impl.getConfig().outConfs.size());
531                                     ASSERT_EQ(InferenceEngine::Layout::ANY, impl.getConfig().inConfs.at(0).desc.getLayout());
532                                     ASSERT_EQ(InferenceEngine::Layout::ANY, impl.getConfig().outConfs.at(0).desc.getLayout());
533                                     ASSERT_EQ(InferenceEngine::Layout::ANY, impl.getConfig().outConfs.at(1).desc.getLayout());
534                                 },
535                                 [](MKLDNNPlugin::PrimitiveDescInfo impl) {
536                                     ASSERT_EQ(MKLDNNPlugin::impl_desc_type::unknown, impl.getImplementationType());
537                                     ASSERT_EQ(1, impl.getConfig().inConfs.size());
538                                     ASSERT_EQ(2, impl.getConfig().outConfs.size());
539                                     ASSERT_EQ(InferenceEngine::Layout::NCHW, impl.getConfig().inConfs.at(0).desc.getLayout());
540                                     ASSERT_EQ(InferenceEngine::Layout::NCHW, impl.getConfig().outConfs.at(0).desc.getLayout());
541                                     ASSERT_EQ(InferenceEngine::Layout::NCHW, impl.getConfig().outConfs.at(1).desc.getLayout());
542                                 }
543                         }
544                 },
545                 split_test_params {
546                         {1, 20, 2, 5},
547                         {{1, 10, 2, 5}, {1, 10, 2, 5}},
548                         1, 2, MKLDNNPlugin::impl_desc_type::unknown, {}, {
549                                 [](MKLDNNPlugin::PrimitiveDescInfo impl) {
550                                     ASSERT_EQ(MKLDNNPlugin::impl_desc_type::ref, impl.getImplementationType());
551                                     ASSERT_EQ(1, impl.getConfig().inConfs.size());
552                                     ASSERT_EQ(2, impl.getConfig().outConfs.size());
553                                     ASSERT_EQ(InferenceEngine::Layout::ANY, impl.getConfig().inConfs.at(0).desc.getLayout());
554                                     ASSERT_EQ(InferenceEngine::Layout::ANY, impl.getConfig().outConfs.at(0).desc.getLayout());
555                                     ASSERT_EQ(InferenceEngine::Layout::ANY, impl.getConfig().outConfs.at(1).desc.getLayout());
556                                 },
557                                 [](MKLDNNPlugin::PrimitiveDescInfo impl) {
558                                     ASSERT_EQ(MKLDNNPlugin::impl_desc_type::unknown, impl.getImplementationType());
559                                     ASSERT_EQ(1, impl.getConfig().inConfs.size());
560                                     ASSERT_EQ(2, impl.getConfig().outConfs.size());
561                                     ASSERT_EQ(InferenceEngine::Layout::NCHW, impl.getConfig().inConfs.at(0).desc.getLayout());
562                                     ASSERT_EQ(InferenceEngine::Layout::NCHW, impl.getConfig().outConfs.at(0).desc.getLayout());
563                                     ASSERT_EQ(InferenceEngine::Layout::NCHW, impl.getConfig().outConfs.at(1).desc.getLayout());
564                                 }
565                         }
566                 },
567                 split_test_params {
568                         {2, 20, 2, 5},
569                         {{2, 10, 2, 5}, {2, 10, 2, 5}},
570                         1, 2, MKLDNNPlugin::impl_desc_type::unknown, {}, {
571                                 [](MKLDNNPlugin::PrimitiveDescInfo impl) {
572                                     ASSERT_EQ(MKLDNNPlugin::impl_desc_type::ref, impl.getImplementationType());
573                                     ASSERT_EQ(1, impl.getConfig().inConfs.size());
574                                     ASSERT_EQ(2, impl.getConfig().outConfs.size());
575                                     ASSERT_EQ(InferenceEngine::Layout::ANY, impl.getConfig().inConfs.at(0).desc.getLayout());
576                                     ASSERT_EQ(InferenceEngine::Layout::ANY, impl.getConfig().outConfs.at(0).desc.getLayout());
577                                     ASSERT_EQ(InferenceEngine::Layout::ANY, impl.getConfig().outConfs.at(1).desc.getLayout());
578                                 },
579                                 [](MKLDNNPlugin::PrimitiveDescInfo impl) {
580                                     ASSERT_EQ(MKLDNNPlugin::impl_desc_type::unknown, impl.getImplementationType());
581                                     ASSERT_EQ(1, impl.getConfig().inConfs.size());
582                                     ASSERT_EQ(2, impl.getConfig().outConfs.size());
583                                     ASSERT_EQ(InferenceEngine::Layout::NCHW, impl.getConfig().inConfs.at(0).desc.getLayout());
584                                     ASSERT_EQ(InferenceEngine::Layout::NCHW, impl.getConfig().outConfs.at(0).desc.getLayout());
585                                     ASSERT_EQ(InferenceEngine::Layout::NCHW, impl.getConfig().outConfs.at(1).desc.getLayout());
586                                 }
587                         }
588                 },
589                 split_test_params {
590                         {2, 24, 2, 5},
591                         {{2, 16, 2, 5}, {2, 8, 2, 5}},
592                         1, 3, MKLDNNPlugin::impl_desc_type::ref, {MKLDNNPlugin::impl_desc_type::ref}
593                 },
594                 split_test_params {
595                         {1, 20, 2, 5},
596                         {{1, 13, 2, 5}, {1, 7, 2, 5}},
597                         1, 2, MKLDNNPlugin::impl_desc_type::ref, {MKLDNNPlugin::impl_desc_type::ref}
598                 },
599                 split_test_params {
600                         {1, 20, 2, 5},
601                         {{1, 10, 2, 5}, {1, 10, 2, 5}},
602                         1, 2, MKLDNNPlugin::impl_desc_type::ref, {MKLDNNPlugin::impl_desc_type::ref}
603                 },
604                 split_test_params {
605                         {2, 20, 2, 5},
606                         {{2, 10, 2, 5}, {2, 10, 2, 5}},
607                         1, 2, MKLDNNPlugin::impl_desc_type::ref, {MKLDNNPlugin::impl_desc_type::ref}
608                 },
609                 split_test_params {
610                         {2, 20, 2, 5},
611                         {{2, 15, 2, 5}, {2,  5, 2, 5}},
612                         1, 2, MKLDNNPlugin::impl_desc_type::ref, {MKLDNNPlugin::impl_desc_type::ref}
613                 },
614                 split_test_params {
615                         {3, 11, 7, 5},
616                         {{3, 11, 4, 5}, {3, 11, 3, 5}},
617                         2, 2, MKLDNNPlugin::impl_desc_type::ref, {MKLDNNPlugin::impl_desc_type::ref}
618                 },
619                 split_test_params {
620                         {3, 11, 7, 5},
621                         {{3, 11, 7, 1}, {3, 11, 7, 4}},
622                         3, 2, MKLDNNPlugin::impl_desc_type::ref, {MKLDNNPlugin::impl_desc_type::ref}
623                 },
624                 split_test_params {
625                         {5, 6, 7, 15},
626                         {{5, 1, 7, 15}, {5, 2, 7, 15}, {5, 1, 7, 15}, {5, 2, 7, 15}},
627                         1, 2, MKLDNNPlugin::impl_desc_type::ref, {MKLDNNPlugin::impl_desc_type::ref}
628                 },
629                 split_test_params {
630                         {5, 6, 7, 15},
631                         {{5, 6, 3, 15}, {5, 6, 1, 15}, {5, 6, 2, 15}, {5, 6, 1, 15}},
632                         2, 2, MKLDNNPlugin::impl_desc_type::ref, {MKLDNNPlugin::impl_desc_type::ref}
633                 },
634                 split_test_params {
635                         {5, 6, 7, 15},
636                         {{5, 6, 7, 5}, {5, 6, 7, 3}, {5, 6, 7, 4}, {5, 6, 7, 3}},
637                         3, 2, MKLDNNPlugin::impl_desc_type::ref, {MKLDNNPlugin::impl_desc_type::ref}}));