2 // Copyright © 2017 Arm Ltd. All rights reserved.
3 // See LICENSE file in the project root for full license information.
6 #include <boost/test/unit_test.hpp>
7 #include "armnnTfParser/ITfParser.hpp"
8 #include "ParserPrototxtFixture.hpp"
10 BOOST_AUTO_TEST_SUITE(TensorflowParser)
12 template <bool withDimZero, bool withDimOne>
13 struct SqueezeFixture : public armnnUtils::ParserPrototxtFixture<armnnTfParser::ITfParser>
19 " name: \"graphInput\" \n"
20 " op: \"Placeholder\" \n"
36 " name: \"Squeeze\" \n"
38 " input: \"graphInput\" \n"
46 " key: \"squeeze_dims\" \n"
52 m_Prototext += "i:0\n";
57 m_Prototext += "i:1\n";
66 SetupSingleInputSingleOutput({ 1, 1, 2, 2 }, "graphInput", "Squeeze");
70 typedef SqueezeFixture<false, false> ImpliedDimensionsSqueezeFixture;
71 typedef SqueezeFixture<true, false> ExplicitDimensionZeroSqueezeFixture;
72 typedef SqueezeFixture<false, true> ExplicitDimensionOneSqueezeFixture;
73 typedef SqueezeFixture<true, true> ExplicitDimensionsSqueezeFixture;
75 BOOST_FIXTURE_TEST_CASE(ParseImplicitSqueeze, ImpliedDimensionsSqueezeFixture)
77 BOOST_TEST((m_Parser->GetNetworkOutputBindingInfo("Squeeze").second.GetShape() ==
78 armnn::TensorShape({2,2})));
79 RunTest<2>({ 1.0f, 2.0f, 3.0f, 4.0f },
80 { 1.0f, 2.0f, 3.0f, 4.0f });
83 BOOST_FIXTURE_TEST_CASE(ParseDimensionZeroSqueeze, ExplicitDimensionZeroSqueezeFixture)
85 BOOST_TEST((m_Parser->GetNetworkOutputBindingInfo("Squeeze").second.GetShape() ==
86 armnn::TensorShape({1,2,2})));
87 RunTest<3>({ 1.0f, 2.0f, 3.0f, 4.0f },
88 { 1.0f, 2.0f, 3.0f, 4.0f });
91 BOOST_FIXTURE_TEST_CASE(ParseDimensionOneSqueeze, ExplicitDimensionOneSqueezeFixture)
93 BOOST_TEST((m_Parser->GetNetworkOutputBindingInfo("Squeeze").second.GetShape() ==
94 armnn::TensorShape({1,2,2})));
95 RunTest<3>({ 1.0f, 2.0f, 3.0f, 4.0f },
96 { 1.0f, 2.0f, 3.0f, 4.0f });
99 BOOST_FIXTURE_TEST_CASE(ParseExplicitDimensionsSqueeze, ExplicitDimensionsSqueezeFixture)
101 BOOST_TEST((m_Parser->GetNetworkOutputBindingInfo("Squeeze").second.GetShape() ==
102 armnn::TensorShape({2,2})));
103 RunTest<2>({ 1.0f, 2.0f, 3.0f, 4.0f },
104 { 1.0f, 2.0f, 3.0f, 4.0f });
107 BOOST_AUTO_TEST_SUITE_END()