1 // Copyright (C) 2018 Intel Corporation
2 // SPDX-License-Identifier: Apache-2.0
5 #include <gtest/gtest.h>
6 #include "cnn_network_impl.hpp"
7 #include <../graph_tools/graph_test_base.hpp>
9 using namespace testing;
10 using namespace InferenceEngine;
11 using namespace InferenceEngine::details;
13 using namespace GraphTest;
15 class CNNNetworkImplTest : public GraphTestsBase {
19 * @brief connect layers with wrong input data
20 * @param x - output layer index
21 * @param y - input layer index
22 * @param wrongID - data index, which falsely displayed among inputs in y
24 void CONNECT_WRONGLY(int x, int y, int wrongID) {
25 datas[x].front()->getInputTo()[std::to_string(y)] = layers[y];
26 layers[y]->insData.push_back(datas[wrongID].front());
27 lhsLayers.insert(layers[x]);
28 rhsLayers.insert(layers[y]);
31 * @brief connect layers where no input data in layer
32 * @param x - output layer index
33 * @param y - input layer index
35 void CONNECT_WITHOUT_INS_DATA(int x, int y) {
36 datas[x].front()->getInputTo()[std::to_string(y)] = layers[y];
37 lhsLayers.insert(layers[x]);
38 rhsLayers.insert(layers[y]);
41 * @brief connect layers with wrong data name
42 * @param x - output layer index
43 * @param y - input layer index
44 * @param wrongName - wrong data name, which displayed between x and y
46 void CONNECT_WITH_DATA_NAME(int x, int y, int name) {
47 datas[x].front()->name = std::to_string(name);
48 datas[x].front()->getInputTo()[std::to_string(y)] = layers[y];
49 layers[y]->insData.push_back(datas[x].front());
50 lhsLayers.insert(layers[x]);
51 rhsLayers.insert(layers[y]);
54 * @brief connect layers with wrong layer name
55 * @param x - output layer index
56 * @param y - input layer index
57 * @param name - wrong layer name, which displayed instead of y
59 void CONNECT_WITH_LAYER_NAME(int x, int y, int name) {
60 layers[y]->name = std::to_string(name);
61 datas[x].front()->getInputTo()[std::to_string(y)] = layers[y];
62 layers[y]->insData.push_back(datas[x].front());
63 lhsLayers.insert(layers[x]);
64 rhsLayers.insert(layers[y]);
67 * @brief insert data, which has no creator layer, but it is, into layer
68 * @param x - data input to y index
69 * @param y - input layer index
71 void CONNECT_WITHOUT_CREATOR_LAYER_WHICH_EXIST(int x, int y) {
72 datas[x].front()->getInputTo()[std::to_string(y)] = layers[y];
73 datas[x].front()->getCreatorLayer() = std::weak_ptr<CNNLayer>();
74 layers[y]->insData.push_back(datas[x].front());
75 lhsLayers.insert(layers[x]);
76 rhsLayers.insert(layers[y]);
79 * @brief insert data, which has no creator layer, into layer
80 * @param x - data input to y index
81 * @param y - input layer index
83 void CONNECT_WITHOUT_CREATOR_LAYER(int x, int y) {
84 datas[x].front()->getInputTo()[std::to_string(y)] = layers[y];
85 datas[x].front()->getCreatorLayer() = std::weak_ptr<CNNLayer>();
87 layers[y]->insData.push_back(datas[x].front());
88 rhsLayers.insert(layers[y]);
91 * @brief connect and specify input layer type
92 * @param x - output layer index
93 * @param y - input layer index
95 void CONNECT_WITH_INPUT_TYPE(int x, int y, std::string type) {
96 datas[x].front()->getInputTo()[std::to_string(y)] = layers[y];
97 layers[y]->insData.push_back(datas[x].front());
98 layers[x]->type = type;
99 lhsLayers.insert(layers[x]);
100 rhsLayers.insert(layers[y]);
104 TEST_F(CNNNetworkImplTest, throwOnWrongInputType) {
105 MockCNNNetworkImpl network;
106 CONNECT_WITH_INPUT_TYPE(1, 2, "const");
108 EXPECT_CALL(network, getInputsInfo(_)).Times(2).WillRepeatedly(WithArg<0>(Invoke([&](InputsDataMap& maps) {
112 ASSERT_THROW(network.validateNetwork(), InferenceEngineException);
115 TEST_F(CNNNetworkImplTest, severalRightInputTypes) {
116 MockCNNNetworkImpl network;
119 CONNECT_WITH_INPUT_TYPE(3, 1, "input");
120 CONNECT_WITH_INPUT_TYPE(0, 1, "input");
122 EXPECT_CALL(network, getInputsInfo(_)).Times(2).WillRepeatedly(WithArg<0>(Invoke([&](InputsDataMap& maps) {
126 ASSERT_NO_THROW(network.validateNetwork());
129 TEST_F(CNNNetworkImplTest, noCreatorLayers) {
130 MockCNNNetworkImpl network;
134 CONNECT_WITHOUT_CREATOR_LAYER(0, 1);
135 EXPECT_CALL(network, getInputsInfo(_)).Times(2).WillRepeatedly(WithArg<0>(Invoke([&](InputsDataMap& maps) {
139 ASSERT_THROW(network.validateNetwork(), InferenceEngineException);
142 TEST_F(CNNNetworkImplTest, dataHasNoCreatorLayerButItIs) {
143 MockCNNNetworkImpl network;
147 CONNECT_WITHOUT_CREATOR_LAYER_WHICH_EXIST(0, 1);
148 EXPECT_CALL(network, getInputsInfo(_)).Times(2).WillRepeatedly(WithArg<0>(Invoke([&](InputsDataMap& maps) {
152 ASSERT_THROW(network.validateNetwork(), InferenceEngineException);
155 TEST_F(CNNNetworkImplTest, layerNameIsNotUnique) {
156 MockCNNNetworkImpl network;
159 CONNECT_WITH_LAYER_NAME(2, 3, 1);
161 EXPECT_CALL(network, getInputsInfo(_)).Times(2).WillRepeatedly(WithArg<0>(Invoke([&](InputsDataMap& maps) {
165 ASSERT_THROW(network.validateNetwork(), InferenceEngineException);
168 TEST_F(CNNNetworkImplTest, dataNameIsNotUnique) {
169 MockCNNNetworkImpl network;
173 CONNECT_WITH_DATA_NAME(2, 3, 1);
175 EXPECT_CALL(network, getInputsInfo(_)).Times(2).WillRepeatedly(WithArg<0>(Invoke([&](InputsDataMap& maps) {
179 ASSERT_THROW(network.validateNetwork(), InferenceEngineException);
182 TEST_F(CNNNetworkImplTest, layerDoesNotHaveInputData) {
184 MockCNNNetworkImpl network;
189 CONNECT_WITHOUT_INS_DATA(1, 3);
191 EXPECT_CALL(network, getInputsInfo(_)).Times(2).WillRepeatedly(WithArg<0>(Invoke([&](InputsDataMap& maps) {
195 ASSERT_THROW(network.validateNetwork(), InferenceEngineException);
198 TEST_F(CNNNetworkImplTest, layerDataNotCoresspondEachOtherOneInput) {
200 MockCNNNetworkImpl network;
204 CONNECT_WRONGLY(3, 4, 2);
206 EXPECT_CALL(network, getInputsInfo(_)).Times(2).WillRepeatedly(WithArg<0>(Invoke([&](InputsDataMap& maps) {
210 ASSERT_THROW(network.validateNetwork(), InferenceEngineException);
213 TEST_F(CNNNetworkImplTest, layerDataNotCoresspondEachOtherTwoInputs) {
215 MockCNNNetworkImpl network;
220 CONNECT_WRONGLY(3, 4, 2);
222 EXPECT_CALL(network, getInputsInfo(_)).Times(2).WillRepeatedly(WithArg<0>(Invoke([&](InputsDataMap& maps) {
225 ASSERT_THROW(network.validateNetwork(), InferenceEngineException);
228 TEST_F(CNNNetworkImplTest, canGetName) {
229 InferenceEngine::details::CNNNetworkImpl net;
230 net.setName("myName");
231 const char* p = "33333333333";
233 net.getName(name, sizeof(name));
234 ASSERT_STREQ(name, "myName");
237 TEST_F(CNNNetworkImplTest, canGetNameStr) {
238 InferenceEngine::details::CNNNetworkImpl net;
239 net.setName("myName");
240 auto name = net.getName();
241 ASSERT_STREQ(name.c_str(), "myName");
244 TEST_F(CNNNetworkImplTest, cycleIsDetectedInNetwork) {
246 MockCNNNetworkImpl network;
248 // 1->2->3-> 4->5->6->7-> 8
252 // └----------------┘
265 EXPECT_CALL(network, getInputsInfo(_)).Times(2).WillRepeatedly(WithArg<0>(Invoke([&](InputsDataMap& maps) {
269 ASSERT_THROW(network.validateNetwork(), InferenceEngineException);