IVGCVSW-2467 Remove GetDataType<T> function
[platform/upstream/armnn.git] / src / armnnTfLiteParser / test / Reshape.cpp
1 //
2 // Copyright © 2017 Arm Ltd. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5
6 #include <boost/test/unit_test.hpp>
7 #include "ParserFlatbuffersFixture.hpp"
8 #include "../TfLiteParser.hpp"
9
10 #include <string>
11 #include <iostream>
12
13 BOOST_AUTO_TEST_SUITE(TensorflowLiteParser)
14
15 struct ReshapeFixture : public ParserFlatbuffersFixture
16 {
17     explicit ReshapeFixture(const std::string& inputShape,
18                             const std::string& outputShape,
19                             const std::string& newShape)
20     {
21         m_JsonString = R"(
22             {
23                 "version": 3,
24                 "operator_codes": [ { "builtin_code": "RESHAPE" } ],
25                 "subgraphs": [ {
26                     "tensors": [
27                         {)";
28         m_JsonString += R"(
29                             "shape" : )" + inputShape + ",";
30         m_JsonString += R"(
31                             "type": "UINT8",
32                             "buffer": 0,
33                             "name": "inputTensor",
34                             "quantization": {
35                                 "min": [ 0.0 ],
36                                 "max": [ 255.0 ],
37                                 "scale": [ 1.0 ],
38                                 "zero_point": [ 0 ],
39                             }
40                         },
41                         {)";
42         m_JsonString += R"(
43                             "shape" : )" + outputShape;
44         m_JsonString += R"(,
45                             "type": "UINT8",
46                             "buffer": 1,
47                             "name": "outputTensor",
48                             "quantization": {
49                                 "min": [ 0.0 ],
50                                 "max": [ 255.0 ],
51                                 "scale": [ 1.0 ],
52                                 "zero_point": [ 0 ],
53                             }
54                         }
55                     ],
56                     "inputs": [ 0 ],
57                     "outputs": [ 1 ],
58                     "operators": [
59                         {
60                             "opcode_index": 0,
61                             "inputs": [ 0 ],
62                             "outputs": [ 1 ],
63                             "builtin_options_type": "ReshapeOptions",
64                             "builtin_options": {)";
65         if (!newShape.empty())
66         {
67             m_JsonString += R"("new_shape" : )" + newShape;
68         }
69         m_JsonString += R"(},
70                             "custom_options_format": "FLEXBUFFERS"
71                         }
72                     ],
73                 } ],
74                 "buffers" : [ {}, {} ]
75             }
76         )";
77
78     }
79 };
80
81 struct ReshapeFixtureWithReshapeDims : ReshapeFixture
82 {
83     ReshapeFixtureWithReshapeDims() : ReshapeFixture("[ 1, 9 ]", "[ 3, 3 ]", "[ 3, 3 ]") {}
84 };
85
86 BOOST_FIXTURE_TEST_CASE(ParseReshapeWithReshapeDims, ReshapeFixtureWithReshapeDims)
87 {
88     SetupSingleInputSingleOutput("inputTensor", "outputTensor");
89     RunTest<2, armnn::DataType::QuantisedAsymm8>(0,
90                                                  { 1, 2, 3, 4, 5, 6, 7, 8, 9 },
91                                                  { 1, 2, 3, 4, 5, 6, 7, 8, 9 });
92     BOOST_TEST((m_Parser->GetNetworkOutputBindingInfo(0, "outputTensor").second.GetShape()
93                 == armnn::TensorShape({3,3})));
94 }
95
96 struct ReshapeFixtureWithReshapeDimsFlatten : ReshapeFixture
97 {
98     ReshapeFixtureWithReshapeDimsFlatten() : ReshapeFixture("[ 3, 3 ]", "[ 1, 9 ]", "[ -1 ]") {}
99 };
100
101 BOOST_FIXTURE_TEST_CASE(ParseReshapeWithReshapeDimsFlatten, ReshapeFixtureWithReshapeDimsFlatten)
102 {
103     SetupSingleInputSingleOutput("inputTensor", "outputTensor");
104     RunTest<2, armnn::DataType::QuantisedAsymm8>(0,
105                                                  { 1, 2, 3, 4, 5, 6, 7, 8, 9 },
106                                                  { 1, 2, 3, 4, 5, 6, 7, 8, 9 });
107     BOOST_TEST((m_Parser->GetNetworkOutputBindingInfo(0, "outputTensor").second.GetShape()
108                 == armnn::TensorShape({1,9})));
109 }
110
111 struct ReshapeFixtureWithReshapeDimsFlattenTwoDims : ReshapeFixture
112 {
113     ReshapeFixtureWithReshapeDimsFlattenTwoDims() : ReshapeFixture("[ 3, 2, 3 ]", "[ 2, 9 ]", "[ 2, -1 ]") {}
114 };
115
116 BOOST_FIXTURE_TEST_CASE(ParseReshapeWithReshapeDimsFlattenTwoDims, ReshapeFixtureWithReshapeDimsFlattenTwoDims)
117 {
118     SetupSingleInputSingleOutput("inputTensor", "outputTensor");
119     RunTest<2, armnn::DataType::QuantisedAsymm8>(0,
120                                                  { 1, 1, 1, 2, 2, 2, 3, 3, 3, 4, 4, 4, 5, 5, 5, 6, 6, 6 },
121                                                  { 1, 1, 1, 2, 2, 2, 3, 3, 3, 4, 4, 4, 5, 5, 5, 6, 6, 6 });
122     BOOST_TEST((m_Parser->GetNetworkOutputBindingInfo(0, "outputTensor").second.GetShape()
123                 == armnn::TensorShape({2,9})));
124 }
125
126 struct ReshapeFixtureWithReshapeDimsFlattenOneDim : ReshapeFixture
127 {
128     ReshapeFixtureWithReshapeDimsFlattenOneDim() : ReshapeFixture("[ 2, 9 ]", "[ 2, 3, 3 ]", "[ 2, -1, 3 ]") {}
129 };
130
131 BOOST_FIXTURE_TEST_CASE(ParseReshapeWithReshapeDimsFlattenOneDim, ReshapeFixtureWithReshapeDimsFlattenOneDim)
132 {
133     SetupSingleInputSingleOutput("inputTensor", "outputTensor");
134     RunTest<3, armnn::DataType::QuantisedAsymm8>(0,
135                                                  { 1, 1, 1, 2, 2, 2, 3, 3, 3, 4, 4, 4, 5, 5, 5, 6, 6, 6 },
136                                                  { 1, 1, 1, 2, 2, 2, 3, 3, 3, 4, 4, 4, 5, 5, 5, 6, 6, 6 });
137     BOOST_TEST((m_Parser->GetNetworkOutputBindingInfo(0, "outputTensor").second.GetShape()
138                 == armnn::TensorShape({2,3,3})));
139 }
140
141 BOOST_AUTO_TEST_SUITE_END()