Release 18.08
[platform/upstream/armnn.git] / src / armnnTfParser / test / Squeeze.cpp
1 //
2 // Copyright © 2017 Arm Ltd. All rights reserved.
3 // See LICENSE file in the project root for full license information.
4 //
5
6 #include <boost/test/unit_test.hpp>
7 #include "armnnTfParser/ITfParser.hpp"
8 #include "ParserPrototxtFixture.hpp"
9
10 BOOST_AUTO_TEST_SUITE(TensorflowParser)
11
12 template <bool withDimZero, bool withDimOne>
13 struct SqueezeFixture : public armnnUtils::ParserPrototxtFixture<armnnTfParser::ITfParser>
14 {
15     SqueezeFixture()
16     {
17         m_Prototext =
18                 "node { \n"
19                 "    name: \"graphInput\" \n"
20                 "    op: \"Placeholder\" \n"
21                 "    attr { \n"
22                 "      key: \"dtype\" \n"
23                 "      value { \n"
24                 "        type: DT_FLOAT \n"
25                 "      } \n"
26                 "    } \n"
27                 "    attr { \n"
28                 "      key: \"shape\" \n"
29                 "      value { \n"
30                 "        shape { \n"
31                 "        } \n"
32                 "      } \n"
33                 "    } \n"
34                 "  } \n"
35                 "node { \n"
36                 "  name: \"Squeeze\" \n"
37                 "  op: \"Squeeze\" \n"
38                 "  input: \"graphInput\" \n"
39                 "  attr { \n"
40                 "    key: \"T\" \n"
41                 "    value { \n"
42                 "      type: DT_FLOAT \n"
43                 "    } \n"
44                 "  } \n"
45                 "  attr { \n"
46                 "    key: \"squeeze_dims\" \n"
47                 "    value { \n"
48                 "      list {\n";
49
50         if (withDimZero)
51         {
52             m_Prototext += "i:0\n";
53         }
54
55         if (withDimOne)
56         {
57             m_Prototext += "i:1\n";
58         }
59
60         m_Prototext +=
61                 "      } \n"
62                 "    } \n"
63                 "  } \n"
64                 "} \n";
65
66         SetupSingleInputSingleOutput({ 1, 1, 2, 2 }, "graphInput", "Squeeze");
67     }
68 };
69
70 typedef SqueezeFixture<false, false> ImpliedDimensionsSqueezeFixture;
71 typedef SqueezeFixture<true, false>  ExplicitDimensionZeroSqueezeFixture;
72 typedef SqueezeFixture<false, true>  ExplicitDimensionOneSqueezeFixture;
73 typedef SqueezeFixture<true, true>   ExplicitDimensionsSqueezeFixture;
74
75 BOOST_FIXTURE_TEST_CASE(ParseImplicitSqueeze, ImpliedDimensionsSqueezeFixture)
76 {
77     BOOST_TEST((m_Parser->GetNetworkOutputBindingInfo("Squeeze").second.GetShape() ==
78                armnn::TensorShape({2,2})));
79     RunTest<2>({ 1.0f, 2.0f, 3.0f, 4.0f },
80                { 1.0f, 2.0f, 3.0f, 4.0f });
81 }
82
83 BOOST_FIXTURE_TEST_CASE(ParseDimensionZeroSqueeze, ExplicitDimensionZeroSqueezeFixture)
84 {
85     BOOST_TEST((m_Parser->GetNetworkOutputBindingInfo("Squeeze").second.GetShape() ==
86                armnn::TensorShape({1,2,2})));
87     RunTest<3>({ 1.0f, 2.0f, 3.0f, 4.0f },
88                { 1.0f, 2.0f, 3.0f, 4.0f });
89 }
90
91 BOOST_FIXTURE_TEST_CASE(ParseDimensionOneSqueeze, ExplicitDimensionOneSqueezeFixture)
92 {
93     BOOST_TEST((m_Parser->GetNetworkOutputBindingInfo("Squeeze").second.GetShape() ==
94                armnn::TensorShape({1,2,2})));
95     RunTest<3>({ 1.0f, 2.0f, 3.0f, 4.0f },
96                { 1.0f, 2.0f, 3.0f, 4.0f });
97 }
98
99 BOOST_FIXTURE_TEST_CASE(ParseExplicitDimensionsSqueeze, ExplicitDimensionsSqueezeFixture)
100 {
101     BOOST_TEST((m_Parser->GetNetworkOutputBindingInfo("Squeeze").second.GetShape() ==
102                armnn::TensorShape({2,2})));
103     RunTest<2>({ 1.0f, 2.0f, 3.0f, 4.0f },
104                { 1.0f, 2.0f, 3.0f, 4.0f });
105 }
106
107 BOOST_AUTO_TEST_SUITE_END()