Publishing R3
[platform/upstream/dldt.git] / inference-engine / tests / unit / engines / mkldnn / graph / layers / internal / graph_split_test.cpp
1 // Copyright (C) 2018 Intel Corporation
2 //
3 // SPDX-License-Identifier: Apache-2.0
4 //
5
6 #include <gtest/gtest.h>
7 #include <gmock/gmock-spec-builders.h>
8 #include "mkldnn_plugin/mkldnn_graph.h"
9 #include "mock_mkldnn_primitive.hpp"
10
11 #include "test_graph.hpp"
12
13 #include "single_layer_common.hpp"
14 #include <mkldnn_plugin/mkldnn_extension_utils.h>
15 #include <inference_engine/cnn_network_impl.hpp>
16 #include "tests_common.hpp"
17
18 using namespace ::testing;
19 using namespace std;
20 using namespace mkldnn;
21
22 struct split_test_params {
23     struct {
24         size_t n;
25         size_t c;
26         size_t h;
27         size_t w;
28     } in;
29
30     struct {
31         size_t n;
32         size_t c;
33         size_t h;
34         size_t w;
35     } out1;
36
37     struct {
38         size_t n;
39         size_t c;
40         size_t h;
41         size_t w;
42     } out2;
43
44     size_t num_prim_desc;
45
46     MKLDNNPlugin::impl_desc_type selectedType;
47     std::vector<MKLDNNPlugin::impl_desc_type> preferTypes;
48
49     std::vector<std::function<void(MKLDNNPlugin::PrimitiveDescInfo)>> comp;
50 };
51
52 template <typename data_t>
53 void ref_split(InferenceEngine::TBlob<data_t> &src, InferenceEngine::TBlob<data_t> &dst1, InferenceEngine::TBlob<data_t> &dst2) {
54     const float * srcData = src.readOnly();
55
56     int MB = dst1.dims()[dst1.dims().size() - 1];
57
58     float * dstData1 = dst1.data();
59     int dstSize1 = dst1.size() / MB;
60
61     float *dstData2 = dst2.data();
62     int dstSize2 = dst2.size() / MB;
63
64     for (int b = 0; b < MB; b++) {
65         for (size_t j = 0; j < dstSize1; j++, srcData++) {
66             dstData1[b*dstSize1 + j] = *srcData;
67         }
68
69         for (size_t j = 0; j < dstSize2; j++, srcData++) {
70             dstData2[b*dstSize1 + j] = *srcData;
71         }
72     }
73 }
74
75 class MKLDNNGraphSplitTests: public TestsCommon,
76                               public WithParamInterface<split_test_params> {
77     // TODO: remove power layers from the test
78     std::string model_t = R"V0G0N(
79 <net name="ConcatOnly" version="2" precision="FP32" batch="1">
80     <layers>
81         <layer name="in1" type="Input" precision="FP32" id="1">
82             <output>
83                 <port id="1">
84                     <dim>_IN_</dim>
85                     <dim>_IC_</dim>
86                     <dim>_IH_</dim>
87                     <dim>_IW_</dim>
88                 </port>
89             </output>
90         </layer>
91         <layer name="split" id="2" type="Split" precision="FP32">
92             <split_data axis="1" PrimitivesPriority="_IMPLS_"/>
93             <input>
94                 <port id="1">
95                     <dim>_IN_</dim>
96                     <dim>_IC_</dim>
97                     <dim>_IH_</dim>
98                     <dim>_IW_</dim>
99                 </port>
100             </input>
101             <output>
102                 <port id="2">
103                     <dim>_ON1_</dim>
104                     <dim>_OC1_</dim>
105                     <dim>_OH1_</dim>
106                     <dim>_OW1_</dim>
107                 </port>
108                 <port id="3">
109                     <dim>_ON2_</dim>
110                     <dim>_OC2_</dim>
111                     <dim>_OH2_</dim>
112                     <dim>_OW2_</dim>
113                 </port>
114             </output>
115         </layer>
116         <layer name="power1" id="3" type="Power" precision="FP32">
117             <power_data power="1" scale="1" shift="0"/>
118             <input>
119                 <port id="1">
120                     <dim>_ON1_</dim>
121                     <dim>_OC1_</dim>
122                     <dim>_OH1_</dim>
123                     <dim>_OW1_</dim>
124                 </port>
125             </input>
126             <output>
127                 <port id="2">
128                     <dim>_ON1_</dim>
129                     <dim>_OC1_</dim>
130                     <dim>_OH1_</dim>
131                     <dim>_OW1_</dim>
132                 </port>
133             </output>
134         </layer>
135         <layer name="power2" id="4" type="Power" precision="FP32">
136             <power_data power="1" scale="1" shift="0"/>
137             <input>
138                 <port id="1">
139                     <dim>_ON2_</dim>
140                     <dim>_OC2_</dim>
141                     <dim>_OH2_</dim>
142                     <dim>_OW2_</dim>
143                 </port>
144             </input>
145             <output>
146                 <port id="2">
147                     <dim>_ON2_</dim>
148                     <dim>_OC2_</dim>
149                     <dim>_OH2_</dim>
150                     <dim>_OW2_</dim>
151                 </port>
152             </output>
153         </layer>
154     </layers>
155     <edges>
156         <edge from-layer="1" from-port="1" to-layer="2" to-port="1"/>
157         <edge from-layer="2" from-port="2" to-layer="3" to-port="1"/>
158         <edge from-layer="2" from-port="3" to-layer="4" to-port="1"/>
159     </edges>
160 </net>
161 )V0G0N";
162
163 protected:
164     std::string getModel(split_test_params p) {
165         std::string model = model_t;
166         REPLACE_WITH_NUM(model, "_IN_", p.in.n);
167         REPLACE_WITH_NUM(model, "_IC_", p.in.c);
168         REPLACE_WITH_NUM(model, "_IW_", p.in.w);
169         REPLACE_WITH_NUM(model, "_IH_", p.in.h);
170
171         REPLACE_WITH_NUM(model, "_ON1_", p.out1.n);
172         REPLACE_WITH_NUM(model, "_OC1_", p.out1.c);
173         REPLACE_WITH_NUM(model, "_OH1_", p.out1.h);
174         REPLACE_WITH_NUM(model, "_OW1_", p.out1.w);
175
176         REPLACE_WITH_NUM(model, "_ON2_", p.out2.n);
177         REPLACE_WITH_NUM(model, "_OC2_", p.out2.c);
178         REPLACE_WITH_NUM(model, "_OH2_", p.out2.h);
179         REPLACE_WITH_NUM(model, "_OW2_", p.out2.w);
180         std::string impls;
181         for (const auto& preferType : p.preferTypes) {
182             if (!impls.empty())
183                 impls += ",";
184             impls += "cpu:" + MKLDNNGraphTestClass::getStrPrimitiveDescriptorType(preferType);
185         }
186         REPLACE_WITH_STR(model, "_IMPLS_", impls);
187         return model;
188     }
189
190     virtual void TearDown() {
191     }
192
193     virtual void SetUp() {
194         try {
195             split_test_params p = ::testing::WithParamInterface<split_test_params>::GetParam();
196             std::string model = getModel(p);
197
198             InferenceEngine::CNNNetReader net_reader;
199             ASSERT_NO_THROW(net_reader.ReadNetwork(model.data(), model.length()));
200
201             MKLDNNGraphTestClass graph;
202             graph.CreateGraph(net_reader.getNetwork());
203             auto& nodes = graph.getNodes();
204             for (int i = 0; i < nodes.size(); i++) {
205                 if (nodes[i]->getType() == MKLDNNPlugin::Split) {
206                     ASSERT_EQ(p.num_prim_desc, nodes[i]->getSupportedPrimitiveDescriptors().size());
207                     for (size_t j = 0; j < p.num_prim_desc && j < p.comp.size(); j++) {
208                         p.comp.at(j)(nodes[i]->getSupportedPrimitiveDescriptors().at(j));
209                     }
210                     ASSERT_NE(nullptr, nodes[i]->getSelectedPrimitiveDescriptor());
211                     ASSERT_EQ(p.selectedType, nodes[i]->getSelectedPrimitiveDescriptor()->getImplementationType());
212                 }
213             }
214             ASSERT_LE(3, nodes.size());
215
216             InferenceEngine::SizeVector dims_src = {p.in.n, p.in.c, p.in.h, p.in.w};
217
218             InferenceEngine::Blob::Ptr src = InferenceEngine::make_shared_blob<float, const InferenceEngine::SizeVector>(InferenceEngine::Precision::FP32, InferenceEngine::NCHW, dims_src);
219             src->allocate();
220             fill_data(src->buffer(), src->size());
221
222             InferenceEngine::BlobMap srcs;
223             srcs.insert(std::pair<std::string, InferenceEngine::Blob::Ptr>("in1", src));
224
225             InferenceEngine::TBlob<float>* srcPtr = dynamic_cast<InferenceEngine::TBlob<float>*>(src.get());
226
227             if (srcPtr == nullptr)
228                 FAIL() << "Cannot cast blob to TBlob<float>.";
229
230             InferenceEngine::OutputsDataMap out;
231             out = net_reader.getNetwork().getOutputsInfo();
232             InferenceEngine::BlobMap outputBlobs;
233             auto it = out.begin();
234
235             std::pair<std::string, InferenceEngine::DataPtr> item = *it;
236
237             InferenceEngine::TBlob<float>::Ptr output1;
238             output1 = InferenceEngine::make_shared_blob<float>(item.second->getTensorDesc());
239             output1->allocate();
240             outputBlobs[item.first] = output1;
241
242             InferenceEngine::TBlob<float> dst_ref1(item.second->getTensorDesc());
243             dst_ref1.allocate();
244
245             item = *(++it);
246             InferenceEngine::TBlob<float>::Ptr output2;
247             output2 = InferenceEngine::make_shared_blob<float>(item.second->getTensorDesc());
248             output2->allocate();
249             outputBlobs[item.first] = output2;
250
251             InferenceEngine::TBlob<float> dst_ref2(item.second->getTensorDesc());
252             dst_ref2.allocate();
253
254             graph.Infer(srcs, outputBlobs);
255
256             ref_split(*srcPtr, dst_ref1, dst_ref2);
257
258             compare(*output1, dst_ref1);
259             compare(*output2, dst_ref2);
260         } catch (const InferenceEngine::details::InferenceEngineException &e) {
261             FAIL() << e.what();
262         }
263     }
264 };
265
266 TEST_P(MKLDNNGraphSplitTests, TestsSplit) {}
267
268 INSTANTIATE_TEST_CASE_P(
269         TestsSplit, MKLDNNGraphSplitTests,
270         ::testing::Values(
271                 split_test_params {
272                         {1, 24, 2, 5},
273                         {1, 16, 2, 5},
274                         {1, 8, 2, 5},
275                         3, MKLDNNPlugin::impl_desc_type::unknown, {}, {
276                                 [](MKLDNNPlugin::PrimitiveDescInfo impl) {
277                                     ASSERT_EQ(MKLDNNPlugin::impl_desc_type::ref, impl.getImplementationType());
278                                     ASSERT_EQ(1, impl.getConfig().inConfs.size());
279                                     ASSERT_EQ(2, impl.getConfig().outConfs.size());
280                                     ASSERT_EQ(InferenceEngine::Layout::ANY, impl.getConfig().inConfs.at(0).desc.getLayout());
281                                     ASSERT_EQ(InferenceEngine::Layout::ANY, impl.getConfig().outConfs.at(0).desc.getLayout());
282                                     ASSERT_EQ(InferenceEngine::Layout::ANY, impl.getConfig().outConfs.at(1).desc.getLayout());
283                                 },
284                                 [](MKLDNNPlugin::PrimitiveDescInfo impl) {
285                                     ASSERT_EQ(MKLDNNPlugin::impl_desc_type::unknown, impl.getImplementationType());
286                                     ASSERT_EQ(1, impl.getConfig().inConfs.size());
287                                     ASSERT_EQ(2, impl.getConfig().outConfs.size());
288                                     ASSERT_EQ(InferenceEngine::Layout::NCHW, impl.getConfig().inConfs.at(0).desc.getLayout());
289                                     ASSERT_EQ(InferenceEngine::Layout::NCHW, impl.getConfig().outConfs.at(0).desc.getLayout());
290                                     ASSERT_EQ(InferenceEngine::Layout::NCHW, impl.getConfig().outConfs.at(1).desc.getLayout());
291                                 },
292                                 [](MKLDNNPlugin::PrimitiveDescInfo impl) {
293                                     ASSERT_EQ(MKLDNNPlugin::impl_desc_type::unknown, impl.getImplementationType());
294                                     ASSERT_EQ(1, impl.getConfig().inConfs.size());
295                                     ASSERT_EQ(2, impl.getConfig().outConfs.size());
296                                     ASSERT_EQ(InferenceEngine::Layout::BLOCKED, impl.getConfig().inConfs.at(0).desc.getLayout());
297                                     ASSERT_EQ(InferenceEngine::Layout::BLOCKED, impl.getConfig().outConfs.at(0).desc.getLayout());
298                                     ASSERT_EQ(InferenceEngine::Layout::BLOCKED, impl.getConfig().outConfs.at(1).desc.getLayout());
299                                 }
300                         }
301                 },
302                 split_test_params {
303                         {1, 20, 2, 5},
304                         {1, 13, 2, 5},
305                         {1, 7, 2, 5},
306                         2, MKLDNNPlugin::impl_desc_type::unknown, {}, {
307                                 [](MKLDNNPlugin::PrimitiveDescInfo impl) {
308                                     ASSERT_EQ(MKLDNNPlugin::impl_desc_type::ref, impl.getImplementationType());
309                                     ASSERT_EQ(1, impl.getConfig().inConfs.size());
310                                     ASSERT_EQ(2, impl.getConfig().outConfs.size());
311                                     ASSERT_EQ(InferenceEngine::Layout::ANY, impl.getConfig().inConfs.at(0).desc.getLayout());
312                                     ASSERT_EQ(InferenceEngine::Layout::ANY, impl.getConfig().outConfs.at(0).desc.getLayout());
313                                     ASSERT_EQ(InferenceEngine::Layout::ANY, impl.getConfig().outConfs.at(1).desc.getLayout());
314                                 },
315                                 [](MKLDNNPlugin::PrimitiveDescInfo impl) {
316                                     ASSERT_EQ(MKLDNNPlugin::impl_desc_type::unknown, impl.getImplementationType());
317                                     ASSERT_EQ(1, impl.getConfig().inConfs.size());
318                                     ASSERT_EQ(2, impl.getConfig().outConfs.size());
319                                     ASSERT_EQ(InferenceEngine::Layout::NCHW, impl.getConfig().inConfs.at(0).desc.getLayout());
320                                     ASSERT_EQ(InferenceEngine::Layout::NCHW, impl.getConfig().outConfs.at(0).desc.getLayout());
321                                     ASSERT_EQ(InferenceEngine::Layout::NCHW, impl.getConfig().outConfs.at(1).desc.getLayout());
322                                 }
323                         }
324                 },
325                 split_test_params {
326                         {1, 20, 2, 5},
327                         {1, 10, 2, 5},
328                         {1, 10, 2, 5},
329                         2, MKLDNNPlugin::impl_desc_type::unknown, {}, {
330                                 [](MKLDNNPlugin::PrimitiveDescInfo impl) {
331                                     ASSERT_EQ(MKLDNNPlugin::impl_desc_type::ref, impl.getImplementationType());
332                                     ASSERT_EQ(1, impl.getConfig().inConfs.size());
333                                     ASSERT_EQ(2, impl.getConfig().outConfs.size());
334                                     ASSERT_EQ(InferenceEngine::Layout::ANY, impl.getConfig().inConfs.at(0).desc.getLayout());
335                                     ASSERT_EQ(InferenceEngine::Layout::ANY, impl.getConfig().outConfs.at(0).desc.getLayout());
336                                     ASSERT_EQ(InferenceEngine::Layout::ANY, impl.getConfig().outConfs.at(1).desc.getLayout());
337                                 },
338                                 [](MKLDNNPlugin::PrimitiveDescInfo impl) {
339                                     ASSERT_EQ(MKLDNNPlugin::impl_desc_type::unknown, impl.getImplementationType());
340                                     ASSERT_EQ(1, impl.getConfig().inConfs.size());
341                                     ASSERT_EQ(2, impl.getConfig().outConfs.size());
342                                     ASSERT_EQ(InferenceEngine::Layout::NCHW, impl.getConfig().inConfs.at(0).desc.getLayout());
343                                     ASSERT_EQ(InferenceEngine::Layout::NCHW, impl.getConfig().outConfs.at(0).desc.getLayout());
344                                     ASSERT_EQ(InferenceEngine::Layout::NCHW, impl.getConfig().outConfs.at(1).desc.getLayout());
345                                 }
346                         }
347                 },
348                 split_test_params {
349                         {2, 20, 2, 5},
350                         {2, 10, 2, 5},
351                         {2, 10, 2, 5},
352                         2, MKLDNNPlugin::impl_desc_type::unknown, {}, {
353                                 [](MKLDNNPlugin::PrimitiveDescInfo impl) {
354                                     ASSERT_EQ(MKLDNNPlugin::impl_desc_type::ref, impl.getImplementationType());
355                                     ASSERT_EQ(1, impl.getConfig().inConfs.size());
356                                     ASSERT_EQ(2, impl.getConfig().outConfs.size());
357                                     ASSERT_EQ(InferenceEngine::Layout::ANY, impl.getConfig().inConfs.at(0).desc.getLayout());
358                                     ASSERT_EQ(InferenceEngine::Layout::ANY, impl.getConfig().outConfs.at(0).desc.getLayout());
359                                     ASSERT_EQ(InferenceEngine::Layout::ANY, impl.getConfig().outConfs.at(1).desc.getLayout());
360                                 },
361                                 [](MKLDNNPlugin::PrimitiveDescInfo impl) {
362                                     ASSERT_EQ(MKLDNNPlugin::impl_desc_type::unknown, impl.getImplementationType());
363                                     ASSERT_EQ(1, impl.getConfig().inConfs.size());
364                                     ASSERT_EQ(2, impl.getConfig().outConfs.size());
365                                     ASSERT_EQ(InferenceEngine::Layout::NCHW, impl.getConfig().inConfs.at(0).desc.getLayout());
366                                     ASSERT_EQ(InferenceEngine::Layout::NCHW, impl.getConfig().outConfs.at(0).desc.getLayout());
367                                     ASSERT_EQ(InferenceEngine::Layout::NCHW, impl.getConfig().outConfs.at(1).desc.getLayout());
368                                 }
369                         }
370                 },
371                 split_test_params {
372                         {1, 24, 2, 5},
373                         {1, 16, 2, 5},
374                         {1, 8, 2, 5},
375                         3, MKLDNNPlugin::impl_desc_type::ref, {MKLDNNPlugin::impl_desc_type::ref}
376                 },
377                 split_test_params {
378                         {1, 20, 2, 5},
379                         {1, 13, 2, 5},
380                         {1, 7, 2, 5},
381                         2, MKLDNNPlugin::impl_desc_type::ref, {MKLDNNPlugin::impl_desc_type::ref}
382                 },
383                 split_test_params {
384                         {1, 20, 2, 5},
385                         {1, 10, 2, 5},
386                         {1, 10, 2, 5},
387                         2, MKLDNNPlugin::impl_desc_type::ref, {MKLDNNPlugin::impl_desc_type::ref}
388                 },
389                 split_test_params {
390                         {2, 20, 2, 5},
391                         {2, 10, 2, 5},
392                         {2, 10, 2, 5},
393                         2, MKLDNNPlugin::impl_desc_type::ref, {MKLDNNPlugin::impl_desc_type::ref}}));
394
395 class MKLDNNGraphDynBatchSplitTests: public MKLDNNGraphSplitTests {
396 protected:
397     virtual void SetUp() {
398         try {
399             split_test_params p = ::testing::WithParamInterface<split_test_params>::GetParam();
400             std::string model = getModel(p);
401             size_t MB = p.in.n;
402             if (MB < 2)
403                 MB = 2;
404
405             InferenceEngine::CNNNetReader net_reader;
406             ASSERT_NO_THROW(net_reader.ReadNetwork(model.data(), model.length()));
407             InferenceEngine::CNNNetwork network = net_reader.getNetwork();
408             auto implNet = dynamic_cast<InferenceEngine::details::CNNNetworkImpl *>(&((InferenceEngine::ICNNNetwork&)network));
409             ASSERT_NE(nullptr, implNet) << "Failed to cast ICNNNetwork to CNNNetworkImpl";
410             InferenceEngine::ResponseDesc resp;
411             InferenceEngine::StatusCode sts  = implNet->setBatchSizeReshape(MB, &resp);
412             ASSERT_EQ((int)InferenceEngine::StatusCode::OK, sts) << resp.msg;
413
414             MKLDNNGraphTestClass graph;
415             graph.setProperty({{InferenceEngine::PluginConfigParams::KEY_DYN_BATCH_ENABLED, InferenceEngine::PluginConfigParams::YES}});
416             graph.CreateGraph(net_reader.getNetwork());
417
418             InferenceEngine::SizeVector dims_src = {MB, p.in.c, p.in.h, p.in.w};
419
420             InferenceEngine::Blob::Ptr src = InferenceEngine::make_shared_blob<float, const InferenceEngine::SizeVector>(InferenceEngine::Precision::FP32, InferenceEngine::NCHW, dims_src);
421             src->allocate();
422             fill_data(src->buffer(), src->size());
423
424             InferenceEngine::BlobMap srcs;
425             srcs.insert(std::pair<std::string, InferenceEngine::Blob::Ptr>("in1", src));
426
427             InferenceEngine::TBlob<float>* srcPtr = dynamic_cast<InferenceEngine::TBlob<float>*>(src.get());
428
429             if (srcPtr == nullptr)
430                 FAIL() << "Cannot cast blob to TBlob<float>.";
431
432             InferenceEngine::OutputsDataMap out;
433             out = net_reader.getNetwork().getOutputsInfo();
434             InferenceEngine::BlobMap outputBlobs;
435             auto it = out.begin();
436
437             std::pair<std::string, InferenceEngine::DataPtr> item = *it;
438
439             InferenceEngine::TBlob<float>::Ptr output1;
440             output1 = InferenceEngine::make_shared_blob<float>(item.second->getTensorDesc());
441             output1->allocate();
442             outputBlobs[item.first] = output1;
443
444             item = *(++it);
445             InferenceEngine::TBlob<float>::Ptr output2;
446             output2 = InferenceEngine::make_shared_blob<float>(item.second->getTensorDesc());
447             output2->allocate();
448             outputBlobs[item.first] = output2;
449
450             auto checkSplit = [](const MKLDNNPlugin::MKLDNNNodePtr& node) {
451                 return node->getType() == MKLDNNPlugin::Split;
452             };
453
454             graph.checkDynBatch(srcs, outputBlobs, MB, MB, checkSplit);
455             graph.checkDynBatch(srcs, outputBlobs, 1, MB, checkSplit);
456         } catch (const InferenceEngine::details::InferenceEngineException &e) {
457             FAIL() << e.what();
458         }
459     }
460 };
461
462 TEST_P(MKLDNNGraphDynBatchSplitTests, TestsDynBatchSplit) {}
463
464 INSTANTIATE_TEST_CASE_P(
465         TestsDynBatchSplit, MKLDNNGraphDynBatchSplitTests,
466         ::testing::Values(
467                 split_test_params {
468                         {1, 24, 2, 5},
469                         {1, 16, 2, 5},
470                         {1, 8, 2, 5},
471                         3, MKLDNNPlugin::impl_desc_type::unknown, {}, {
472                                 [](MKLDNNPlugin::PrimitiveDescInfo impl) {
473                                     ASSERT_EQ(MKLDNNPlugin::impl_desc_type::ref, impl.getImplementationType());
474                                     ASSERT_EQ(1, impl.getConfig().inConfs.size());
475                                     ASSERT_EQ(2, impl.getConfig().outConfs.size());
476                                     ASSERT_EQ(InferenceEngine::Layout::ANY, impl.getConfig().inConfs.at(0).desc.getLayout());
477                                     ASSERT_EQ(InferenceEngine::Layout::ANY, impl.getConfig().outConfs.at(0).desc.getLayout());
478                                     ASSERT_EQ(InferenceEngine::Layout::ANY, impl.getConfig().outConfs.at(1).desc.getLayout());
479                                 },
480                                 [](MKLDNNPlugin::PrimitiveDescInfo impl) {
481                                     ASSERT_EQ(MKLDNNPlugin::impl_desc_type::unknown, impl.getImplementationType());
482                                     ASSERT_EQ(1, impl.getConfig().inConfs.size());
483                                     ASSERT_EQ(2, impl.getConfig().outConfs.size());
484                                     ASSERT_EQ(InferenceEngine::Layout::NCHW, impl.getConfig().inConfs.at(0).desc.getLayout());
485                                     ASSERT_EQ(InferenceEngine::Layout::NCHW, impl.getConfig().outConfs.at(0).desc.getLayout());
486                                     ASSERT_EQ(InferenceEngine::Layout::NCHW, impl.getConfig().outConfs.at(1).desc.getLayout());
487                                 },
488                                 [](MKLDNNPlugin::PrimitiveDescInfo impl) {
489                                     ASSERT_EQ(MKLDNNPlugin::impl_desc_type::unknown, impl.getImplementationType());
490                                     ASSERT_EQ(1, impl.getConfig().inConfs.size());
491                                     ASSERT_EQ(2, impl.getConfig().outConfs.size());
492                                     ASSERT_EQ(InferenceEngine::Layout::BLOCKED, impl.getConfig().inConfs.at(0).desc.getLayout());
493                                     ASSERT_EQ(InferenceEngine::Layout::BLOCKED, impl.getConfig().outConfs.at(0).desc.getLayout());
494                                     ASSERT_EQ(InferenceEngine::Layout::BLOCKED, impl.getConfig().outConfs.at(1).desc.getLayout());
495                                 }
496                         }
497                 },
498                 split_test_params {
499                         {1, 20, 2, 5},
500                         {1, 13, 2, 5},
501                         {1, 7, 2, 5},
502                         2, MKLDNNPlugin::impl_desc_type::unknown, {}, {
503                                 [](MKLDNNPlugin::PrimitiveDescInfo impl) {
504                                     ASSERT_EQ(MKLDNNPlugin::impl_desc_type::ref, impl.getImplementationType());
505                                     ASSERT_EQ(1, impl.getConfig().inConfs.size());
506                                     ASSERT_EQ(2, impl.getConfig().outConfs.size());
507                                     ASSERT_EQ(InferenceEngine::Layout::ANY, impl.getConfig().inConfs.at(0).desc.getLayout());
508                                     ASSERT_EQ(InferenceEngine::Layout::ANY, impl.getConfig().outConfs.at(0).desc.getLayout());
509                                     ASSERT_EQ(InferenceEngine::Layout::ANY, impl.getConfig().outConfs.at(1).desc.getLayout());
510                                 },
511                                 [](MKLDNNPlugin::PrimitiveDescInfo impl) {
512                                     ASSERT_EQ(MKLDNNPlugin::impl_desc_type::unknown, impl.getImplementationType());
513                                     ASSERT_EQ(1, impl.getConfig().inConfs.size());
514                                     ASSERT_EQ(2, impl.getConfig().outConfs.size());
515                                     ASSERT_EQ(InferenceEngine::Layout::NCHW, impl.getConfig().inConfs.at(0).desc.getLayout());
516                                     ASSERT_EQ(InferenceEngine::Layout::NCHW, impl.getConfig().outConfs.at(0).desc.getLayout());
517                                     ASSERT_EQ(InferenceEngine::Layout::NCHW, impl.getConfig().outConfs.at(1).desc.getLayout());
518                                 }
519                         }
520                 },
521                 split_test_params {
522                         {1, 20, 2, 5},
523                         {1, 10, 2, 5},
524                         {1, 10, 2, 5},
525                         2, MKLDNNPlugin::impl_desc_type::unknown, {}, {
526                                 [](MKLDNNPlugin::PrimitiveDescInfo impl) {
527                                     ASSERT_EQ(MKLDNNPlugin::impl_desc_type::ref, impl.getImplementationType());
528                                     ASSERT_EQ(1, impl.getConfig().inConfs.size());
529                                     ASSERT_EQ(2, impl.getConfig().outConfs.size());
530                                     ASSERT_EQ(InferenceEngine::Layout::ANY, impl.getConfig().inConfs.at(0).desc.getLayout());
531                                     ASSERT_EQ(InferenceEngine::Layout::ANY, impl.getConfig().outConfs.at(0).desc.getLayout());
532                                     ASSERT_EQ(InferenceEngine::Layout::ANY, impl.getConfig().outConfs.at(1).desc.getLayout());
533                                 },
534                                 [](MKLDNNPlugin::PrimitiveDescInfo impl) {
535                                     ASSERT_EQ(MKLDNNPlugin::impl_desc_type::unknown, impl.getImplementationType());
536                                     ASSERT_EQ(1, impl.getConfig().inConfs.size());
537                                     ASSERT_EQ(2, impl.getConfig().outConfs.size());
538                                     ASSERT_EQ(InferenceEngine::Layout::NCHW, impl.getConfig().inConfs.at(0).desc.getLayout());
539                                     ASSERT_EQ(InferenceEngine::Layout::NCHW, impl.getConfig().outConfs.at(0).desc.getLayout());
540                                     ASSERT_EQ(InferenceEngine::Layout::NCHW, impl.getConfig().outConfs.at(1).desc.getLayout());
541                                 }
542                         }
543                 },
544                 split_test_params {
545                         {2, 20, 2, 5},
546                         {2, 10, 2, 5},
547                         {2, 10, 2, 5},
548                         2, MKLDNNPlugin::impl_desc_type::unknown, {}, {
549                                 [](MKLDNNPlugin::PrimitiveDescInfo impl) {
550                                     ASSERT_EQ(MKLDNNPlugin::impl_desc_type::ref, impl.getImplementationType());
551                                     ASSERT_EQ(1, impl.getConfig().inConfs.size());
552                                     ASSERT_EQ(2, impl.getConfig().outConfs.size());
553                                     ASSERT_EQ(InferenceEngine::Layout::ANY, impl.getConfig().inConfs.at(0).desc.getLayout());
554                                     ASSERT_EQ(InferenceEngine::Layout::ANY, impl.getConfig().outConfs.at(0).desc.getLayout());
555                                     ASSERT_EQ(InferenceEngine::Layout::ANY, impl.getConfig().outConfs.at(1).desc.getLayout());
556                                 },
557                                 [](MKLDNNPlugin::PrimitiveDescInfo impl) {
558                                     ASSERT_EQ(MKLDNNPlugin::impl_desc_type::unknown, impl.getImplementationType());
559                                     ASSERT_EQ(1, impl.getConfig().inConfs.size());
560                                     ASSERT_EQ(2, impl.getConfig().outConfs.size());
561                                     ASSERT_EQ(InferenceEngine::Layout::NCHW, impl.getConfig().inConfs.at(0).desc.getLayout());
562                                     ASSERT_EQ(InferenceEngine::Layout::NCHW, impl.getConfig().outConfs.at(0).desc.getLayout());
563                                     ASSERT_EQ(InferenceEngine::Layout::NCHW, impl.getConfig().outConfs.at(1).desc.getLayout());
564                                 }
565                         }
566                 },
567                 split_test_params {
568                         {1, 24, 2, 5},
569                         {1, 16, 2, 5},
570                         {1, 8, 2, 5},
571                         3, MKLDNNPlugin::impl_desc_type::ref, {MKLDNNPlugin::impl_desc_type::ref}
572                 },
573                 split_test_params {
574                         {1, 20, 2, 5},
575                         {1, 13, 2, 5},
576                         {1, 7, 2, 5},
577                         2, MKLDNNPlugin::impl_desc_type::ref, {MKLDNNPlugin::impl_desc_type::ref}
578                 },
579                 split_test_params {
580                         {1, 20, 2, 5},
581                         {1, 10, 2, 5},
582                         {1, 10, 2, 5},
583                         2, MKLDNNPlugin::impl_desc_type::ref, {MKLDNNPlugin::impl_desc_type::ref}
584                 },
585                 split_test_params {
586                         {2, 20, 2, 5},
587                         {2, 10, 2, 5},
588                         {2, 10, 2, 5},
589                         2, MKLDNNPlugin::impl_desc_type::ref, {MKLDNNPlugin::impl_desc_type::ref}}));