2 // Copyright © 2017 Arm Ltd. All rights reserved.
3 // See LICENSE file in the project root for full license information.
6 #include <boost/test/unit_test.hpp>
7 #include "armnnTfParser/ITfParser.hpp"
8 #include "ParserPrototxtFixture.hpp"
10 BOOST_AUTO_TEST_SUITE(TensorflowParser)
13 template <bool withDimZero, bool withDimOne>
14 struct SqueezeFixture : public ParserPrototxtFixture<armnnTfParser::ITfParser>
20 " name: \"graphInput\" \n"
21 " op: \"Placeholder\" \n"
37 " name: \"Squeeze\" \n"
39 " input: \"graphInput\" \n"
47 " key: \"squeeze_dims\" \n"
53 m_Prototext += "i:0\n";
58 m_Prototext += "i:1\n";
67 SetupSingleInputSingleOutput({ 1, 1, 2, 2 }, "graphInput", "Squeeze");
71 typedef SqueezeFixture<false, false> ImpliedDimensionsSqueezeFixture;
72 typedef SqueezeFixture<true, false> ExplicitDimensionZeroSqueezeFixture;
73 typedef SqueezeFixture<false, true> ExplicitDimensionOneSqueezeFixture;
74 typedef SqueezeFixture<true, true> ExplicitDimensionsSqueezeFixture;
76 BOOST_FIXTURE_TEST_CASE(ParseImplicitSqueeze, ImpliedDimensionsSqueezeFixture)
78 BOOST_TEST((m_Parser->GetNetworkOutputBindingInfo("Squeeze").second.GetShape() ==
79 armnn::TensorShape({2,2})));
80 RunTest<2>({ 1.0f, 2.0f, 3.0f, 4.0f },
81 { 1.0f, 2.0f, 3.0f, 4.0f });
84 BOOST_FIXTURE_TEST_CASE(ParseDimensionZeroSqueeze, ExplicitDimensionZeroSqueezeFixture)
86 BOOST_TEST((m_Parser->GetNetworkOutputBindingInfo("Squeeze").second.GetShape() ==
87 armnn::TensorShape({1,2,2})));
88 RunTest<3>({ 1.0f, 2.0f, 3.0f, 4.0f },
89 { 1.0f, 2.0f, 3.0f, 4.0f });
92 BOOST_FIXTURE_TEST_CASE(ParseDimensionOneSqueeze, ExplicitDimensionOneSqueezeFixture)
94 BOOST_TEST((m_Parser->GetNetworkOutputBindingInfo("Squeeze").second.GetShape() ==
95 armnn::TensorShape({1,2,2})));
96 RunTest<3>({ 1.0f, 2.0f, 3.0f, 4.0f },
97 { 1.0f, 2.0f, 3.0f, 4.0f });
100 BOOST_FIXTURE_TEST_CASE(ParseExplicitDimensionsSqueeze, ExplicitDimensionsSqueezeFixture)
102 BOOST_TEST((m_Parser->GetNetworkOutputBindingInfo("Squeeze").second.GetShape() ==
103 armnn::TensorShape({2,2})));
104 RunTest<2>({ 1.0f, 2.0f, 3.0f, 4.0f },
105 { 1.0f, 2.0f, 3.0f, 4.0f });
108 BOOST_AUTO_TEST_SUITE_END()