Release 18.08
[platform/upstream/armnn.git] / src / armnnTfParser / test / Constant.cpp
1 //
2 // Copyright © 2017 Arm Ltd. All rights reserved.
3 // See LICENSE file in the project root for full license information.
4 //
5
6 #include <boost/test/unit_test.hpp>
7
8 #include "armnnTfParser/ITfParser.hpp"
9
10 #include "ParserPrototxtFixture.hpp"
11
12 BOOST_AUTO_TEST_SUITE(TensorflowParser)
13
14 // Tests that a Const node in Tensorflow can be converted to a ConstLayer in armnn (as opposed to most
15 // Const nodes which are used as weight inputs for convolutions etc. and are therefore not converted to
16 // armnn ConstLayers).
17 struct ConstantFixture : public armnnUtils::ParserPrototxtFixture<armnnTfParser::ITfParser>
18 {
19     ConstantFixture()
20     {
21         // Input = tf.placeholder(tf.float32, name = "input")
22         // Const = tf.constant([17], tf.float32, [1])
23         // Output = tf.add(input, const, name = "output")
24         m_Prototext =
25             R"(
26 node {
27   name: "input"
28   op: "Placeholder"
29   attr {
30     key: "dtype"
31     value {
32       type: DT_FLOAT
33     }
34   }
35   attr {
36     key: "shape"
37     value {
38       shape {
39         unknown_rank: true
40       }
41     }
42   }
43 }
44 node {
45   name: "Const"
46   op: "Const"
47   attr {
48     key: "dtype"
49     value {
50       type: DT_FLOAT
51     }
52   }
53   attr {
54     key: "value"
55     value {
56       tensor {
57         dtype: DT_FLOAT
58         tensor_shape {
59           dim {
60             size: 1
61           }
62         }
63         float_val: 17.0
64       }
65     }
66   }
67 }
68 node {
69   name: "output"
70   op: "Add"
71   input: "input"
72   input: "Const"
73   attr {
74     key: "T"
75     value {
76       type: DT_FLOAT
77     }
78   }
79 }
80             )";
81         SetupSingleInputSingleOutput({ 1 }, "input", "output");
82     }
83 };
84
85 BOOST_FIXTURE_TEST_CASE(Constant, ConstantFixture)
86 {
87     RunTest<1>({1}, {18});
88 }
89
90
91 // Tests that a single Const node in Tensorflow can be used twice by a dependant node. This should result in only
92 // a single armnn ConstLayer being created.
93 struct ConstantReusedFixture : public armnnUtils::ParserPrototxtFixture<armnnTfParser::ITfParser>
94 {
95     ConstantReusedFixture()
96     {
97         // Const = tf.constant([17], tf.float32, [1])
98         // Output = tf.add(const, const, name = "output")
99         m_Prototext =
100             R"(
101 node {
102   name: "Const"
103   op: "Const"
104   attr {
105     key: "dtype"
106     value {
107       type: DT_FLOAT
108     }
109   }
110   attr {
111     key: "value"
112     value {
113       tensor {
114         dtype: DT_FLOAT
115         tensor_shape {
116           dim {
117             size: 1
118           }
119         }
120         float_val: 17.0
121       }
122     }
123   }
124 }
125 node {
126   name: "output"
127   op: "Add"
128   input: "Const"
129   input: "Const"
130   attr {
131     key: "T"
132     value {
133       type: DT_FLOAT
134     }
135   }
136 }
137             )";
138         Setup({}, { "output" });
139     }
140 };
141
142 BOOST_FIXTURE_TEST_CASE(ConstantReused, ConstantReusedFixture)
143 {
144     RunTest<1>({}, { { "output", { 34 } } });
145 }
146
147 template <int ListSize>
148 struct ConstantValueListFixture : public armnnUtils::ParserPrototxtFixture<armnnTfParser::ITfParser>
149 {
150     ConstantValueListFixture()
151     {
152         m_Prototext =
153             R"(
154 node {
155   name: "output"
156   op: "Const"
157   attr {
158     key: "dtype"
159     value {
160       type: DT_FLOAT
161     }
162   }
163   attr {
164     key: "value"
165     value {
166       tensor {
167         dtype: DT_FLOAT
168         tensor_shape {
169           dim {
170             size: 2
171           }
172           dim {
173             size: 3
174           }
175         })";
176
177         double value = 0.75;
178         for (int i = 0; i < ListSize; i++, value += 0.25)
179         {
180             m_Prototext += std::string("float_val : ") + std::to_string(value) + "\n";
181         }
182
183         m_Prototext +=
184             R"(
185       }
186     }
187   }
188 }
189             )";
190         Setup({}, { "output" });
191     }
192 };
193
194 using ConstantSingleValueListFixture = ConstantValueListFixture<1>;
195 using ConstantMultipleValueListFixture = ConstantValueListFixture<4>;
196 using ConstantMaxValueListFixture = ConstantValueListFixture<6>;
197
198 BOOST_FIXTURE_TEST_CASE(ConstantSingleValueList, ConstantSingleValueListFixture)
199 {
200     RunTest<2>({}, { { "output", { 0.75f, 0.75f, 0.75f, 0.75f, 0.75f, 0.75f } } });
201 }
202 BOOST_FIXTURE_TEST_CASE(ConstantMultipleValueList, ConstantMultipleValueListFixture)
203 {
204     RunTest<2>({}, { { "output", { 0.75f, 1.f, 1.25f, 1.5f,  1.5f,  1.5f } } });
205 }
206 BOOST_FIXTURE_TEST_CASE(ConstantMaxValueList, ConstantMaxValueListFixture)
207 {
208     RunTest<2>({}, { { "output", { 0.75f, 1.f, 1.25f, 1.50f, 1.75f, 2.f } } });
209 }
210
211 template <bool WithShape, bool WithContent, bool WithValueList>
212 struct ConstantCreateFixture : public armnnUtils::ParserPrototxtFixture<armnnTfParser::ITfParser>
213 {
214     ConstantCreateFixture()
215     {
216         m_Prototext =
217             R"(
218 node {
219     name: "output"
220     op: "Const"
221     attr {
222     key: "dtype"
223     value {
224         type: DT_FLOAT
225     }
226     }
227     attr {
228     key: "value"
229     value {
230         tensor {
231         dtype: DT_FLOAT
232             )";
233
234         if (WithShape)
235         {
236             m_Prototext +=
237                 R"(
238 tensor_shape {
239     dim {
240     size: 2
241     }
242     dim {
243     size: 2
244     }
245 }
246                 )";
247         }
248         else
249         {
250             m_Prototext +=
251                 R"(
252 tensor_shape {
253 }
254                 )";
255         }
256
257         if (WithContent)
258         {
259             m_Prototext +=
260                 R"(
261 tensor_content: "\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?"
262                 )";
263         }
264
265         if (WithValueList)
266         {
267             m_Prototext +=
268                 R"(
269 float_val: 1.0
270 float_val: 1.0
271 float_val: 1.0
272 float_val: 1.0
273 float_val: 1.0
274                 )";
275         }
276
277         m_Prototext +=
278             R"(
279             }
280         }
281     }
282 }
283             )";
284     }
285 };
286
287 using ConstantCreateNoValueListFixture = ConstantCreateFixture<true, false, true>;
288 using ConstantCreateNoValueList2Fixture = ConstantCreateFixture<true, false, false>;
289 using ConstantCreateNoContentFixture = ConstantCreateFixture<true, true, false>;
290 using ConstantCreateNoContent2Fixture = ConstantCreateFixture<true, false, false>;
291 using ConstantCreateNoShapeFixture = ConstantCreateFixture<false, false, false>;
292 using ConstantCreateNoShape2Fixture = ConstantCreateFixture<false, true, false>;
293 using ConstantCreateNoShape3Fixture = ConstantCreateFixture<false, false, true>;
294
295 BOOST_FIXTURE_TEST_CASE(ConstantCreateInvalidValueList, ConstantCreateNoValueListFixture)
296 {
297     BOOST_REQUIRE_THROW(Setup({}, { "output" }), armnn::ParseException);
298 }
299 BOOST_FIXTURE_TEST_CASE(ConstantCreateInvalidValueList2, ConstantCreateNoValueList2Fixture)
300 {
301     BOOST_REQUIRE_THROW(Setup({}, { "output" }), armnn::ParseException);
302 }
303 BOOST_FIXTURE_TEST_CASE(ConstantCreateInvalidContent, ConstantCreateNoContentFixture)
304 {
305     BOOST_REQUIRE_THROW(Setup({}, { "output" }), armnn::ParseException);
306 }
307 BOOST_FIXTURE_TEST_CASE(ConstantCreateInvalidShape, ConstantCreateNoShapeFixture)
308 {
309     BOOST_REQUIRE_THROW(Setup({}, { "output" }), armnn::ParseException);
310 }
311 BOOST_FIXTURE_TEST_CASE(ConstantCreateNoShape2, ConstantCreateNoShape2Fixture)
312 {
313     BOOST_REQUIRE_THROW(Setup({}, { "output" }), armnn::ParseException);
314 }
315 BOOST_FIXTURE_TEST_CASE(ConstantCreateNoShape3, ConstantCreateNoShape3Fixture)
316 {
317     Setup({}, { "output" });
318     RunTest<1>({}, { { "output", { 1.f, 1.f, 1.f, 1.f, 1.f } } });
319 }
320
321 BOOST_AUTO_TEST_SUITE_END()