Release 18.08
[platform/upstream/armnn.git] / src / armnnTfLiteParser / test / Squeeze.cpp
1 //
2 // Copyright © 2017 Arm Ltd. All rights reserved.
3 // See LICENSE file in the project root for full license information.
4 //
5
6 #include <boost/test/unit_test.hpp>
7 #include "ParserFlatbuffersFixture.hpp"
8 #include "../TfLiteParser.hpp"
9
10 #include <string>
11 #include <iostream>
12
13 BOOST_AUTO_TEST_SUITE(TensorflowLiteParser)
14
15 struct SqueezeFixture : public ParserFlatbuffersFixture
16 {
17     explicit SqueezeFixture(const std::string& inputShape,
18                             const std::string& outputShape,
19                             const std::string& squeezeDims)
20     {
21         m_JsonString = R"(
22             {
23                 "version": 3,
24                 "operator_codes": [ { "builtin_code": "SQUEEZE" } ],
25                 "subgraphs": [ {
26                     "tensors": [
27                         {)";
28         m_JsonString += R"(
29                             "shape" : )" + inputShape + ",";
30         m_JsonString += R"(
31                             "type": "UINT8",
32                             "buffer": 0,
33                             "name": "inputTensor",
34                             "quantization": {
35                                 "min": [ 0.0 ],
36                                 "max": [ 255.0 ],
37                                 "scale": [ 1.0 ],
38                                 "zero_point": [ 0 ],
39                             }
40                         },
41                         {)";
42         m_JsonString += R"(
43                             "shape" : )" + outputShape;
44         m_JsonString += R"(,
45                             "type": "UINT8",
46                             "buffer": 1,
47                             "name": "outputTensor",
48                             "quantization": {
49                                 "min": [ 0.0 ],
50                                 "max": [ 255.0 ],
51                                 "scale": [ 1.0 ],
52                                 "zero_point": [ 0 ],
53                             }
54                         }
55                     ],
56                     "inputs": [ 0 ],
57                     "outputs": [ 1 ],
58                     "operators": [
59                         {
60                             "opcode_index": 0,
61                             "inputs": [ 0 ],
62                             "outputs": [ 1 ],
63                             "builtin_options_type": "SqueezeOptions",
64                             "builtin_options": {)";
65         if (!squeezeDims.empty())
66         {
67             m_JsonString += R"("squeeze_dims" : )" + squeezeDims;
68         }
69         m_JsonString += R"(},
70                             "custom_options_format": "FLEXBUFFERS"
71                         }
72                     ],
73                 } ],
74                 "buffers" : [ {}, {} ]
75             }
76         )";
77     }
78 };
79
80 struct SqueezeFixtureWithSqueezeDims : SqueezeFixture
81 {
82     SqueezeFixtureWithSqueezeDims() : SqueezeFixture("[ 1, 2, 2, 1 ]", "[ 2, 2, 1 ]", "[ 0, 1, 2 ]") {}
83 };
84
85 BOOST_FIXTURE_TEST_CASE(ParseSqueezeWithSqueezeDims, SqueezeFixtureWithSqueezeDims)
86 {
87     SetupSingleInputSingleOutput("inputTensor", "outputTensor");
88     RunTest<3, uint8_t>(0, { 1, 2, 3, 4 }, { 1, 2, 3, 4 });
89     BOOST_TEST((m_Parser->GetNetworkOutputBindingInfo(0, "outputTensor").second.GetShape()
90         == armnn::TensorShape({2,2,1})));
91
92 }
93
94 struct SqueezeFixtureWithoutSqueezeDims : SqueezeFixture
95 {
96     SqueezeFixtureWithoutSqueezeDims() : SqueezeFixture("[ 1, 2, 2, 1 ]", "[ 2, 2 ]", "") {}
97 };
98
99 BOOST_FIXTURE_TEST_CASE(ParseSqueezeWithoutSqueezeDims, SqueezeFixtureWithoutSqueezeDims)
100 {
101     SetupSingleInputSingleOutput("inputTensor", "outputTensor");
102     RunTest<2, uint8_t>(0, { 1, 2, 3, 4 }, { 1, 2, 3, 4 });
103     BOOST_TEST((m_Parser->GetNetworkOutputBindingInfo(0, "outputTensor").second.GetShape()
104         == armnn::TensorShape({2,2})));
105 }
106
107 struct SqueezeFixtureWithInvalidInput : SqueezeFixture
108 {
109     SqueezeFixtureWithInvalidInput() : SqueezeFixture("[ 1, 2, 2, 1, 2 ]", "[ 1, 2, 2, 1 ]", "[ ]") {}
110 };
111
112 BOOST_FIXTURE_TEST_CASE(ParseSqueezeInvalidInput, SqueezeFixtureWithInvalidInput)
113 {
114     BOOST_CHECK_THROW((SetupSingleInputSingleOutput("inputTensor", "outputTensor")),
115                       armnn::InvalidArgumentException);
116 }
117
118 struct SqueezeFixtureWithSqueezeDimsSizeInvalid : SqueezeFixture
119 {
120     SqueezeFixtureWithSqueezeDimsSizeInvalid() : SqueezeFixture("[ 1, 2, 2, 1 ]",
121                                                                 "[ 1, 2, 2, 1 ]",
122                                                                 "[ 1, 2, 2, 2, 2 ]") {}
123 };
124
125 BOOST_FIXTURE_TEST_CASE(ParseSqueezeInvalidSqueezeDims, SqueezeFixtureWithSqueezeDimsSizeInvalid)
126 {
127     BOOST_CHECK_THROW((SetupSingleInputSingleOutput("inputTensor", "outputTensor")), armnn::ParseException);
128 }
129
130
131 struct SqueezeFixtureWithNegativeSqueezeDims : SqueezeFixture
132 {
133     SqueezeFixtureWithNegativeSqueezeDims() : SqueezeFixture("[ 1, 2, 2, 1 ]",
134                                                              "[ 1, 2, 2, 1 ]",
135                                                              "[ -2 , 2 ]") {}
136 };
137
138 BOOST_FIXTURE_TEST_CASE(ParseSqueezeNegativeSqueezeDims, SqueezeFixtureWithNegativeSqueezeDims)
139 {
140     BOOST_CHECK_THROW((SetupSingleInputSingleOutput("inputTensor", "outputTensor")), armnn::ParseException);
141 }
142
143
144 BOOST_AUTO_TEST_SUITE_END()