Release 18.03
[platform/upstream/armnn.git] / src / armnnTfParser / test / Squeeze.cpp
1 //
2 // Copyright © 2017 Arm Ltd. All rights reserved.
3 // See LICENSE file in the project root for full license information.
4 //
5
6 #include <boost/test/unit_test.hpp>
7 #include "armnnTfParser/ITfParser.hpp"
8 #include "ParserPrototxtFixture.hpp"
9
10 BOOST_AUTO_TEST_SUITE(TensorflowParser)
11
12
13 template <bool withDimZero, bool withDimOne>
14 struct SqueezeFixture : public ParserPrototxtFixture<armnnTfParser::ITfParser>
15 {
16     SqueezeFixture()
17     {
18         m_Prototext =
19                 "node { \n"
20                 "    name: \"graphInput\" \n"
21                 "    op: \"Placeholder\" \n"
22                 "    attr { \n"
23                 "      key: \"dtype\" \n"
24                 "      value { \n"
25                 "        type: DT_FLOAT \n"
26                 "      } \n"
27                 "    } \n"
28                 "    attr { \n"
29                 "      key: \"shape\" \n"
30                 "      value { \n"
31                 "        shape { \n"
32                 "        } \n"
33                 "      } \n"
34                 "    } \n"
35                 "  } \n"
36                 "node { \n"
37                 "  name: \"Squeeze\" \n"
38                 "  op: \"Squeeze\" \n"
39                 "  input: \"graphInput\" \n"
40                 "  attr { \n"
41                 "    key: \"T\" \n"
42                 "    value { \n"
43                 "      type: DT_FLOAT \n"
44                 "    } \n"
45                 "  } \n"
46                 "  attr { \n"
47                 "    key: \"squeeze_dims\" \n"
48                 "    value { \n"
49                 "      list {\n";
50
51         if (withDimZero)
52         {
53             m_Prototext += "i:0\n";
54         }
55
56         if (withDimOne)
57         {
58             m_Prototext += "i:1\n";
59         }
60
61         m_Prototext +=
62                 "      } \n"
63                 "    } \n"
64                 "  } \n"
65                 "} \n";
66
67         SetupSingleInputSingleOutput({ 1, 1, 2, 2 }, "graphInput", "Squeeze");
68     }
69 };
70
71 typedef SqueezeFixture<false, false> ImpliedDimensionsSqueezeFixture;
72 typedef SqueezeFixture<true, false>  ExplicitDimensionZeroSqueezeFixture;
73 typedef SqueezeFixture<false, true>  ExplicitDimensionOneSqueezeFixture;
74 typedef SqueezeFixture<true, true>   ExplicitDimensionsSqueezeFixture;
75
76 BOOST_FIXTURE_TEST_CASE(ParseImplicitSqueeze, ImpliedDimensionsSqueezeFixture)
77 {
78     BOOST_TEST((m_Parser->GetNetworkOutputBindingInfo("Squeeze").second.GetShape() ==
79                armnn::TensorShape({2,2})));
80     RunTest<2>({ 1.0f, 2.0f, 3.0f, 4.0f },
81                { 1.0f, 2.0f, 3.0f, 4.0f });
82 }
83
84 BOOST_FIXTURE_TEST_CASE(ParseDimensionZeroSqueeze, ExplicitDimensionZeroSqueezeFixture)
85 {
86     BOOST_TEST((m_Parser->GetNetworkOutputBindingInfo("Squeeze").second.GetShape() ==
87                armnn::TensorShape({1,2,2})));
88     RunTest<3>({ 1.0f, 2.0f, 3.0f, 4.0f },
89                { 1.0f, 2.0f, 3.0f, 4.0f });
90 }
91
92 BOOST_FIXTURE_TEST_CASE(ParseDimensionOneSqueeze, ExplicitDimensionOneSqueezeFixture)
93 {
94     BOOST_TEST((m_Parser->GetNetworkOutputBindingInfo("Squeeze").second.GetShape() ==
95                armnn::TensorShape({1,2,2})));
96     RunTest<3>({ 1.0f, 2.0f, 3.0f, 4.0f },
97                { 1.0f, 2.0f, 3.0f, 4.0f });
98 }
99
100 BOOST_FIXTURE_TEST_CASE(ParseExplicitDimensionsSqueeze, ExplicitDimensionsSqueezeFixture)
101 {
102     BOOST_TEST((m_Parser->GetNetworkOutputBindingInfo("Squeeze").second.GetShape() ==
103                armnn::TensorShape({2,2})));
104     RunTest<2>({ 1.0f, 2.0f, 3.0f, 4.0f },
105                { 1.0f, 2.0f, 3.0f, 4.0f });
106 }
107
108 BOOST_AUTO_TEST_SUITE_END()