2 // Copyright © 2017 Arm Ltd. All rights reserved.
3 // See LICENSE file in the project root for full license information.
6 #include <boost/test/unit_test.hpp>
8 #include "armnnTfParser/ITfParser.hpp"
10 #include "ParserPrototxtFixture.hpp"
12 BOOST_AUTO_TEST_SUITE(TensorflowParser)
14 // Tests that a Const node in Tensorflow can be converted to a ConstLayer in armnn (as opposed to most
15 // Const nodes which are used as weight inputs for convolutions etc. and are therefore not converted to
16 // armnn ConstLayers).
17 struct ConstantFixture : public armnnUtils::ParserPrototxtFixture<armnnTfParser::ITfParser>
21 // Input = tf.placeholder(tf.float32, name = "input")
22 // Const = tf.constant([17], tf.float32, [1])
23 // Output = tf.add(input, const, name = "output")
81 SetupSingleInputSingleOutput({ 1 }, "input", "output");
85 BOOST_FIXTURE_TEST_CASE(Constant, ConstantFixture)
87 RunTest<1>({1}, {18});
91 // Tests that a single Const node in Tensorflow can be used twice by a dependant node. This should result in only
92 // a single armnn ConstLayer being created.
93 struct ConstantReusedFixture : public armnnUtils::ParserPrototxtFixture<armnnTfParser::ITfParser>
95 ConstantReusedFixture()
97 // Const = tf.constant([17], tf.float32, [1])
98 // Output = tf.add(const, const, name = "output")
138 Setup({}, { "output" });
142 BOOST_FIXTURE_TEST_CASE(ConstantReused, ConstantReusedFixture)
144 RunTest<1>({}, { { "output", { 34 } } });
147 template <int ListSize>
148 struct ConstantValueListFixture : public armnnUtils::ParserPrototxtFixture<armnnTfParser::ITfParser>
150 ConstantValueListFixture()
178 for (int i = 0; i < ListSize; i++, value += 0.25)
180 m_Prototext += std::string("float_val : ") + std::to_string(value) + "\n";
190 Setup({}, { "output" });
194 using ConstantSingleValueListFixture = ConstantValueListFixture<1>;
195 using ConstantMultipleValueListFixture = ConstantValueListFixture<4>;
196 using ConstantMaxValueListFixture = ConstantValueListFixture<6>;
198 BOOST_FIXTURE_TEST_CASE(ConstantSingleValueList, ConstantSingleValueListFixture)
200 RunTest<2>({}, { { "output", { 0.75f, 0.75f, 0.75f, 0.75f, 0.75f, 0.75f } } });
202 BOOST_FIXTURE_TEST_CASE(ConstantMultipleValueList, ConstantMultipleValueListFixture)
204 RunTest<2>({}, { { "output", { 0.75f, 1.f, 1.25f, 1.5f, 1.5f, 1.5f } } });
206 BOOST_FIXTURE_TEST_CASE(ConstantMaxValueList, ConstantMaxValueListFixture)
208 RunTest<2>({}, { { "output", { 0.75f, 1.f, 1.25f, 1.50f, 1.75f, 2.f } } });
211 template <bool WithShape, bool WithContent, bool WithValueList>
212 struct ConstantCreateFixture : public armnnUtils::ParserPrototxtFixture<armnnTfParser::ITfParser>
214 ConstantCreateFixture()
261 tensor_content: "\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?"
287 using ConstantCreateNoValueListFixture = ConstantCreateFixture<true, false, true>;
288 using ConstantCreateNoValueList2Fixture = ConstantCreateFixture<true, false, false>;
289 using ConstantCreateNoContentFixture = ConstantCreateFixture<true, true, false>;
290 using ConstantCreateNoContent2Fixture = ConstantCreateFixture<true, false, false>;
291 using ConstantCreateNoShapeFixture = ConstantCreateFixture<false, false, false>;
292 using ConstantCreateNoShape2Fixture = ConstantCreateFixture<false, true, false>;
293 using ConstantCreateNoShape3Fixture = ConstantCreateFixture<false, false, true>;
295 BOOST_FIXTURE_TEST_CASE(ConstantCreateInvalidValueList, ConstantCreateNoValueListFixture)
297 BOOST_REQUIRE_THROW(Setup({}, { "output" }), armnn::ParseException);
299 BOOST_FIXTURE_TEST_CASE(ConstantCreateInvalidValueList2, ConstantCreateNoValueList2Fixture)
301 BOOST_REQUIRE_THROW(Setup({}, { "output" }), armnn::ParseException);
303 BOOST_FIXTURE_TEST_CASE(ConstantCreateInvalidContent, ConstantCreateNoContentFixture)
305 BOOST_REQUIRE_THROW(Setup({}, { "output" }), armnn::ParseException);
307 BOOST_FIXTURE_TEST_CASE(ConstantCreateInvalidShape, ConstantCreateNoShapeFixture)
309 BOOST_REQUIRE_THROW(Setup({}, { "output" }), armnn::ParseException);
311 BOOST_FIXTURE_TEST_CASE(ConstantCreateNoShape2, ConstantCreateNoShape2Fixture)
313 BOOST_REQUIRE_THROW(Setup({}, { "output" }), armnn::ParseException);
315 BOOST_FIXTURE_TEST_CASE(ConstantCreateNoShape3, ConstantCreateNoShape3Fixture)
317 Setup({}, { "output" });
318 RunTest<1>({}, { { "output", { 1.f, 1.f, 1.f, 1.f, 1.f } } });
321 BOOST_AUTO_TEST_SUITE_END()