1 // Copyright (C) 2019 Intel Corporation
2 // SPDX-License-Identifier: Apache-2.0
5 #include <gtest/gtest.h>
7 #include <transform/transform_network.hpp>
8 #include <ie_builders.hpp>
10 #include "builder_test.hpp"
12 using namespace testing;
13 using namespace InferenceEngine;
15 class TransformNetworkTest: public BuilderTestCommon {};
17 TEST_F(TransformNetworkTest, AddNewLayer) {
18 Builder::Network builder("test");
19 Transform::Network network(builder);
20 ASSERT_EQ(0, builder.size());
21 network.addLayer(Builder::InputLayer("in1").setPort(Port({1, 3, 27, 27})));
22 ASSERT_EQ(1, builder.size());
25 TEST_F(TransformNetworkTest, RemoveLayer) {
26 Builder::Network builder("test");
27 Transform::Network network(builder);
28 ASSERT_EQ(0, builder.size());
29 Transform::Layer layer = network.addLayer(Builder::InputLayer("in1").setPort(Port({1, 3, 27, 27})));
30 ASSERT_EQ(1, builder.size());
32 network.removeLayer(layer);
33 ASSERT_EQ(0, builder.size());
36 TEST_F(TransformNetworkTest, GetIncorrectPort) {
37 Builder::Network builder("test");
38 Transform::Network network(builder);
39 Transform::Layer layer = network.addLayer(Builder::InputLayer("in1").setPort(Port({1, 3, 27, 27})));
40 ASSERT_THROW(layer.getInPort(), InferenceEngine::details::InferenceEngineException);
41 ASSERT_THROW(layer.getOutPort(1), InferenceEngine::details::InferenceEngineException);
45 TEST_F(TransformNetworkTest, GetCorrectPort) {
46 Builder::Network builder("test");
47 Transform::Network network(builder);
48 Transform::Layer layer = network.addLayer(Builder::InputLayer("in1").setPort(Port({1, 3, 27, 27})));
49 ASSERT_NO_THROW(layer.getOutPort());
50 ASSERT_NO_THROW(layer.getOutPort(0));
53 TEST_F(TransformNetworkTest, GetLayerById) {
54 Builder::Network builder("test");
55 Transform::Network network(builder);
56 Transform::Layer layer = network.addLayer(Builder::InputLayer("in1").setPort(Port({1, 3, 27, 27})));
57 ASSERT_NO_THROW(network.getLayer(layer.getId()));
60 TEST_F(TransformNetworkTest, GetLayerByName) {
61 Builder::Network builder("test");
62 Transform::Network network(builder);
63 network.addLayer(Builder::InputLayer("in1").setPort(Port({1, 3, 27, 27})));
64 ASSERT_NO_THROW(network.getLayer("in1"));
67 TEST_F(TransformNetworkTest, ConnectTwoLayers) {
68 Builder::Network builder("test");
69 Transform::Network network(builder);
70 Transform::Layer input = network.addLayer(Builder::InputLayer("in1").setPort(Port({1, 3, 27, 27})));
71 Transform::Layer relu = network.addLayer(Builder::ReLULayer("relu1"));
72 ASSERT_EQ(2, builder.size());
73 ASSERT_EQ(0, builder.getConnections().size());
74 network.connect(input, relu);
75 ASSERT_EQ(1, builder.getConnections().size());
78 TEST_F(TransformNetworkTest, ConnectTwoPorts) {
79 Builder::Network builder("test");
80 Transform::Network network(builder);
81 Transform::Port inputPort = network.addLayer(Builder::InputLayer("in1").setPort(Port({1, 3, 27, 27}))).getOutPort();
82 Transform::Port reluPort = network.addLayer(Builder::ReLULayer("relu1")).getInPort();
83 ASSERT_EQ(2, builder.size());
84 ASSERT_EQ(0, builder.getConnections().size());
85 network.connect(inputPort, reluPort);
86 ASSERT_EQ(1, builder.getConnections().size());
89 TEST_F(TransformNetworkTest, DisconnectTwoLayers) {
90 Builder::Network builder("test");
91 Transform::Network network(builder);
92 Transform::Layer input = network.addLayer(Builder::InputLayer("in1").setPort(Port({1, 3, 27, 27})));
93 Transform::Layer relu = network.addLayer(Builder::ReLULayer("relu1"));
94 ASSERT_EQ(2, builder.size());
95 ASSERT_EQ(0, builder.getConnections().size());
96 network.connect(input, relu);
97 ASSERT_EQ(1, builder.getConnections().size());
98 network.disconnect(input, relu);
99 ASSERT_EQ(0, builder.getConnections().size());
102 TEST_F(TransformNetworkTest, DisonnectTwoPorts) {
103 Builder::Network builder("test");
104 Transform::Network network(builder);
105 Transform::Port inputPort = network.addLayer(Builder::InputLayer("in1").setPort(Port({1, 3, 27, 27}))).getOutPort();
106 Transform::Port reluPort = network.addLayer(Builder::ReLULayer("relu1")).getInPort();
107 ASSERT_EQ(2, builder.size());
108 ASSERT_EQ(0, builder.getConnections().size());
109 network.connect(inputPort, reluPort);
110 ASSERT_EQ(1, builder.getConnections().size());
111 network.disconnect(inputPort, reluPort);
112 ASSERT_EQ(0, builder.getConnections().size());
115 TEST_F(TransformNetworkTest, RemoveLayerAndConnection) {
116 Builder::Network builder("test");
117 Transform::Network network(builder);
118 Transform::Layer input = network.addLayer(Builder::InputLayer("in1").setPort(Port({1, 3, 27, 27})));
119 Transform::Layer relu = network.addLayer(Builder::ReLULayer("relu1"));
120 network.connect(input, relu);
121 ASSERT_EQ(1, builder.getConnections().size());
122 ASSERT_EQ(2, builder.size());
123 network.removeLayer(relu);
124 ASSERT_EQ(0, builder.getConnections().size());
125 ASSERT_EQ(1, builder.size());
128 TEST_F(TransformNetworkTest, GetInitializedConnection) {
129 Builder::Network builder("test");
130 Transform::Network network(builder);
131 Transform::Layer input = network.addLayer(Builder::InputLayer("in1").setPort(Port({1, 3, 27, 27})));
132 Transform::Layer relu = network.addLayer(Builder::ReLULayer("relu1"));
133 network.connect(input, relu);
134 ASSERT_EQ(input.getOutPort(), relu.getInPort().getConnection().getSource());
137 TEST_F(TransformNetworkTest, GetIncorrectConnections) {
138 Builder::Network builder("test");
139 Transform::Network network(builder);
140 Transform::Layer input = network.addLayer(Builder::InputLayer("in1").setPort(Port({1, 3, 27, 27})));
141 Transform::Layer relu = network.addLayer(Builder::ReLULayer("relu1"));
142 ASSERT_THROW(relu.getInPort().getConnection().getSource(), InferenceEngine::details::InferenceEngineException);
143 ASSERT_THROW(input.getOutPort().getConnection().getDestination(), InferenceEngine::details::InferenceEngineException);
144 ASSERT_NO_THROW(input.getOutPort().getConnection().getSource());
145 ASSERT_NO_THROW(relu.getInPort().getConnection().getDestination());
148 TEST_F(TransformNetworkTest, ConnectToSourcePortsFromConnection) {
149 Builder::Network builder("test");
150 Transform::Network network(builder);
151 Transform::Port inputPort = network.addLayer(Builder::InputLayer("in1").setPort(Port({1, 3, 27, 27}))).getOutPort();
152 Transform::Port reluPort = network.addLayer(Builder::ReLULayer("relu1")).getInPort();
153 ASSERT_EQ(2, builder.size());
154 ASSERT_EQ(0, builder.getConnections().size());
155 ASSERT_NO_THROW(inputPort.getConnection().setDestination(reluPort));
156 ASSERT_EQ(1, builder.getConnections().size());
159 TEST_F(TransformNetworkTest, ConnectWithTwoDestinations) {
160 Builder::Network builder("test");
161 Transform::Network network(builder);
162 Transform::Port inputPort = network.addLayer(Builder::InputLayer("in1").setPort(Port({1, 3, 27, 27}))).getOutPort();
163 Transform::Port reluPort1 = network.addLayer(Builder::ReLULayer("relu1")).getInPort();
164 Transform::Port reluPort2 = network.addLayer(Builder::ReLULayer("relu2")).getInPort();
165 ASSERT_EQ(3, builder.size());
166 ASSERT_EQ(0, builder.getConnections().size());
167 ASSERT_NO_THROW(inputPort.getConnection().setDestination(reluPort1));
168 ASSERT_NO_THROW(inputPort.getConnection().addDestination(reluPort2));
169 ASSERT_THROW(inputPort.getConnection().addDestination(reluPort2), InferenceEngine::details::InferenceEngineException);
170 ASSERT_EQ(2, builder.getConnections().size());
171 ASSERT_THROW(inputPort.getConnection().setDestination(reluPort2), InferenceEngine::details::InferenceEngineException);
172 ASSERT_NO_THROW(inputPort.getConnection().setDestinations({reluPort2, reluPort1}));
173 ASSERT_EQ(2, builder.getConnections().size());
176 TEST_F(TransformNetworkTest, ConnectToDestinationPortsFromConnection) {
177 Builder::Network builder("test");
178 Transform::Network network(builder);
179 Transform::Port inputPort = network.addLayer(Builder::InputLayer("in1").setPort(Port({1, 3, 27, 27}))).getOutPort();
180 Transform::Port reluPort = network.addLayer(Builder::ReLULayer("relu1")).getInPort();
181 ASSERT_EQ(2, builder.size());
182 ASSERT_EQ(0, builder.getConnections().size());
183 reluPort.getConnection().setSource(inputPort);
184 ASSERT_EQ(1, builder.getConnections().size());