1 // Copyright (C) 2018-2019 Intel Corporation
2 // SPDX-License-Identifier: Apache-2.0
5 #include <gtest/gtest.h>
6 #include <gmock/gmock-spec-builders.h>
7 #include "mkldnn_plugin/mkldnn_graph.h"
9 #include "test_graph.hpp"
11 #include "single_layer_common.hpp"
12 #include <mkldnn_plugin/mkldnn_extension_utils.h>
13 #include <inference_engine/cnn_network_impl.hpp>
14 #include "tests_common.hpp"
16 using namespace ::testing;
18 using namespace mkldnn;
20 struct split_test_params {
21 // Formats: NCHW, NCDHW
23 std::vector<vector<size_t>> outs;
29 MKLDNNPlugin::impl_desc_type selectedType;
30 std::vector<MKLDNNPlugin::impl_desc_type> preferTypes;
32 std::vector<std::function<void(MKLDNNPlugin::PrimitiveDescInfo)>> comp;
35 template <typename data_t>
36 void ref_split(InferenceEngine::TBlob<data_t> &src, std::vector<InferenceEngine::TBlob<data_t>>& dsts, split_test_params& prm) {
37 const float * srcData = src.readOnly();
40 for (int i = 0; i < prm.axis; i++)
41 outerSize *= src.dims()[i];
43 for (size_t osIdx = 0; osIdx < outerSize; osIdx++) {
44 for (size_t dstIdx = 0; dstIdx < dsts.size(); dstIdx++) {
45 float* dstData = dsts[dstIdx].data();
46 int innerSize = dsts[dstIdx].size() / outerSize;
48 for (size_t j = 0; j < innerSize; j++, srcData++) {
49 dstData[osIdx*innerSize + j] = *srcData;
55 class MKLDNNGraphSplitTests: public TestsCommon,
56 public WithParamInterface<split_test_params> {
57 std::string model_t = R"V0G0N(
58 <net name="ConcatOnly" version="3" precision="FP32" batch="1">
60 <layer name="in1" type="Input" precision="FP32" id="1">
71 <layer name="split" id="2" type="Split" precision="FP32">
72 <split_data axis="_AXIS_" PrimitivesPriority="_IMPLS_"/>
88 <edge from-layer="1" from-port="1" to-layer="2" to-port="1"/>
93 std::string port_t = R"V0G0N(
104 std::string getModel(split_test_params p) {
105 std::string model = model_t;
106 auto dims_size = p.dims.size();
110 REMOVE_LINE(model, "<dim>_IH_</dim>");
112 REMOVE_LINE(model, "<dim>_ID_</dim>");
114 REPLACE_WITH_NUM(model, "_IN_", p.dims[0]);
115 REPLACE_WITH_NUM(model, "_IC_", p.dims[1]);
116 REPLACE_WITH_NUM(model, "_IW_", p.dims[dims_size - 1]);
119 REPLACE_WITH_NUM(model, "_ID_", p.dims[dims_size - 3]);
121 REPLACE_WITH_NUM(model, "_IH_", p.dims[dims_size - 2]);
124 std::string outPorts;
125 for (int idx = 0; idx < p.outs.size(); idx++) {
126 std::string outPort = port_t;
129 REMOVE_LINE(outPort, "<dim>_H_</dim>");
131 REMOVE_LINE(outPort, "<dim>_D_</dim>");
133 REPLACE_WITH_NUM(outPort, "_ID_", idx);
134 REPLACE_WITH_NUM(outPort, "_N_", p.outs[idx][0]);
135 REPLACE_WITH_NUM(outPort, "_C_", p.outs[idx][1]);
136 REPLACE_WITH_NUM(outPort, "_W_", p.outs[idx][dims_size - 1]);
139 REPLACE_WITH_NUM(outPort, "_D_", p.outs[idx][dims_size - 3]);
141 REPLACE_WITH_NUM(outPort, "_H_", p.outs[idx][dims_size - 2]);
146 REPLACE_WITH_STR(model, "_OP_", outPorts);
148 REPLACE_WITH_NUM(model, "_AXIS_", p.axis);
151 for (const auto& preferType : p.preferTypes) {
154 impls += "cpu:" + MKLDNNGraphTestClass::getStrPrimitiveDescriptorType(preferType);
156 REPLACE_WITH_STR(model, "_IMPLS_", impls);
160 virtual void TearDown() {
163 virtual void SetUp() {
165 split_test_params p = ::testing::WithParamInterface<split_test_params>::GetParam();
166 std::string model = getModel(p);
168 InferenceEngine::CNNNetReader net_reader;
169 net_reader.ReadNetwork(model.data(), model.length());
171 MKLDNNGraphTestClass graph;
172 graph.CreateGraph(net_reader.getNetwork());
173 auto& nodes = graph.getNodes();
174 for (int i = 0; i < nodes.size(); i++) {
175 if (nodes[i]->getType() == MKLDNNPlugin::Split) {
176 ASSERT_EQ(p.num_prim_desc, nodes[i]->getSupportedPrimitiveDescriptors().size());
177 for (size_t j = 0; j < p.num_prim_desc && j < p.comp.size(); j++) {
178 p.comp.at(j)(nodes[i]->getSupportedPrimitiveDescriptors().at(j));
180 ASSERT_NE(nullptr, nodes[i]->getSelectedPrimitiveDescriptor());
181 ASSERT_EQ(p.selectedType, nodes[i]->getSelectedPrimitiveDescriptor()->getImplementationType());
184 ASSERT_LE(3, nodes.size());
186 InferenceEngine::SizeVector dims_src = p.dims;
187 InferenceEngine::Layout layout = InferenceEngine::ANY;
188 switch (p.dims.size()) {
190 layout = InferenceEngine::NCHW;
193 layout = InferenceEngine::NCDHW;
197 InferenceEngine::Blob::Ptr src = InferenceEngine::make_shared_blob<float, const InferenceEngine::SizeVector>(InferenceEngine::Precision::FP32, layout, dims_src);
199 fill_data(src->buffer(), src->size());
201 InferenceEngine::BlobMap srcs;
202 srcs.insert(std::pair<std::string, InferenceEngine::Blob::Ptr>("in1", src));
204 auto srcPtr = dynamic_cast<InferenceEngine::TBlob<float>*>(src.get());
206 if (srcPtr == nullptr)
207 FAIL() << "Cannot cast blob to TBlob<float>.";
209 InferenceEngine::OutputsDataMap out;
210 out = net_reader.getNetwork().getOutputsInfo();
211 InferenceEngine::BlobMap outputBlobs;
212 std::vector<InferenceEngine::TBlob<float>> dst_refs;
213 for (auto& item : out) {
214 InferenceEngine::TBlob<float>::Ptr output;
215 output = InferenceEngine::make_shared_blob<float>(item.second->getTensorDesc());
217 outputBlobs[item.first] = output;
219 InferenceEngine::TBlob<float> dst_ref(item.second->getTensorDesc());
221 dst_refs.push_back(dst_ref);
224 graph.Infer(srcs, outputBlobs);
226 ref_split(*srcPtr, dst_refs, p);
229 for (auto& output : outputBlobs) {
230 compare(*output.second, dst_refs[ref_idx++], 0.0005f);
232 } catch (const InferenceEngine::details::InferenceEngineException &e) {
238 TEST_P(MKLDNNGraphSplitTests, TestsSplit) {}
240 INSTANTIATE_TEST_CASE_P(
241 TestsSplit, MKLDNNGraphSplitTests,
245 {{1, 16, 2, 5}, {1, 8, 2, 5}},
246 1, 3, MKLDNNPlugin::impl_desc_type::unknown, {}, {
247 [](MKLDNNPlugin::PrimitiveDescInfo impl) {
248 ASSERT_EQ(MKLDNNPlugin::impl_desc_type::ref, impl.getImplementationType());
249 ASSERT_EQ(1, impl.getConfig().inConfs.size());
250 ASSERT_EQ(2, impl.getConfig().outConfs.size());
251 ASSERT_EQ(InferenceEngine::Layout::ANY, impl.getConfig().inConfs.at(0).desc.getLayout());
252 ASSERT_EQ(InferenceEngine::Layout::ANY, impl.getConfig().outConfs.at(0).desc.getLayout());
253 ASSERT_EQ(InferenceEngine::Layout::ANY, impl.getConfig().outConfs.at(1).desc.getLayout());
255 [](MKLDNNPlugin::PrimitiveDescInfo impl) {
256 ASSERT_EQ(MKLDNNPlugin::impl_desc_type::unknown, impl.getImplementationType());
257 ASSERT_EQ(1, impl.getConfig().inConfs.size());
258 ASSERT_EQ(2, impl.getConfig().outConfs.size());
259 ASSERT_EQ(InferenceEngine::Layout::NCHW, impl.getConfig().inConfs.at(0).desc.getLayout());
260 ASSERT_EQ(InferenceEngine::Layout::NCHW, impl.getConfig().outConfs.at(0).desc.getLayout());
261 ASSERT_EQ(InferenceEngine::Layout::NCHW, impl.getConfig().outConfs.at(1).desc.getLayout());
263 [](MKLDNNPlugin::PrimitiveDescInfo impl) {
264 ASSERT_EQ(MKLDNNPlugin::impl_desc_type::unknown, impl.getImplementationType());
265 ASSERT_EQ(1, impl.getConfig().inConfs.size());
266 ASSERT_EQ(2, impl.getConfig().outConfs.size());
267 ASSERT_EQ(InferenceEngine::Layout::BLOCKED, impl.getConfig().inConfs.at(0).desc.getLayout());
268 ASSERT_EQ(InferenceEngine::Layout::BLOCKED, impl.getConfig().outConfs.at(0).desc.getLayout());
269 ASSERT_EQ(InferenceEngine::Layout::BLOCKED, impl.getConfig().outConfs.at(1).desc.getLayout());
275 {{1, 13, 2, 5}, {1, 7, 2, 5}},
276 1, 2, MKLDNNPlugin::impl_desc_type::unknown, {}, {
277 [](MKLDNNPlugin::PrimitiveDescInfo impl) {
278 ASSERT_EQ(MKLDNNPlugin::impl_desc_type::ref, impl.getImplementationType());
279 ASSERT_EQ(1, impl.getConfig().inConfs.size());
280 ASSERT_EQ(2, impl.getConfig().outConfs.size());
281 ASSERT_EQ(InferenceEngine::Layout::ANY, impl.getConfig().inConfs.at(0).desc.getLayout());
282 ASSERT_EQ(InferenceEngine::Layout::ANY, impl.getConfig().outConfs.at(0).desc.getLayout());
283 ASSERT_EQ(InferenceEngine::Layout::ANY, impl.getConfig().outConfs.at(1).desc.getLayout());
285 [](MKLDNNPlugin::PrimitiveDescInfo impl) {
286 ASSERT_EQ(MKLDNNPlugin::impl_desc_type::unknown, impl.getImplementationType());
287 ASSERT_EQ(1, impl.getConfig().inConfs.size());
288 ASSERT_EQ(2, impl.getConfig().outConfs.size());
289 ASSERT_EQ(InferenceEngine::Layout::NCHW, impl.getConfig().inConfs.at(0).desc.getLayout());
290 ASSERT_EQ(InferenceEngine::Layout::NCHW, impl.getConfig().outConfs.at(0).desc.getLayout());
291 ASSERT_EQ(InferenceEngine::Layout::NCHW, impl.getConfig().outConfs.at(1).desc.getLayout());
297 {{1, 10, 2, 5}, {1, 10, 2, 5}},
298 1, 2, MKLDNNPlugin::impl_desc_type::unknown, {}, {
299 [](MKLDNNPlugin::PrimitiveDescInfo impl) {
300 ASSERT_EQ(MKLDNNPlugin::impl_desc_type::ref, impl.getImplementationType());
301 ASSERT_EQ(1, impl.getConfig().inConfs.size());
302 ASSERT_EQ(2, impl.getConfig().outConfs.size());
303 ASSERT_EQ(InferenceEngine::Layout::ANY, impl.getConfig().inConfs.at(0).desc.getLayout());
304 ASSERT_EQ(InferenceEngine::Layout::ANY, impl.getConfig().outConfs.at(0).desc.getLayout());
305 ASSERT_EQ(InferenceEngine::Layout::ANY, impl.getConfig().outConfs.at(1).desc.getLayout());
307 [](MKLDNNPlugin::PrimitiveDescInfo impl) {
308 ASSERT_EQ(MKLDNNPlugin::impl_desc_type::unknown, impl.getImplementationType());
309 ASSERT_EQ(1, impl.getConfig().inConfs.size());
310 ASSERT_EQ(2, impl.getConfig().outConfs.size());
311 ASSERT_EQ(InferenceEngine::Layout::NCHW, impl.getConfig().inConfs.at(0).desc.getLayout());
312 ASSERT_EQ(InferenceEngine::Layout::NCHW, impl.getConfig().outConfs.at(0).desc.getLayout());
313 ASSERT_EQ(InferenceEngine::Layout::NCHW, impl.getConfig().outConfs.at(1).desc.getLayout());
319 {{2, 10, 2, 5}, {2, 10, 2, 5}},
320 1, 2, MKLDNNPlugin::impl_desc_type::unknown, {}, {
321 [](MKLDNNPlugin::PrimitiveDescInfo impl) {
322 ASSERT_EQ(MKLDNNPlugin::impl_desc_type::ref, impl.getImplementationType());
323 ASSERT_EQ(1, impl.getConfig().inConfs.size());
324 ASSERT_EQ(2, impl.getConfig().outConfs.size());
325 ASSERT_EQ(InferenceEngine::Layout::ANY, impl.getConfig().inConfs.at(0).desc.getLayout());
326 ASSERT_EQ(InferenceEngine::Layout::ANY, impl.getConfig().outConfs.at(0).desc.getLayout());
327 ASSERT_EQ(InferenceEngine::Layout::ANY, impl.getConfig().outConfs.at(1).desc.getLayout());
329 [](MKLDNNPlugin::PrimitiveDescInfo impl) {
330 ASSERT_EQ(MKLDNNPlugin::impl_desc_type::unknown, impl.getImplementationType());
331 ASSERT_EQ(1, impl.getConfig().inConfs.size());
332 ASSERT_EQ(2, impl.getConfig().outConfs.size());
333 ASSERT_EQ(InferenceEngine::Layout::NCHW, impl.getConfig().inConfs.at(0).desc.getLayout());
334 ASSERT_EQ(InferenceEngine::Layout::NCHW, impl.getConfig().outConfs.at(0).desc.getLayout());
335 ASSERT_EQ(InferenceEngine::Layout::NCHW, impl.getConfig().outConfs.at(1).desc.getLayout());
341 {{1, 16, 2, 5}, {1, 8, 2, 5}},
342 1, 3, MKLDNNPlugin::impl_desc_type::ref, {MKLDNNPlugin::impl_desc_type::ref}
346 {{1, 13, 2, 5}, {1, 7, 2, 5}},
347 1, 2, MKLDNNPlugin::impl_desc_type::ref, {MKLDNNPlugin::impl_desc_type::ref}
351 {{1, 10, 2, 5}, {1, 10, 2, 5}},
352 1, 2, MKLDNNPlugin::impl_desc_type::ref, {MKLDNNPlugin::impl_desc_type::ref}
356 {{2, 10, 2, 5}, {2, 10, 2, 5}},
357 1, 2, MKLDNNPlugin::impl_desc_type::ref, {MKLDNNPlugin::impl_desc_type::ref}
361 {{2, 15, 2, 5}, {2, 5, 2, 5}},
362 1, 2, MKLDNNPlugin::impl_desc_type::ref, {MKLDNNPlugin::impl_desc_type::ref}
366 {{3, 11, 7, 5}, {6, 11, 7, 5}},
367 0, 2, MKLDNNPlugin::impl_desc_type::ref, {MKLDNNPlugin::impl_desc_type::ref}
371 {{3, 11, 4, 5}, {3, 11, 3, 5}},
372 2, 2, MKLDNNPlugin::impl_desc_type::ref, {MKLDNNPlugin::impl_desc_type::ref}
376 {{3, 11, 7, 1}, {3, 11, 7, 4}},
377 3, 2, MKLDNNPlugin::impl_desc_type::ref, {MKLDNNPlugin::impl_desc_type::ref}
381 {{1, 6, 7, 15}, {2, 6, 7, 15}, {1, 6, 7, 15}, {1, 6, 7, 15}},
382 0, 2, MKLDNNPlugin::impl_desc_type::ref, {MKLDNNPlugin::impl_desc_type::ref}
386 {{5, 1, 7, 15}, {5, 2, 7, 15}, {5, 1, 7, 15}, {5, 2, 7, 15}},
387 1, 2, MKLDNNPlugin::impl_desc_type::ref, {MKLDNNPlugin::impl_desc_type::ref}
391 {{5, 6, 3, 15}, {5, 6, 1, 15}, {5, 6, 2, 15}, {5, 6, 1, 15}},
392 2, 2, MKLDNNPlugin::impl_desc_type::ref, {MKLDNNPlugin::impl_desc_type::ref}
396 {{5, 6, 7, 5}, {5, 6, 7, 3}, {5, 6, 7, 4}, {5, 6, 7, 3}},
397 3, 2, MKLDNNPlugin::impl_desc_type::ref, {MKLDNNPlugin::impl_desc_type::ref}
402 1, 2, MKLDNNPlugin::impl_desc_type::ref, {MKLDNNPlugin::impl_desc_type::ref}},
405 {{1, 8, 16, 16, 16}, {1, 8, 16, 16, 16}, {1, 8, 16, 16, 16}, {1, 8, 16, 16, 16}},
406 1, 3, MKLDNNPlugin::impl_desc_type::ref, {MKLDNNPlugin::impl_desc_type::ref}},
409 {{1, 8, 16, 16, 16}, {1, 8, 16, 16, 16}, {1, 8, 16, 16, 16}, {1, 8, 16, 16, 16}},
410 1, 3, MKLDNNPlugin::impl_desc_type::unknown, {}}));
412 class MKLDNNGraphDynBatchSplitTests: public MKLDNNGraphSplitTests {
414 virtual void SetUp() {
416 split_test_params p = ::testing::WithParamInterface<split_test_params>::GetParam();
417 std::string model = getModel(p);
418 size_t MB = p.dims[0];
422 InferenceEngine::CNNNetReader net_reader;
423 ASSERT_NO_THROW(net_reader.ReadNetwork(model.data(), model.length()));
424 InferenceEngine::CNNNetwork network = net_reader.getNetwork();
425 auto implNet = dynamic_cast<InferenceEngine::details::CNNNetworkImpl *>(&((InferenceEngine::ICNNNetwork&)network));
426 ASSERT_NE(nullptr, implNet) << "Failed to cast ICNNNetwork to CNNNetworkImpl";
427 InferenceEngine::ResponseDesc resp;
428 InferenceEngine::StatusCode sts = implNet->setBatchSizeReshape(MB, &resp);
429 ASSERT_EQ((int)InferenceEngine::StatusCode::OK, sts) << resp.msg;
431 MKLDNNGraphTestClass graph;
432 graph.setProperty({{InferenceEngine::PluginConfigParams::KEY_DYN_BATCH_ENABLED, InferenceEngine::PluginConfigParams::YES}});
433 graph.CreateGraph(net_reader.getNetwork());
435 InferenceEngine::SizeVector dims_src = p.dims;
436 InferenceEngine::Layout layout = InferenceEngine::ANY;
437 switch (p.dims.size()) {
439 layout = InferenceEngine::NCHW;
442 layout = InferenceEngine::NCDHW;
446 InferenceEngine::Blob::Ptr src = InferenceEngine::make_shared_blob<float, const InferenceEngine::SizeVector>(InferenceEngine::Precision::FP32, layout, dims_src);
448 fill_data(src->buffer(), src->size());
450 InferenceEngine::BlobMap srcs;
451 srcs.insert(std::pair<std::string, InferenceEngine::Blob::Ptr>("in1", src));
453 InferenceEngine::TBlob<float>* srcPtr = dynamic_cast<InferenceEngine::TBlob<float>*>(src.get());
455 if (srcPtr == nullptr)
456 FAIL() << "Cannot cast blob to TBlob<float>.";
458 InferenceEngine::OutputsDataMap out;
459 out = net_reader.getNetwork().getOutputsInfo();
460 InferenceEngine::BlobMap outputBlobs;
461 auto it = out.begin();
463 std::pair<std::string, InferenceEngine::DataPtr> item = *it;
465 InferenceEngine::TBlob<float>::Ptr output1;
466 output1 = InferenceEngine::make_shared_blob<float>(item.second->getTensorDesc());
468 outputBlobs[item.first] = output1;
471 InferenceEngine::TBlob<float>::Ptr output2;
472 output2 = InferenceEngine::make_shared_blob<float>(item.second->getTensorDesc());
474 outputBlobs[item.first] = output2;
476 auto checkSplit = [](const MKLDNNPlugin::MKLDNNNodePtr& node) {
477 return node->getType() == MKLDNNPlugin::Split;
480 graph.checkDynBatch(srcs, outputBlobs, MB, MB, checkSplit);
481 graph.checkDynBatch(srcs, outputBlobs, 1, MB, checkSplit);
482 } catch (const InferenceEngine::details::InferenceEngineException &e) {
488 TEST_P(MKLDNNGraphDynBatchSplitTests, TestsDynBatchSplit) {}
490 INSTANTIATE_TEST_CASE_P(
491 TestsDynBatchSplit, MKLDNNGraphDynBatchSplitTests,
495 {{1, 16, 2, 5}, {1, 8, 2, 5}},
496 1, 3, MKLDNNPlugin::impl_desc_type::unknown, {}, {
497 [](MKLDNNPlugin::PrimitiveDescInfo impl) {
498 ASSERT_EQ(MKLDNNPlugin::impl_desc_type::ref, impl.getImplementationType());
499 ASSERT_EQ(1, impl.getConfig().inConfs.size());
500 ASSERT_EQ(2, impl.getConfig().outConfs.size());
501 ASSERT_EQ(InferenceEngine::Layout::ANY, impl.getConfig().inConfs.at(0).desc.getLayout());
502 ASSERT_EQ(InferenceEngine::Layout::ANY, impl.getConfig().outConfs.at(0).desc.getLayout());
503 ASSERT_EQ(InferenceEngine::Layout::ANY, impl.getConfig().outConfs.at(1).desc.getLayout());
505 [](MKLDNNPlugin::PrimitiveDescInfo impl) {
506 ASSERT_EQ(MKLDNNPlugin::impl_desc_type::unknown, impl.getImplementationType());
507 ASSERT_EQ(1, impl.getConfig().inConfs.size());
508 ASSERT_EQ(2, impl.getConfig().outConfs.size());
509 ASSERT_EQ(InferenceEngine::Layout::NCHW, impl.getConfig().inConfs.at(0).desc.getLayout());
510 ASSERT_EQ(InferenceEngine::Layout::NCHW, impl.getConfig().outConfs.at(0).desc.getLayout());
511 ASSERT_EQ(InferenceEngine::Layout::NCHW, impl.getConfig().outConfs.at(1).desc.getLayout());
513 [](MKLDNNPlugin::PrimitiveDescInfo impl) {
514 ASSERT_EQ(MKLDNNPlugin::impl_desc_type::unknown, impl.getImplementationType());
515 ASSERT_EQ(1, impl.getConfig().inConfs.size());
516 ASSERT_EQ(2, impl.getConfig().outConfs.size());
517 ASSERT_EQ(InferenceEngine::Layout::BLOCKED, impl.getConfig().inConfs.at(0).desc.getLayout());
518 ASSERT_EQ(InferenceEngine::Layout::BLOCKED, impl.getConfig().outConfs.at(0).desc.getLayout());
519 ASSERT_EQ(InferenceEngine::Layout::BLOCKED, impl.getConfig().outConfs.at(1).desc.getLayout());
525 {{1, 13, 2, 5}, {1, 7, 2, 5}},
526 1, 2, MKLDNNPlugin::impl_desc_type::unknown, {}, {
527 [](MKLDNNPlugin::PrimitiveDescInfo impl) {
528 ASSERT_EQ(MKLDNNPlugin::impl_desc_type::ref, impl.getImplementationType());
529 ASSERT_EQ(1, impl.getConfig().inConfs.size());
530 ASSERT_EQ(2, impl.getConfig().outConfs.size());
531 ASSERT_EQ(InferenceEngine::Layout::ANY, impl.getConfig().inConfs.at(0).desc.getLayout());
532 ASSERT_EQ(InferenceEngine::Layout::ANY, impl.getConfig().outConfs.at(0).desc.getLayout());
533 ASSERT_EQ(InferenceEngine::Layout::ANY, impl.getConfig().outConfs.at(1).desc.getLayout());
535 [](MKLDNNPlugin::PrimitiveDescInfo impl) {
536 ASSERT_EQ(MKLDNNPlugin::impl_desc_type::unknown, impl.getImplementationType());
537 ASSERT_EQ(1, impl.getConfig().inConfs.size());
538 ASSERT_EQ(2, impl.getConfig().outConfs.size());
539 ASSERT_EQ(InferenceEngine::Layout::NCHW, impl.getConfig().inConfs.at(0).desc.getLayout());
540 ASSERT_EQ(InferenceEngine::Layout::NCHW, impl.getConfig().outConfs.at(0).desc.getLayout());
541 ASSERT_EQ(InferenceEngine::Layout::NCHW, impl.getConfig().outConfs.at(1).desc.getLayout());
547 {{1, 10, 2, 5}, {1, 10, 2, 5}},
548 1, 2, MKLDNNPlugin::impl_desc_type::unknown, {}, {
549 [](MKLDNNPlugin::PrimitiveDescInfo impl) {
550 ASSERT_EQ(MKLDNNPlugin::impl_desc_type::ref, impl.getImplementationType());
551 ASSERT_EQ(1, impl.getConfig().inConfs.size());
552 ASSERT_EQ(2, impl.getConfig().outConfs.size());
553 ASSERT_EQ(InferenceEngine::Layout::ANY, impl.getConfig().inConfs.at(0).desc.getLayout());
554 ASSERT_EQ(InferenceEngine::Layout::ANY, impl.getConfig().outConfs.at(0).desc.getLayout());
555 ASSERT_EQ(InferenceEngine::Layout::ANY, impl.getConfig().outConfs.at(1).desc.getLayout());
557 [](MKLDNNPlugin::PrimitiveDescInfo impl) {
558 ASSERT_EQ(MKLDNNPlugin::impl_desc_type::unknown, impl.getImplementationType());
559 ASSERT_EQ(1, impl.getConfig().inConfs.size());
560 ASSERT_EQ(2, impl.getConfig().outConfs.size());
561 ASSERT_EQ(InferenceEngine::Layout::NCHW, impl.getConfig().inConfs.at(0).desc.getLayout());
562 ASSERT_EQ(InferenceEngine::Layout::NCHW, impl.getConfig().outConfs.at(0).desc.getLayout());
563 ASSERT_EQ(InferenceEngine::Layout::NCHW, impl.getConfig().outConfs.at(1).desc.getLayout());
569 {{2, 10, 2, 5}, {2, 10, 2, 5}},
570 1, 2, MKLDNNPlugin::impl_desc_type::unknown, {}, {
571 [](MKLDNNPlugin::PrimitiveDescInfo impl) {
572 ASSERT_EQ(MKLDNNPlugin::impl_desc_type::ref, impl.getImplementationType());
573 ASSERT_EQ(1, impl.getConfig().inConfs.size());
574 ASSERT_EQ(2, impl.getConfig().outConfs.size());
575 ASSERT_EQ(InferenceEngine::Layout::ANY, impl.getConfig().inConfs.at(0).desc.getLayout());
576 ASSERT_EQ(InferenceEngine::Layout::ANY, impl.getConfig().outConfs.at(0).desc.getLayout());
577 ASSERT_EQ(InferenceEngine::Layout::ANY, impl.getConfig().outConfs.at(1).desc.getLayout());
579 [](MKLDNNPlugin::PrimitiveDescInfo impl) {
580 ASSERT_EQ(MKLDNNPlugin::impl_desc_type::unknown, impl.getImplementationType());
581 ASSERT_EQ(1, impl.getConfig().inConfs.size());
582 ASSERT_EQ(2, impl.getConfig().outConfs.size());
583 ASSERT_EQ(InferenceEngine::Layout::NCHW, impl.getConfig().inConfs.at(0).desc.getLayout());
584 ASSERT_EQ(InferenceEngine::Layout::NCHW, impl.getConfig().outConfs.at(0).desc.getLayout());
585 ASSERT_EQ(InferenceEngine::Layout::NCHW, impl.getConfig().outConfs.at(1).desc.getLayout());
591 {{2, 16, 2, 5}, {2, 8, 2, 5}},
592 1, 3, MKLDNNPlugin::impl_desc_type::ref, {MKLDNNPlugin::impl_desc_type::ref}
596 {{1, 13, 2, 5}, {1, 7, 2, 5}},
597 1, 2, MKLDNNPlugin::impl_desc_type::ref, {MKLDNNPlugin::impl_desc_type::ref}
601 {{1, 10, 2, 5}, {1, 10, 2, 5}},
602 1, 2, MKLDNNPlugin::impl_desc_type::ref, {MKLDNNPlugin::impl_desc_type::ref}
606 {{2, 10, 2, 5}, {2, 10, 2, 5}},
607 1, 2, MKLDNNPlugin::impl_desc_type::ref, {MKLDNNPlugin::impl_desc_type::ref}
611 {{2, 15, 2, 5}, {2, 5, 2, 5}},
612 1, 2, MKLDNNPlugin::impl_desc_type::ref, {MKLDNNPlugin::impl_desc_type::ref}
616 {{3, 11, 4, 5}, {3, 11, 3, 5}},
617 2, 2, MKLDNNPlugin::impl_desc_type::ref, {MKLDNNPlugin::impl_desc_type::ref}
621 {{3, 11, 7, 1}, {3, 11, 7, 4}},
622 3, 2, MKLDNNPlugin::impl_desc_type::ref, {MKLDNNPlugin::impl_desc_type::ref}
626 {{5, 1, 7, 15}, {5, 2, 7, 15}, {5, 1, 7, 15}, {5, 2, 7, 15}},
627 1, 2, MKLDNNPlugin::impl_desc_type::ref, {MKLDNNPlugin::impl_desc_type::ref}
631 {{5, 6, 3, 15}, {5, 6, 1, 15}, {5, 6, 2, 15}, {5, 6, 1, 15}},
632 2, 2, MKLDNNPlugin::impl_desc_type::ref, {MKLDNNPlugin::impl_desc_type::ref}
636 {{5, 6, 7, 5}, {5, 6, 7, 3}, {5, 6, 7, 4}, {5, 6, 7, 3}},
637 3, 2, MKLDNNPlugin::impl_desc_type::ref, {MKLDNNPlugin::impl_desc_type::ref}}));