Release 18.03
[platform/upstream/armnn.git] / src / armnnTfParser / test / TestDependencies.cpp
1 //
2 // Copyright © 2017 Arm Ltd. All rights reserved.
3 // See LICENSE file in the project root for full license information.
4 //
5
6 #include <boost/test/unit_test.hpp>
7 #include "armnnTfParser/ITfParser.hpp"
8 #include "ParserPrototxtFixture.hpp"
9
10 BOOST_AUTO_TEST_SUITE(TensorflowParser)
11
12 // Graph which tests that nodes are re-ordered in the queue when they are encountered a second time.
13 // In this case R0 will be encountered first via R1 and then via R2. At that time
14 // we need to make sure that R0 (and the I on which it is dependent) is moved to the front again
15 // so that it is before both R1 and R2.
16 //    I
17 //    |
18 //    R0
19 //   / \'
20 //  R1  R2
21 //   \  |
22 //    \ R3
23 //     \|
24 //      O
25 struct RediscoveredDependenciesFixture : public ParserPrototxtFixture<armnnTfParser::ITfParser>
26 {
27     RediscoveredDependenciesFixture()
28     {
29         // input = tf.placeholder(tf.float32, 1, "input")
30         // relu0 = tf.nn.relu(input, "relu0")
31         // relu1 = tf.nn.relu(relu0, "relu1")
32         // relu2 = tf.nn.relu(relu0, "relu2")
33         // relu3 = tf.nn.relu(relu2, "relu3")
34         // output = tf.add(relu1, relu3, "output")
35         m_Prototext = R"(
36             node {
37               name: "input"
38               op: "Placeholder"
39               attr {
40                 key: "dtype"
41                 value {
42                   type: DT_FLOAT
43                 }
44               }
45               attr {
46                 key: "shape"
47                 value {
48                   shape {
49                     dim {
50                       size: 1
51                     }
52                   }
53                 }
54               }
55             }
56             node {
57               name: "relu0"
58               op: "Relu"
59               input: "input"
60               attr {
61                 key: "T"
62                 value {
63                   type: DT_FLOAT
64                 }
65               }
66             }
67             node {
68               name: "relu1"
69               op: "Relu"
70               input: "relu0"
71               attr {
72                 key: "T"
73                 value {
74                   type: DT_FLOAT
75                 }
76               }
77             }
78             node {
79               name: "relu2"
80               op: "Relu"
81               input: "relu0"
82               attr {
83                 key: "T"
84                 value {
85                   type: DT_FLOAT
86                 }
87               }
88             }
89             node {
90               name: "relu3"
91               op: "Relu"
92               input: "relu2"
93               attr {
94                 key: "T"
95                 value {
96                   type: DT_FLOAT
97                 }
98               }
99             }
100             node {
101               name: "output"
102               op: "Add"
103               input: "relu1"
104               input: "relu3"
105               attr {
106                 key: "T"
107                 value {
108                   type: DT_FLOAT
109                 }
110               }
111             }
112         )";
113         SetupSingleInputSingleOutput({ 1 }, "input", "output");
114     }
115 };
116
117 BOOST_FIXTURE_TEST_CASE(RediscoveredDependencies, RediscoveredDependenciesFixture)
118 {
119     RunTest<1>({1}, {2});
120 }
121
122 // Tests that a simple cycle in the tensorflow graph will be detected and an exception thrown, rather than the TfParser
123 // getting stuck in an infinite loop.
124 BOOST_AUTO_TEST_CASE(SimpleCycle)
125 {
126     const char* prototext = R"(
127 node {
128   name: "r1"
129   op: "Relu"
130   input: "r2"
131   attr {
132     key: "T"
133     value {
134       type: DT_FLOAT
135     }
136   }
137 }
138 node {
139   name: "r2"
140   op: "Relu"
141   input: "r1"
142   attr {
143     key: "T"
144     value {
145       type: DT_FLOAT
146     }
147   }
148 }
149     )";
150     armnnTfParser::ITfParserPtr parser = armnnTfParser::ITfParser::Create();
151     BOOST_CHECK_THROW(parser->CreateNetworkFromString(prototext, {}, { "r2" }), armnn::ParseException);
152 }
153
154 // Similar to the above SimpleCycle test, but has a single node which connects to itself.
155 BOOST_AUTO_TEST_CASE(SingleNodeCycle)
156 {
157     const char* prototext = R"(
158 node {
159   name: "r1"
160   op: "Relu"
161   input: "r1"
162   attr {
163     key: "T"
164     value {
165       type: DT_FLOAT
166     }
167   }
168 }
169     )";
170     armnnTfParser::ITfParserPtr parser = armnnTfParser::ITfParser::Create();
171     BOOST_CHECK_THROW(parser->CreateNetworkFromString(prototext, {}, { "r1" }), armnn::ParseException);
172 }
173
174 // Similar to the above SimpleCycle test, but with a more complicated graph.
175 //    I
176 //    |
177 //    A2---<---<-
178 //   / \'        |
179 //  R1  R2       |
180 //   \  |        |
181 //    \ R3       |
182 //     \|        |
183 //      A1-->--->|
184 //
185 BOOST_AUTO_TEST_CASE(ComplexCycle)
186 {
187     // input = tf.placeholder(tf.float32, 1, "input")
188     // add2 = tf.nn.relu(input, add1, "add2") // This line won't actually run in TF, because add1 is not yet defined
189     // relu1 = tf.nn.relu(relu0, "relu1")
190     // relu2 = tf.nn.relu(relu0, "relu2")
191     // relu3 = tf.nn.relu(relu2, "relu3")
192     // add1 = tf.add(relu1, relu3, "add1")
193     const char* prototext = R"(
194         node {
195             name: "input"
196             op: "Placeholder"
197             attr {
198             key: "dtype"
199             value {
200                 type: DT_FLOAT
201             }
202             }
203             attr {
204             key: "shape"
205             value {
206                 shape {
207                 dim {
208                     size: 1
209                 }
210                 }
211             }
212             }
213         }
214         node {
215             name: "add2"
216             op: "Add"
217             input: "input"
218             input: "add1"
219             attr {
220             key: "T"
221             value {
222                 type: DT_FLOAT
223             }
224             }
225         }
226         node {
227             name: "relu1"
228             op: "Relu"
229             input: "add2"
230             attr {
231             key: "T"
232             value {
233                 type: DT_FLOAT
234             }
235             }
236         }
237         node {
238             name: "relu2"
239             op: "Relu"
240             input: "add2"
241             attr {
242             key: "T"
243             value {
244                 type: DT_FLOAT
245             }
246             }
247         }
248         node {
249             name: "relu3"
250             op: "Relu"
251             input: "relu2"
252             attr {
253             key: "T"
254             value {
255                 type: DT_FLOAT
256             }
257             }
258         }
259         node {
260             name: "add1"
261             op: "Add"
262             input: "relu1"
263             input: "relu3"
264             attr {
265             key: "T"
266             value {
267                 type: DT_FLOAT
268             }
269             }
270         }
271     )";
272     armnnTfParser::ITfParserPtr parser = armnnTfParser::ITfParser::Create();
273     BOOST_CHECK_THROW(parser->CreateNetworkFromString(prototext, {}, { "add1" }), armnn::ParseException);
274 }
275
276 // Tests that a graph with an input that is not present throws a ParseException.
277 BOOST_AUTO_TEST_CASE(InvalidInput)
278 {
279     const char* prototext = R"(
280 node {
281   name: "r1"
282   op: "Relu"
283   input: "a-node-that-does-not-exist"
284   attr {
285     key: "T"
286     value {
287       type: DT_FLOAT
288     }
289   }
290 }
291     )";
292     armnnTfParser::ITfParserPtr parser = armnnTfParser::ITfParser::Create();
293     BOOST_CHECK_THROW(parser->CreateNetworkFromString(prototext, {}, { "r1" }), armnn::ParseException);
294 }
295
296 BOOST_AUTO_TEST_SUITE_END()