Publishing R5 content (#72)
[platform/upstream/dldt.git] / inference-engine / tests / unit / cnn_network / cnn_network_impl_test.cpp
1 // Copyright (C) 2018 Intel Corporation
2 // SPDX-License-Identifier: Apache-2.0
3 //
4
5 #include <gtest/gtest.h>
6 #include "cnn_network_impl.hpp"
7 #include <../graph_tools/graph_test_base.hpp>
8
9 using namespace testing;
10 using namespace InferenceEngine;
11 using namespace InferenceEngine::details;
12 using namespace std;
13 using namespace GraphTest;
14
15 class CNNNetworkImplTest : public GraphTestsBase {
16 public:
17     StatusCode sts = OK;
18 /**
19  * @brief connect layers with wrong input data
20  * @param x - output layer index
21  * @param y - input layer index
22  * @param wrongID - data index, which falsely displayed among inputs in y
23  */
24     void CONNECT_WRONGLY(int x, int y, int wrongID) {
25         datas[x].front()->getInputTo()[std::to_string(y)] = layers[y];
26         layers[y]->insData.push_back(datas[wrongID].front());
27         lhsLayers.insert(layers[x]);
28         rhsLayers.insert(layers[y]);
29     }
30 /**
31  * @brief connect layers where no input data in layer
32  * @param x - output layer index
33  * @param y - input layer index
34  */
35     void CONNECT_WITHOUT_INS_DATA(int x, int y) {
36         datas[x].front()->getInputTo()[std::to_string(y)] = layers[y];
37         lhsLayers.insert(layers[x]);
38         rhsLayers.insert(layers[y]);
39     }
40 /**
41  * @brief connect layers with wrong data name
42  * @param x - output layer index
43  * @param y - input layer index
44  * @param wrongName - wrong data name, which displayed between x and y
45  */
46     void CONNECT_WITH_DATA_NAME(int x, int y, int name) {
47         datas[x].front()->name = std::to_string(name);
48         datas[x].front()->getInputTo()[std::to_string(y)] = layers[y];
49         layers[y]->insData.push_back(datas[x].front());
50         lhsLayers.insert(layers[x]);
51         rhsLayers.insert(layers[y]);
52     }
53 /**
54  * @brief connect layers with wrong layer name
55  * @param x - output layer index
56  * @param y - input layer index
57  * @param name - wrong layer name, which displayed instead of y
58  */
59     void CONNECT_WITH_LAYER_NAME(int x, int y, int name) {
60         layers[y]->name = std::to_string(name);
61         datas[x].front()->getInputTo()[std::to_string(y)] = layers[y];
62         layers[y]->insData.push_back(datas[x].front());
63         lhsLayers.insert(layers[x]);
64         rhsLayers.insert(layers[y]);
65     }
66 /**
67  * @brief insert data, which has no creator layer, but it is, into layer
68  * @param x - data input to y index
69  * @param y - input layer index
70  */
71     void CONNECT_WITHOUT_CREATOR_LAYER_WHICH_EXIST(int x, int y) {
72         datas[x].front()->getInputTo()[std::to_string(y)] = layers[y];
73         datas[x].front()->getCreatorLayer() = std::weak_ptr<CNNLayer>();
74         layers[y]->insData.push_back(datas[x].front());
75         lhsLayers.insert(layers[x]);
76         rhsLayers.insert(layers[y]);
77     }
78 /**
79  * @brief insert data, which has no creator layer, into layer
80  * @param x - data input to y index
81  * @param y - input layer index
82  */
83     void CONNECT_WITHOUT_CREATOR_LAYER(int x, int y) {
84         datas[x].front()->getInputTo()[std::to_string(y)] = layers[y];
85         datas[x].front()->getCreatorLayer() = std::weak_ptr<CNNLayer>();
86         layers[x] = nullptr;
87         layers[y]->insData.push_back(datas[x].front());
88         rhsLayers.insert(layers[y]);
89     }
90 /**
91  * @brief connect and specify input layer type
92  * @param x - output  layer index
93  * @param y - input layer index
94  */
95     void CONNECT_WITH_INPUT_TYPE(int x, int y, std::string type) {
96         datas[x].front()->getInputTo()[std::to_string(y)] = layers[y];
97         layers[y]->insData.push_back(datas[x].front());
98         layers[x]->type = type;
99         lhsLayers.insert(layers[x]);
100         rhsLayers.insert(layers[y]);
101     }
102 };
103
104 TEST_F(CNNNetworkImplTest, throwOnWrongInputType) {
105     MockCNNNetworkImpl network;
106     CONNECT_WITH_INPUT_TYPE(1, 2, "const");
107
108     EXPECT_CALL(network, getInputsInfo(_)).Times(2).WillRepeatedly(WithArg<0>(Invoke([&](InputsDataMap& maps) {
109         prepareInputs(maps);
110     })));
111
112     ASSERT_THROW(network.validateNetwork(), InferenceEngineException);
113 }
114
115 TEST_F(CNNNetworkImplTest, severalRightInputTypes) {
116     MockCNNNetworkImpl network;
117
118     CONNECT(1, 2);
119     CONNECT_WITH_INPUT_TYPE(3, 1, "input");
120     CONNECT_WITH_INPUT_TYPE(0, 1, "input");
121
122     EXPECT_CALL(network, getInputsInfo(_)).Times(2).WillRepeatedly(WithArg<0>(Invoke([&](InputsDataMap& maps) {
123         prepareInputs(maps);
124     })));
125
126     ASSERT_NO_THROW(network.validateNetwork());
127 }
128
129 TEST_F(CNNNetworkImplTest, noCreatorLayers) {
130     MockCNNNetworkImpl network;
131
132     CONNECT(1, 2);
133     CONNECT(3, 1);
134     CONNECT_WITHOUT_CREATOR_LAYER(0, 1);
135     EXPECT_CALL(network, getInputsInfo(_)).Times(2).WillRepeatedly(WithArg<0>(Invoke([&](InputsDataMap& maps) {
136         prepareInputs(maps);
137     })));
138
139     ASSERT_THROW(network.validateNetwork(), InferenceEngineException);
140 }
141
142 TEST_F(CNNNetworkImplTest, dataHasNoCreatorLayerButItIs) {
143     MockCNNNetworkImpl network;
144
145     CONNECT(1, 2);
146     CONNECT(3, 1);
147     CONNECT_WITHOUT_CREATOR_LAYER_WHICH_EXIST(0, 1);
148     EXPECT_CALL(network, getInputsInfo(_)).Times(2).WillRepeatedly(WithArg<0>(Invoke([&](InputsDataMap& maps) {
149         prepareInputs(maps);
150     })));
151
152     ASSERT_THROW(network.validateNetwork(), InferenceEngineException);
153 }
154
155 TEST_F(CNNNetworkImplTest, layerNameIsNotUnique) {
156     MockCNNNetworkImpl network;
157
158     CONNECT(1, 2);
159     CONNECT_WITH_LAYER_NAME(2, 3, 1);
160
161     EXPECT_CALL(network, getInputsInfo(_)).Times(2).WillRepeatedly(WithArg<0>(Invoke([&](InputsDataMap& maps) {
162         prepareInputs(maps);
163     })));
164
165     ASSERT_THROW(network.validateNetwork(), InferenceEngineException);
166 }
167
168 TEST_F(CNNNetworkImplTest, dataNameIsNotUnique) {
169     MockCNNNetworkImpl network;
170
171     CONNECT(1, 2);
172     CONNECT(1, 4);
173     CONNECT_WITH_DATA_NAME(2, 3, 1);
174
175     EXPECT_CALL(network, getInputsInfo(_)).Times(2).WillRepeatedly(WithArg<0>(Invoke([&](InputsDataMap& maps) {
176         prepareInputs(maps);
177     })));
178
179     ASSERT_THROW(network.validateNetwork(), InferenceEngineException);
180 }
181
182 TEST_F(CNNNetworkImplTest, layerDoesNotHaveInputData) {
183
184     MockCNNNetworkImpl network;
185
186     CONNECT(1, 2);
187     CONNECT(9, 3);
188     CONNECT(2, 9);
189     CONNECT_WITHOUT_INS_DATA(1, 3);
190
191     EXPECT_CALL(network, getInputsInfo(_)).Times(2).WillRepeatedly(WithArg<0>(Invoke([&](InputsDataMap& maps) {
192         prepareInputs(maps);
193     })));
194
195     ASSERT_THROW(network.validateNetwork(), InferenceEngineException);
196 }
197
198 TEST_F(CNNNetworkImplTest, layerDataNotCoresspondEachOtherOneInput) {
199
200     MockCNNNetworkImpl network;
201
202     CONNECT(1, 2);
203     CONNECT(2, 3);
204     CONNECT_WRONGLY(3, 4, 2);
205
206     EXPECT_CALL(network, getInputsInfo(_)).Times(2).WillRepeatedly(WithArg<0>(Invoke([&](InputsDataMap& maps) {
207         prepareInputs(maps);
208     })));
209
210     ASSERT_THROW(network.validateNetwork(), InferenceEngineException);
211 }
212
213 TEST_F(CNNNetworkImplTest, layerDataNotCoresspondEachOtherTwoInputs) {
214
215     MockCNNNetworkImpl network;
216
217     CONNECT(1, 2);
218     CONNECT(2, 3);
219     CONNECT(7, 4);
220     CONNECT_WRONGLY(3, 4, 2);
221
222     EXPECT_CALL(network, getInputsInfo(_)).Times(2).WillRepeatedly(WithArg<0>(Invoke([&](InputsDataMap& maps) {
223         prepareInputs(maps);
224     })));
225     ASSERT_THROW(network.validateNetwork(), InferenceEngineException);
226 }
227
228 TEST_F(CNNNetworkImplTest, canGetName) {
229     InferenceEngine::details::CNNNetworkImpl net;
230     net.setName("myName");
231     const char* p = "33333333333";
232     char name[20];
233     net.getName(name, sizeof(name));
234     ASSERT_STREQ(name, "myName");
235 }
236
237 TEST_F(CNNNetworkImplTest, canGetNameStr) {
238     InferenceEngine::details::CNNNetworkImpl net;
239     net.setName("myName");
240     auto name = net.getName();
241     ASSERT_STREQ(name.c_str(), "myName");
242 }
243
244 TEST_F(CNNNetworkImplTest, cycleIsDetectedInNetwork) {
245
246     MockCNNNetworkImpl network;
247
248     // 1->2->3-> 4->5->6->7-> 8
249     //       ^  |^        |  ^|
250     //       |  |└--------┘  ||
251     //       |  └------------┘|
252     //       └----------------┘
253
254     CONNECT(1, 2);
255     CONNECT(2, 3);
256     CONNECT(3, 4);
257     CONNECT(4, 5);
258     CONNECT(5, 6);
259     CONNECT(6, 7);
260     CONNECT(7, 8);
261     CONNECT(7, 4);
262     CONNECT(4, 8);
263     CONNECT(8, 3);
264
265     EXPECT_CALL(network, getInputsInfo(_)).Times(2).WillRepeatedly(WithArg<0>(Invoke([&](InputsDataMap& maps) {
266         prepareInputs(maps);
267     })));
268
269     ASSERT_THROW(network.validateNetwork(), InferenceEngineException);
270 }