1 // Copyright (C) 2018-2019 Intel Corporation
2 // SPDX-License-Identifier: Apache-2.0
5 #include <gtest/gtest.h>
6 #include <inference_engine/graph_tools.hpp>
7 #include "graph_test_base.hpp"
8 #include <unordered_set>
9 #include "mock_icnn_network.hpp"
10 #include <gmock/gmock-generated-function-mockers.h>
11 #include <gmock/gmock-generated-matchers.h>
12 #include <gmock/gmock-more-actions.h>
13 #include "xml_father.hpp"
14 #include "ie_common.h"
16 #include "details/ie_cnn_network_tools.h"
18 using namespace testing;
19 using namespace InferenceEngine;
21 using namespace GraphTest;
23 class GraphToolsTest : public GraphTestsBase {
27 TEST_F(GraphToolsTest, canRunSimpleDFS) {
34 EXPECT_CALL(*this, visited(0 ,0)).Times(1);
35 EXPECT_CALL(*this, visited(1, IsBetween(1,3))).Times(1);
36 EXPECT_CALL(*this, visited(2, IsBetween(1,3))).Times(1);
37 EXPECT_CALL(*this, visited(3, 2)).Times(1);
40 CNNNetDFS(layers[0], [&] (const CNNLayerPtr & layer) {
41 visited(ID(layer), idx++);
46 TEST_F(GraphToolsTest, canRunCycleDFS) {
52 EXPECT_CALL(*this, visited(0 ,0)).Times(1);
53 EXPECT_CALL(*this, visited(1, 1)).Times(1);
54 EXPECT_CALL(*this, visited(2, 2)).Times(1);
57 CNNNetDFS(layers[0], [&] (const CNNLayerPtr & layer) {
58 visited(ID(layer), idx++);
63 TEST_F(GraphToolsTest, canRunBFS) {
70 EXPECT_CALL(*this, visited(0 ,0)).Times(1);
71 EXPECT_CALL(*this, visited(1, IsBetween(1,3))).Times(1);
72 EXPECT_CALL(*this, visited(2, IsBetween(1,3))).Times(1);
73 EXPECT_CALL(*this, visited(3, IsBetween(1,3))).Times(1);
74 EXPECT_CALL(*this, visited(4, 4)).Times(1);
77 CNNNetBFS(layers[0], [&] (const InferenceEngine::CNNLayerPtr & layer) {
78 visited(ID(layer), idx++);
83 TEST_F(GraphToolsTest, canRunNBFS) {
90 EXPECT_CALL(*this, visited(0 ,0)).Times(1);
91 EXPECT_CALL(*this, visited(1, IsBetween(1,3))).Times(1);
92 EXPECT_CALL(*this, visited(2, IsBetween(1,3))).Times(1);
93 EXPECT_CALL(*this, visited(3, IsBetween(1,3))).Times(1);
96 CNNNetNBFS(layers[0], 1, [&] (const InferenceEngine::CNNLayerPtr & layer) {
97 visited(ID(layer), idx++);
101 TEST_F(GraphToolsTest, canSortTopologically) {
107 EXPECT_CALL(*mockNet, getInputsInfo(_)).WillOnce(WithArg<0>(Invoke([&](InputsDataMap & maps){
110 auto sorted = CNNNetSortTopologically(*mockNet);
112 EXPECT_EQ(sorted.size(), 4);
114 // first element can be 0 or 2 depending on implementation
116 sorted[0]->name=="0" && sorted[1]->name=="2" ||
117 sorted[0]->name=="2" && sorted[1]->name=="0");
119 EXPECT_STREQ(sorted[2]->name.c_str(), "1");
120 EXPECT_STREQ(sorted[3]->name.c_str(), "4");
123 TEST_F(GraphToolsTest, canDetectLoopsWhileSortTing) {
125 // 1->2->3-> 4->5->6->7-> 8
129 // └----------------┘
142 EXPECT_CALL(*mockNet, getInputsInfo(_)).WillOnce(WithArg<0>(Invoke([&](InputsDataMap & maps){
145 ASSERT_ANY_THROW(CNNNetSortTopologically(*mockNet));
149 TEST_F(GraphToolsTest, canSortIfInputsPointsToLayerWithMultiInputs) {
157 EXPECT_CALL(*mockNet, getInputsInfo(_)).WillOnce(WithArg<0>(Invoke([&](InputsDataMap & maps){
161 auto sorted = CNNNetSortTopologically(*mockNet);
163 vector<vector<string>> expected = {
164 {"1", "3", "4", "5", "2"},
165 {"3", "4", "5", "1", "2"},
166 {"3", "5", "4", "1", "2"},
167 {"1", "3", "5", "4", "2"},
171 for (auto ex: expected) {
173 for (auto i = 0; i < ex.size(); i++) {
174 if (sorted[i]->name != ex[i]) {
181 std::stringstream actual;
182 for (auto x : sorted) {
183 actual << x->name << " ";
186 EXPECT_FALSE(bFailed) << actual.str() << "doesn't match: one of expected" ;
189 TEST_F(GraphToolsTest, canGetAllMemoryInputsLayersFromStandardInputs) {
206 EXPECT_CALL(*mockNet, getInputsInfo(_)).WillOnce(WithArg<0>(Invoke([&](InputsDataMap & maps){
207 prepareSomeInputs(maps, {1});
209 auto allInputLayers = CNNNetGetAllInputLayers(*mockNet);
210 ASSERT_EQ(3, allInputLayers.size());
211 auto element = allInputLayers.begin();
212 ASSERT_STREQ("1", element->get()->name.c_str());
214 ASSERT_STREQ("2", element->get()->name.c_str());
216 ASSERT_STREQ("3", element->get()->name.c_str());
219 TEST_F(GraphToolsTest, canGetSingleInputLayer) {
223 EXPECT_CALL(*mockNet, getInputsInfo(_)).WillOnce(WithArg<0>(Invoke([&](InputsDataMap & maps){
224 prepareSomeInputs(maps, {1});
226 auto allInputLayers = CNNNetGetAllInputLayers(*mockNet);
227 ASSERT_EQ(1, allInputLayers.size());
230 TEST_F(GraphToolsTest, canIterateOverCNNNetwork) {
242 EXPECT_CALL(*mockNet, getInputsInfo(_)).WillRepeatedly(WithArg<0>(Invoke([&](InputsDataMap & maps){
246 std::vector<CNNLayerPtr>resultedOrder;
247 for (auto l : wrap) {
248 resultedOrder.push_back(l);
251 ASSERT_EQ(wrap.size(), 8);
252 ASSERT_STREQ(resultedOrder[0]->name.c_str(), "2");
253 ASSERT_STREQ(resultedOrder[1]->name.c_str(), "6");
254 ASSERT_STREQ(resultedOrder[2]->name.c_str(), "1");
255 ASSERT_STREQ(resultedOrder[3]->name.c_str(), "7");
256 ASSERT_STREQ(resultedOrder[4]->name.c_str(), "3");
257 ASSERT_STREQ(resultedOrder[5]->name.c_str(), "8");
258 ASSERT_STREQ(resultedOrder[6]->name.c_str(), "4");
259 ASSERT_STREQ(resultedOrder[7]->name.c_str(), "5");
262 TEST_F(GraphToolsTest, canIterateOverCNNNetworkWithCycle) {
268 EXPECT_CALL(*mockNet, getInputsInfo(_)).WillRepeatedly(WithArg<0>(Invoke([&](InputsDataMap & maps){
272 std::vector<CNNLayerPtr>resultedOrder;
273 for (auto l : wrap) {
274 resultedOrder.push_back(l);
277 ASSERT_EQ(wrap.size(), 4);
278 ASSERT_STREQ(resultedOrder[0]->name.c_str(), "2");
279 ASSERT_STREQ(resultedOrder[1]->name.c_str(), "3");
280 ASSERT_STREQ(resultedOrder[2]->name.c_str(), "1");
281 ASSERT_STREQ(resultedOrder[3]->name.c_str(), "4");
284 TEST_F(GraphToolsTest, canCompareCNNNetworkIterators) {
288 EXPECT_CALL(*mockNet, getInputsInfo(_)).WillOnce(WithArg<0>(Invoke([&](InputsDataMap & maps){
292 auto i = std::begin(wrap);
301 TEST_F(GraphToolsTest, canIterateOverEmptyNetwork) {
305 EXPECT_CALL(*mockNet, getInputsInfo(_)).WillOnce(WithArg<0>(Invoke([&](InputsDataMap & maps){
309 ASSERT_EQ(std::begin(wrap), std::end(wrap));
312 TEST_F(GraphToolsTest, CNNNetSwapLayersThrowsForNullPointers) {
313 CNNLayerPtr nullLayer;
314 ASSERT_ANY_THROW(CNNNetSwapLayers(nullLayer, nullLayer));
317 TEST_F(GraphToolsTest, CNNNetSwapLayersSwapWithItself) {
321 EXPECT_CALL(*mockNet, getInputsInfo(_)).WillRepeatedly(WithArg<0>(Invoke([&](InputsDataMap & maps){
325 EXPECT_CALL(*mockNet, getLayerByName(_,_,_)).WillRepeatedly(WithArgs<0,1>(Invoke([&](const char* name, InferenceEngine::CNNLayerPtr& l){
326 l = layerByName(name);
327 return l== nullptr ? GENERAL_ERROR : OK;
330 auto l = wrap.getLayerByName("2");
332 ASSERT_NO_THROW(CNNNetSwapLayers(l, l));
334 ASSERT_CONNECTION(1, 2);
335 ASSERT_CONNECTION(2, 3);
338 TEST_F(GraphToolsTest, CNNNetSwapLayersSimpleCase_1) {
341 EXPECT_CALL(*mockNet, getInputsInfo(_)).WillRepeatedly(WithArg<0>(Invoke([&](InputsDataMap & maps){
345 EXPECT_CALL(*mockNet, getLayerByName(_, _, _)).WillRepeatedly(WithArgs<0,1>(Invoke([&](const char* name, InferenceEngine::CNNLayerPtr& l){
346 l = layerByName(name);
347 return l== nullptr ? GENERAL_ERROR : OK;
350 auto l = wrap.getLayerByName("1");
351 auto r = wrap.getLayerByName("2");
353 ASSERT_NO_THROW(CNNNetSwapLayers(l, r));
355 ASSERT_CONNECTION(2, 1);
358 TEST_F(GraphToolsTest, CNNNetSwapLayersSimpleCase_2) {
362 EXPECT_CALL(*mockNet, getInputsInfo(_)).WillRepeatedly(WithArg<0>(Invoke([&](InputsDataMap & maps){
366 EXPECT_CALL(*mockNet, getLayerByName(_, _, _)).WillRepeatedly(WithArgs<0,1>(Invoke([&](const char* name, InferenceEngine::CNNLayerPtr& l){
367 l = layerByName(name);
368 return l== nullptr ? GENERAL_ERROR : OK;
371 auto l = wrap.getLayerByName("2");
372 auto r = wrap.getLayerByName("3");
374 ASSERT_NO_THROW(CNNNetSwapLayers(l, r));
376 ASSERT_CONNECTION(1, 3);
377 ASSERT_CONNECTION(3, 2);
380 TEST_F(GraphToolsTest, CNNNetSwapLayersSimpleCase_3) {
384 EXPECT_CALL(*mockNet, getInputsInfo(_)).WillRepeatedly(WithArg<0>(Invoke([&](InputsDataMap & maps){
388 EXPECT_CALL(*mockNet, getLayerByName(_, _, _)).WillRepeatedly(WithArgs<0,1>(Invoke([&](const char* name, InferenceEngine::CNNLayerPtr& l){
389 l = layerByName(name);
390 return l== nullptr ? GENERAL_ERROR : OK;
393 auto l = wrap.getLayerByName("1");
394 auto r = wrap.getLayerByName("2");
396 ASSERT_NO_THROW(CNNNetSwapLayers(l, r));
398 ASSERT_CONNECTION(2, 1);
399 ASSERT_CONNECTION(1, 3);
402 TEST_F(GraphToolsTest, CNNNetSwapLayersDoesSwapDims) {
406 SET_DIMS(1, {10, 1});
407 SET_DIMS(2, {20, 1});
408 SET_DIMS(3, {30, 1});
410 EXPECT_CALL(*mockNet, getInputsInfo(_)).WillRepeatedly(WithArg<0>(Invoke([&](InputsDataMap & maps){
414 EXPECT_CALL(*mockNet, getLayerByName(_, _, _)).WillRepeatedly(WithArgs<0,1>(Invoke([&](const char* name, InferenceEngine::CNNLayerPtr& l){
415 l = layerByName(name);
416 return l== nullptr ? GENERAL_ERROR : OK;
419 auto l = wrap.getLayerByName("1");
420 auto r = wrap.getLayerByName("2");
422 ASSERT_NO_THROW(CNNNetSwapLayers(l, r));
424 ASSERT_CONNECTION(2, 1);
425 ASSERT_CONNECTION(1, 3);
427 ASSERT_DIMS(1, {20, 1});
428 ASSERT_DIMS(2, {20, 1});
431 TEST_F(GraphToolsTest, CNNNetSwapLayersSimpleCase_4) {
437 EXPECT_CALL(*mockNet, getInputsInfo(_)).WillRepeatedly(WithArg<0>(Invoke([&](InputsDataMap & maps){
441 EXPECT_CALL(*mockNet, getLayerByName(_, _, _)).WillRepeatedly(WithArgs<0,1>(Invoke([&](const char* name, InferenceEngine::CNNLayerPtr& l){
442 l = layerByName(name);
443 return l== nullptr ? GENERAL_ERROR : OK;
446 auto l = wrap.getLayerByName("2");
447 auto r = wrap.getLayerByName("4");
449 ASSERT_NO_THROW(CNNNetSwapLayers(l, r));
451 ASSERT_CONNECTION(1, 4);
452 ASSERT_CONNECTION(4, 3);
453 ASSERT_CONNECTION(3, 2);
454 ASSERT_CONNECTION(2, 5);
457 TEST_F(GraphToolsTest, CNNNetSwapLayersSplit) {
461 EXPECT_CALL(*mockNet, getInputsInfo(_)).WillRepeatedly(WithArg<0>(Invoke([&](InputsDataMap & maps){
465 EXPECT_CALL(*mockNet, getLayerByName(_, _, _)).WillRepeatedly(WithArgs<0,1>(Invoke([&](const char* name, InferenceEngine::CNNLayerPtr& l){
466 l = layerByName(name);
467 return l== nullptr ? GENERAL_ERROR : OK;
470 auto l = wrap.getLayerByName("2");
471 auto r = wrap.getLayerByName("3");
473 ASSERT_NO_THROW(CNNNetSwapLayers(l, r));
475 ASSERT_CONNECTION(1, 2);
476 ASSERT_CONNECTION(1, 3);
478 TEST_F(GraphToolsTest, CNNNetSwapLayersSplit_2) {
482 EXPECT_CALL(*mockNet, getInputsInfo(_)).WillRepeatedly(WithArg<0>(Invoke([&](InputsDataMap & maps){
486 EXPECT_CALL(*mockNet, getLayerByName(_, _, _)).WillRepeatedly(WithArgs<0,1>(Invoke([&](const char* name, InferenceEngine::CNNLayerPtr& l){
487 l = layerByName(name);
488 return l== nullptr ? GENERAL_ERROR : OK;
491 auto l = wrap.getLayerByName("1");
492 auto r = wrap.getLayerByName("2");
494 ASSERT_NO_THROW(CNNNetSwapLayers(l, r));
496 ASSERT_CONNECTION(2, 1);
497 ASSERT_CONNECTION(2, 3);
500 TEST_F(GraphToolsTest, CNNNetSwapLayersSplit_3) {
507 EXPECT_CALL(*mockNet, getInputsInfo(_)).WillRepeatedly(WithArg<0>(Invoke([&](InputsDataMap & maps){
511 EXPECT_CALL(*mockNet, getLayerByName(_, _, _)).WillRepeatedly(WithArgs<0,1>(Invoke([&](const char* name, InferenceEngine::CNNLayerPtr& l){
512 l = layerByName(name);
513 return l== nullptr ? GENERAL_ERROR : OK;
516 auto l = wrap.getLayerByName("1");
517 auto r = wrap.getLayerByName("2");
519 ASSERT_NO_THROW(CNNNetSwapLayers(l, r));
521 ASSERT_CONNECTION(2, 1);
522 ASSERT_CONNECTION(2, 6);
523 ASSERT_CONNECTION(1, 3);
524 ASSERT_CONNECTION(1, 4);
525 ASSERT_CONNECTION(1, 5);
528 TEST_F(GraphToolsTest, CNNNetSwapLayersSplit_4) {
535 EXPECT_CALL(*mockNet, getInputsInfo(_)).WillRepeatedly(WithArg<0>(Invoke([&](InputsDataMap & maps){
539 EXPECT_CALL(*mockNet, getLayerByName(_, _, _)).WillRepeatedly(WithArgs<0,1>(Invoke([&](const char* name, InferenceEngine::CNNLayerPtr& l){
540 l = layerByName(name);
541 return l== nullptr ? GENERAL_ERROR : OK;
544 auto l = wrap.getLayerByName("1");
545 auto r = wrap.getLayerByName("2");
547 ASSERT_NO_THROW(CNNNetSwapLayers(l, r));
549 ASSERT_CONNECTION(4, 2);
550 ASSERT_CONNECTION(4, 1);
551 ASSERT_CONNECTION(2, 1);
552 ASSERT_CONNECTION(2, 3);
553 ASSERT_CONNECTION(1, 3);
556 TEST_F(GraphToolsTest, CNNNetworkInsertLayerThrowsForNullPointers) {
557 CNNLayerPtr nullLayer;
558 ASSERT_ANY_THROW(CNNNetworkInsertLayer(nullLayer, nullLayer, nullLayer));
561 TEST_F(GraphToolsTest, CanNotInsertLayerIntoNonAdjiacendLayers) {
565 EXPECT_CALL(*mockNet, getInputsInfo(_)).WillRepeatedly(WithArg<0>(Invoke([&](InputsDataMap & maps){
569 EXPECT_CALL(*mockNet, getLayerByName(_,_,_)).WillRepeatedly(WithArgs<0,1>(Invoke([&](const char* name, InferenceEngine::CNNLayerPtr& l){
570 l = layerByName(name);
571 return l== nullptr ? GENERAL_ERROR : OK;
574 auto l = wrap.getLayerByName("1");
575 auto r = wrap.getLayerByName("3");
577 ASSERT_ANY_THROW(CNNNetworkInsertLayer(l, r, createGenericLayer("3")));
580 TEST_F(GraphToolsTest, CNNNetworkInsertLayerSimpleCase) {
583 EXPECT_CALL(*mockNet, getInputsInfo(_)).WillRepeatedly(WithArg<0>(Invoke([&](InputsDataMap & maps){
587 EXPECT_CALL(*mockNet, getLayerByName(_,_,_)).WillRepeatedly(WithArgs<0, 1>(Invoke([&](const char* name, InferenceEngine::CNNLayerPtr& l){
588 l = layerByName(name);
589 return l== nullptr ? GENERAL_ERROR : OK;
592 auto l = wrap.getLayerByName("1");
593 auto r = wrap.getLayerByName("2");
595 CNNNetworkInsertLayer(l, r, createGenericLayer("3"));
597 ASSERT_CONNECTION(3, 2);
598 ASSERT_CONNECTION(1, 3);
601 TEST_F(GraphToolsTest, CNNNetworkInsertLayerSimpleCaseWithMultipleOutputs) {
605 EXPECT_CALL(*mockNet, getInputsInfo(_)).WillRepeatedly(WithArg<0>(Invoke([&](InputsDataMap & maps){
609 EXPECT_CALL(*mockNet, getLayerByName(_,_,_)).WillRepeatedly(WithArgs<0,1>(Invoke([&](const char* name, InferenceEngine::CNNLayerPtr& l){
610 l = layerByName(name);
611 return l== nullptr ? GENERAL_ERROR : OK;
614 auto l = wrap.getLayerByName("1");
615 auto r = wrap.getLayerByName("3");
617 CNNNetworkInsertLayer(l, r, createGenericLayer("4"));
619 ASSERT_CONNECTION(4, 3);
620 ASSERT_CONNECTION(1, 4);
621 ASSERT_CONNECTION(1, 2);
625 TEST_F(GraphToolsTest, CNNNetworkInsertLayerSimpleCaseWithMultipleInputs) {
629 EXPECT_CALL(*mockNet, getInputsInfo(_)).WillRepeatedly(WithArg<0>(Invoke([&](InputsDataMap & maps){
633 EXPECT_CALL(*mockNet, getLayerByName(_,_,_)).WillRepeatedly(WithArgs<0,1>(Invoke([&](const char* name, InferenceEngine::CNNLayerPtr& l){
634 l = layerByName(name);
635 return l== nullptr ? GENERAL_ERROR : OK;
638 auto l = wrap.getLayerByName("3");
639 auto r = wrap.getLayerByName("2");
641 CNNNetworkInsertLayer(l, r, createGenericLayer("4"));
643 ASSERT_CONNECTION(4, 2);
644 ASSERT_CONNECTION(3, 4);
645 ASSERT_CONNECTION(1, 2);
648 TEST_F(GraphToolsTest, CNNNetworkInsertLayerSplitAndConcat) {
649 CONNECT_FROM_PORT(1, 0, 2);
650 CONNECT_FROM_PORT(1, 1, 2);
651 CONNECT_FROM_PORT(1, 2, 3);
653 EXPECT_CALL(*mockNet, getInputsInfo(_)).WillRepeatedly(WithArg<0>(Invoke([&](InputsDataMap & maps){
657 EXPECT_CALL(*mockNet, getLayerByName(_,_,_)).WillRepeatedly(WithArgs<0,1>(Invoke([&](const char* name, InferenceEngine::CNNLayerPtr& l){
658 l = layerByName(name);
659 return l== nullptr ? GENERAL_ERROR : OK;
662 auto l = wrap.getLayerByName("1");
663 auto r = wrap.getLayerByName("2");
664 auto r2 = wrap.getLayerByName("3");
666 CNNNetworkInsertLayer(l, r, createGenericLayer("4"), 1);
667 CNNNetworkInsertLayer(l, r2, createGenericLayer("5"), 2);
669 ASSERT_PORT_CONNECTION(1, 0, 2, 0);
670 ASSERT_PORT_CONNECTION(1, 1, 4, 0);
671 ASSERT_PORT_CONNECTION(4, 0, 2, 1);
672 ASSERT_PORT_CONNECTION(1, 2, 5, 0);
673 ASSERT_CONNECTION(5, 3);
677 TEST_F(GraphToolsTest, CNNNetworkInsertAfterLastLayer) {
680 EXPECT_CALL(*mockNet, getInputsInfo(_)).WillRepeatedly(WithArg<0>(Invoke([&](InputsDataMap & maps){
684 EXPECT_CALL(*mockNet, getLayerByName(_,_,_)).WillRepeatedly(WithArgs<0, 1>(Invoke([&](const char* name, InferenceEngine::CNNLayerPtr& l){
685 l = layerByName(name);
686 return l== nullptr ? GENERAL_ERROR : OK;
689 auto l = wrap.getLayerByName("2");
691 CNNNetworkInsertLayer(l, nullptr, createGenericLayer("3"));
693 ASSERT_CONNECTION(1, 2);
694 ASSERT_CONNECTION(2, 3);
697 TEST_F(GraphToolsTest, CNNNetworkInsertAfterAll) {
701 EXPECT_CALL(*mockNet, getInputsInfo(_)).WillRepeatedly(WithArg<0>(Invoke([&](InputsDataMap & maps){
705 EXPECT_CALL(*mockNet, getLayerByName(_,_,_)).WillRepeatedly(WithArgs<0, 1>(Invoke([&](const char* name, InferenceEngine::CNNLayerPtr& l){
706 l = layerByName(name);
707 return l== nullptr ? GENERAL_ERROR : OK;
710 CNNNetworkInsertLayer(wrap.getLayerByName("1"), nullptr, createGenericLayer("5"));
712 ASSERT_CONNECTION(1, 5);
713 ASSERT_CONNECTION(5, 2);
714 ASSERT_CONNECTION(5, 3);
717 TEST_F(GraphToolsTest, CNNNetworkInsertAllAfterSplit) {
719 CONNECT_FROM_PORT(1, 0, 2);
720 CONNECT_FROM_PORT(1, 1, 3);
722 EXPECT_CALL(*mockNet, getInputsInfo(_)).WillRepeatedly(WithArg<0>(Invoke([&](InputsDataMap & maps){
726 EXPECT_CALL(*mockNet, getLayerByName(_,_,_)).WillRepeatedly(WithArgs<0, 1>(Invoke([&](const char* name, InferenceEngine::CNNLayerPtr& l){
727 l = layerByName(name);
728 return l== nullptr ? GENERAL_ERROR : OK;
731 CNNNetworkInsertLayer(wrap.getLayerByName("1"), nullptr, createGenericLayer("5"));
733 ASSERT_CONNECTION(1, 5);
734 ASSERT_CONNECTION(5, 2);
735 ASSERT_CONNECTION(1, 3);
738 TEST_F(GraphToolsTest, CNNNetworkInsert1AfterSplit) {
740 CONNECT_FROM_PORT(1, 0, 2);
741 CONNECT_FROM_PORT(1, 1, 3);
742 CONNECT_FROM_PORT(1, 2, 4);
744 EXPECT_CALL(*mockNet, getInputsInfo(_)).WillRepeatedly(WithArg<0>(Invoke([&](InputsDataMap & maps){
748 EXPECT_CALL(*mockNet, getLayerByName(_,_,_)).WillRepeatedly(WithArgs<0, 1>(Invoke([&](const char* name, InferenceEngine::CNNLayerPtr& l){
749 l = layerByName(name);
750 return l== nullptr ? GENERAL_ERROR : OK;
753 CNNNetworkInsertLayer(wrap.getLayerByName("1"), wrap.getLayerByName("4"), createGenericLayer("5"));
755 ASSERT_CONNECTION(1, 2);
756 ASSERT_CONNECTION(1, 3);
757 ASSERT_CONNECTION(1, 5);
758 ASSERT_CONNECTION(5, 4);
762 TEST_F(GraphToolsTest, CNNNetworkInsertAfter2ConnectionsToEltwise) {
763 // multiple 1->2 connections like square operation using eltwise mull with itself
767 EXPECT_CALL(*mockNet, getInputsInfo(_)).WillRepeatedly(WithArg<0>(Invoke([&](InputsDataMap & maps){
771 EXPECT_CALL(*mockNet, getLayerByName(_,_,_)).WillRepeatedly(WithArgs<0, 1>(Invoke([&](const char* name, InferenceEngine::CNNLayerPtr& l){
772 l = layerByName(name);
773 return l== nullptr ? GENERAL_ERROR : OK;
776 CNNNetworkInsertLayer(wrap.getLayerByName("1"), wrap.getLayerByName("2"), createGenericLayer("5"));
778 ASSERT_CONNECTION(1, 5);
779 ASSERT_MN_CONNECTIONS(5, 2, 1, 2);
783 TEST_F(GraphToolsTest, CNNNetworkRemoveNullPointerLayer) {
785 CONNECT_FROM_PORT(1, 0, 2);
786 CONNECT_FROM_PORT(1, 1, 3);
787 CONNECT_FROM_PORT(1, 2, 4);
789 EXPECT_CALL(*mockNet, getInputsInfo(_)).WillRepeatedly(WithArg<0>(Invoke([&](InputsDataMap & maps){
793 EXPECT_CALL(*mockNet, getLayerByName(_,_,_)).WillRepeatedly(WithArgs<0, 1>(Invoke([&](const char* name, InferenceEngine::CNNLayerPtr& l){
794 l = layerByName(name);
795 return l== nullptr ? GENERAL_ERROR : OK;
798 ASSERT_ANY_THROW(CNNNetworkRemoveLayer(nullptr));
801 TEST_F(GraphToolsTest, CNNNetworkRemoveInputOrOutputLayer) {
803 CONNECT_FROM_PORT(1, 0, 2);
804 CONNECT_FROM_PORT(2, 0, 3);
805 CONNECT_FROM_PORT(1, 0, 3);
807 EXPECT_CALL(*mockNet, getInputsInfo(_)).WillRepeatedly(WithArg<0>(Invoke([&](InputsDataMap & maps){
811 EXPECT_CALL(*mockNet, getLayerByName(_,_,_)).WillRepeatedly(WithArgs<0, 1>(Invoke([&](const char* name, InferenceEngine::CNNLayerPtr& l){
812 l = layerByName(name);
813 return l== nullptr ? GENERAL_ERROR : OK;
816 ASSERT_ANY_THROW(CNNNetworkRemoveLayer(wrap.getLayerByName("1")));
817 ASSERT_ANY_THROW(CNNNetworkRemoveLayer(wrap.getLayerByName("3")));
820 TEST_F(GraphToolsTest, CNNNetworkRemoveLayerThaHas2Outputs) {
822 CONNECT_FROM_PORT(1, 0, 2);
823 CONNECT_FROM_PORT(2, 0, 3);
824 CONNECT_FROM_PORT(2, 0, 4);
825 CONNECT_FROM_PORT(1, 0, 3);
826 CONNECT_FROM_PORT(5, 0, 4);
828 EXPECT_CALL(*mockNet, getInputsInfo(_)).WillRepeatedly(WithArg<0>(Invoke([&](InputsDataMap & maps){
832 EXPECT_CALL(*mockNet, getLayerByName(_,_,_)).WillRepeatedly(WithArgs<0, 1>(Invoke([&](const char* name, InferenceEngine::CNNLayerPtr& l){
833 l = layerByName(name);
834 return l== nullptr ? GENERAL_ERROR : OK;
837 CNNNetworkRemoveLayer(wrap.getLayerByName("2"));
839 ASSERT_2_CONNECTIONS(1, 3);
840 ASSERT_CONNECTION(1, 4);
841 ASSERT_CONNECTION(5, 4);
843 // means all remained references removed
844 ASSERT_NO_CONNECTION(1, 2);
845 ASSERT_NO_CONNECTION(2, 2);
846 ASSERT_NO_CONNECTION(3, 2);
847 ASSERT_NO_CONNECTION(4, 2);
850 TEST_F(GraphToolsTest, CNNNetworkRemoveLayerSplit) {
852 CONNECT_FROM_PORT(1, 0, 2);
853 CONNECT_FROM_PORT(1, 1, 3);
854 CONNECT_FROM_PORT(2, 0, 3);
856 EXPECT_CALL(*mockNet, getInputsInfo(_)).WillRepeatedly(WithArg<0>(Invoke([&](InputsDataMap & maps){
860 EXPECT_CALL(*mockNet, getLayerByName(_,_,_)).WillRepeatedly(WithArgs<0, 1>(Invoke([&](const char* name, InferenceEngine::CNNLayerPtr& l){
861 l = layerByName(name);
862 return l== nullptr ? GENERAL_ERROR : OK;
865 CNNNetworkRemoveLayer(wrap.getLayerByName("2"));
867 ASSERT_2_CONNECTIONS(1, 3);
868 // means all remained references removed
869 ASSERT_NO_CONNECTION(1, 2);
870 ASSERT_NO_CONNECTION(2, 2);
871 ASSERT_NO_CONNECTION(3, 2);
874 TEST_F(GraphToolsTest, CNNNetworkRemoveLayerSplit2) {
876 CONNECT_FROM_PORT(1, 0, 2);
877 CONNECT_FROM_PORT(1, 0, 3);
878 CONNECT_FROM_PORT(1, 0, 4);
879 CONNECT_FROM_PORT(1, 1, 4);
880 CONNECT_FROM_PORT(1, 2, 5);
882 CONNECT_FROM_PORT(2, 0, 3);
883 CONNECT_FROM_PORT(2, 0, 4);
884 CONNECT_FROM_PORT(2, 0, 5);
886 EXPECT_CALL(*mockNet, getInputsInfo(_)).WillRepeatedly(WithArg<0>(Invoke([&](InputsDataMap & maps){
890 EXPECT_CALL(*mockNet, getLayerByName(_,_,_)).WillRepeatedly(WithArgs<0, 1>(Invoke([&](const char* name, InferenceEngine::CNNLayerPtr& l){
891 l = layerByName(name);
892 return l== nullptr ? GENERAL_ERROR : OK;
895 CNNNetworkRemoveLayer(wrap.getLayerByName("2"));
897 ASSERT_2_CONNECTIONS(1, 3);
898 ASSERT_3_CONNECTIONS(1, 4);
899 ASSERT_2_CONNECTIONS(1, 5);
901 // means all remained references removed
902 ASSERT_NO_CONNECTION(1, 2);
903 ASSERT_NO_CONNECTION(2, 2);
904 ASSERT_NO_CONNECTION(3, 2);
905 ASSERT_NO_CONNECTION(4, 2);
906 ASSERT_NO_CONNECTION(5, 2);
909 TEST_F(GraphToolsTest, CNNNetworkRemoveSimpleLayer) {
911 CONNECT_FROM_PORT(1, 0, 2);
912 CONNECT_FROM_PORT(2, 0, 3);
914 EXPECT_CALL(*mockNet, getInputsInfo(_)).WillRepeatedly(WithArg<0>(Invoke([&](InputsDataMap & maps){
918 EXPECT_CALL(*mockNet, getLayerByName(_,_,_)).WillRepeatedly(WithArgs<0, 1>(Invoke([&](const char* name, InferenceEngine::CNNLayerPtr& l){
919 l = layerByName(name);
920 return l== nullptr ? GENERAL_ERROR : OK;
923 CNNNetworkRemoveLayer(wrap.getLayerByName("2"));
925 ASSERT_CONNECTION(1, 3);
927 // means all remained references removed
928 ASSERT_NO_CONNECTION(1, 2);
929 ASSERT_NO_CONNECTION(2, 2);
930 ASSERT_NO_CONNECTION(3, 2);
934 //TEST_F(GraphToolsTest, CNNNetworkInsertLayerBeforeAll) {
937 // EXPECT_CALL(mockNet, GetInputsInfo(_)).WillRepeatedly(WithArg<0>(Invoke([&](InputsDataMap & maps){
938 // prepareInputs(maps);
941 // EXPECT_CALL(mockNet, getLayerByName(_,_,_)).WillRepeatedly(WithArgs<0, 1>(Invoke([&](const char* name, InferenceEngine::CNNLayerPtr& l){
942 // l = layerByName(name);
943 // return l== nullptr ? GENERAL_ERROR : OK;
946 // CNNNetworkInsertLayer(wrap.getLayerByName("1"), nullptr, createGenericLayer("3"));
948 // ASSERT_STREQ("2", wrap.getLayerByName("3")->outData[0]->getInputTo().begin()->second->name.c_str());
949 // ASSERT_STREQ("1", CNNNetPrevLayerName(wrap.getLayerByName("3")).c_str());
950 // ASSERT_STREQ("3", wrap.getLayerByName("1")->outData[0]->getInputTo().begin()->second->name.c_str());
951 // ASSERT_STREQ("3", CNNNetPrevLayerName(wrap.getLayerByName("2")).c_str());