2 // Copyright © 2017 Arm Ltd. All rights reserved.
3 // SPDX-License-Identifier: MIT
6 #include <boost/test/unit_test.hpp>
7 #include "ParserFlatbuffersFixture.hpp"
8 #include "../TfLiteParser.hpp"
13 BOOST_AUTO_TEST_SUITE(TensorflowLiteParser)
15 struct SqueezeFixture : public ParserFlatbuffersFixture
17 explicit SqueezeFixture(const std::string& inputShape,
18 const std::string& outputShape,
19 const std::string& squeezeDims)
24 "operator_codes": [ { "builtin_code": "SQUEEZE" } ],
29 "shape" : )" + inputShape + ",";
33 "name": "inputTensor",
43 "shape" : )" + outputShape;
47 "name": "outputTensor",
63 "builtin_options_type": "SqueezeOptions",
64 "builtin_options": {)";
65 if (!squeezeDims.empty())
67 m_JsonString += R"("squeeze_dims" : )" + squeezeDims;
70 "custom_options_format": "FLEXBUFFERS"
74 "buffers" : [ {}, {} ]
80 struct SqueezeFixtureWithSqueezeDims : SqueezeFixture
82 SqueezeFixtureWithSqueezeDims() : SqueezeFixture("[ 1, 2, 2, 1 ]", "[ 2, 2, 1 ]", "[ 0, 1, 2 ]") {}
85 BOOST_FIXTURE_TEST_CASE(ParseSqueezeWithSqueezeDims, SqueezeFixtureWithSqueezeDims)
87 SetupSingleInputSingleOutput("inputTensor", "outputTensor");
88 RunTest<3, armnn::DataType::QuantisedAsymm8>(0, { 1, 2, 3, 4 }, { 1, 2, 3, 4 });
89 BOOST_TEST((m_Parser->GetNetworkOutputBindingInfo(0, "outputTensor").second.GetShape()
90 == armnn::TensorShape({2,2,1})));
94 struct SqueezeFixtureWithoutSqueezeDims : SqueezeFixture
96 SqueezeFixtureWithoutSqueezeDims() : SqueezeFixture("[ 1, 2, 2, 1 ]", "[ 2, 2 ]", "") {}
99 BOOST_FIXTURE_TEST_CASE(ParseSqueezeWithoutSqueezeDims, SqueezeFixtureWithoutSqueezeDims)
101 SetupSingleInputSingleOutput("inputTensor", "outputTensor");
102 RunTest<2, armnn::DataType::QuantisedAsymm8>(0, { 1, 2, 3, 4 }, { 1, 2, 3, 4 });
103 BOOST_TEST((m_Parser->GetNetworkOutputBindingInfo(0, "outputTensor").second.GetShape()
104 == armnn::TensorShape({2,2})));
107 struct SqueezeFixtureWithInvalidInput : SqueezeFixture
109 SqueezeFixtureWithInvalidInput() : SqueezeFixture("[ 1, 2, 2, 1, 2 ]", "[ 1, 2, 2, 1 ]", "[ ]") {}
112 BOOST_FIXTURE_TEST_CASE(ParseSqueezeInvalidInput, SqueezeFixtureWithInvalidInput)
114 BOOST_CHECK_THROW((SetupSingleInputSingleOutput("inputTensor", "outputTensor")),
115 armnn::InvalidArgumentException);
118 struct SqueezeFixtureWithSqueezeDimsSizeInvalid : SqueezeFixture
120 SqueezeFixtureWithSqueezeDimsSizeInvalid() : SqueezeFixture("[ 1, 2, 2, 1 ]",
122 "[ 1, 2, 2, 2, 2 ]") {}
125 BOOST_FIXTURE_TEST_CASE(ParseSqueezeInvalidSqueezeDims, SqueezeFixtureWithSqueezeDimsSizeInvalid)
127 BOOST_CHECK_THROW((SetupSingleInputSingleOutput("inputTensor", "outputTensor")), armnn::ParseException);
131 struct SqueezeFixtureWithNegativeSqueezeDims : SqueezeFixture
133 SqueezeFixtureWithNegativeSqueezeDims() : SqueezeFixture("[ 1, 2, 2, 1 ]",
138 BOOST_FIXTURE_TEST_CASE(ParseSqueezeNegativeSqueezeDims, SqueezeFixtureWithNegativeSqueezeDims)
140 BOOST_CHECK_THROW((SetupSingleInputSingleOutput("inputTensor", "outputTensor")), armnn::ParseException);
144 BOOST_AUTO_TEST_SUITE_END()