IVGCVSW-2467 Remove GetDataType<T> function
[platform/upstream/armnn.git] / src / armnnTfLiteParser / test / FullyConnected.cpp
1 //
2 // Copyright © 2017 Arm Ltd. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5
6 #include <boost/test/unit_test.hpp>
7 #include "ParserFlatbuffersFixture.hpp"
8 #include "../TfLiteParser.hpp"
9
10 #include <string>
11 #include <iostream>
12
13 BOOST_AUTO_TEST_SUITE(TensorflowLiteParser)
14
15 struct FullyConnectedFixture : public ParserFlatbuffersFixture
16 {
17     explicit FullyConnectedFixture(const std::string& inputShape,
18                                            const std::string& outputShape,
19                                            const std::string& filterShape,
20                                            const std::string& filterData,
21                                            const std::string biasShape = "",
22                                            const std::string biasData = "")
23     {
24         std::string inputTensors = "[ 0, 2 ]";
25         std::string biasTensor = "";
26         std::string biasBuffer = "";
27         if (biasShape.size() > 0 && biasData.size() > 0)
28         {
29             inputTensors = "[ 0, 2, 3 ]";
30             biasTensor = R"(
31                         {
32                             "shape": )" + biasShape + R"( ,
33                             "type": "INT32",
34                             "buffer": 3,
35                             "name": "biasTensor",
36                             "quantization": {
37                                 "min": [ 0.0 ],
38                                 "max": [ 255.0 ],
39                                 "scale": [ 1.0 ],
40                                 "zero_point": [ 0 ],
41                             }
42                         } )";
43             biasBuffer = R"(
44                     { "data": )" + biasData + R"(, }, )";
45         }
46         m_JsonString = R"(
47             {
48                 "version": 3,
49                 "operator_codes": [ { "builtin_code": "FULLY_CONNECTED" } ],
50                 "subgraphs": [ {
51                     "tensors": [
52                         {
53                             "shape": )" + inputShape + R"(,
54                             "type": "UINT8",
55                             "buffer": 0,
56                             "name": "inputTensor",
57                             "quantization": {
58                                 "min": [ 0.0 ],
59                                 "max": [ 255.0 ],
60                                 "scale": [ 1.0 ],
61                                 "zero_point": [ 0 ],
62                             }
63                         },
64                         {
65                             "shape": )" + outputShape + R"(,
66                             "type": "UINT8",
67                             "buffer": 1,
68                             "name": "outputTensor",
69                             "quantization": {
70                                 "min": [ 0.0 ],
71                                 "max": [ 511.0 ],
72                                 "scale": [ 2.0 ],
73                                 "zero_point": [ 0 ],
74                             }
75                         },
76                         {
77                             "shape": )" + filterShape + R"(,
78                             "type": "UINT8",
79                             "buffer": 2,
80                             "name": "filterTensor",
81                             "quantization": {
82                                 "min": [ 0.0 ],
83                                 "max": [ 255.0 ],
84                                 "scale": [ 1.0 ],
85                                 "zero_point": [ 0 ],
86                             }
87                         }, )" + biasTensor + R"(
88                     ],
89                     "inputs": [ 0 ],
90                     "outputs": [ 1 ],
91                     "operators": [
92                         {
93                             "opcode_index": 0,
94                             "inputs": )" + inputTensors + R"(,
95                             "outputs": [ 1 ],
96                             "builtin_options_type": "FullyConnectedOptions",
97                             "builtin_options": {
98                                 "fused_activation_function": "NONE"
99                             },
100                             "custom_options_format": "FLEXBUFFERS"
101                         }
102                     ],
103                 } ],
104                 "buffers" : [
105                     { },
106                     { },
107                     { "data": )" + filterData + R"(, }, )"
108                        + biasBuffer + R"(
109                 ]
110             }
111         )";
112         SetupSingleInputSingleOutput("inputTensor", "outputTensor");
113     }
114 };
115
116 struct FullyConnectedWithNoBiasFixture : FullyConnectedFixture
117 {
118     FullyConnectedWithNoBiasFixture()
119         : FullyConnectedFixture("[ 1, 4, 1, 1 ]",     // inputShape
120                                 "[ 1, 1 ]",           // outputShape
121                                 "[ 1, 4 ]",           // filterShape
122                                 "[ 2, 3, 4, 5 ]")     // filterData
123     {}
124 };
125
126 BOOST_FIXTURE_TEST_CASE(FullyConnectedWithNoBias, FullyConnectedWithNoBiasFixture)
127 {
128     RunTest<2, armnn::DataType::QuantisedAsymm8>(
129         0,
130         { 10, 20, 30, 40 },
131         { 400/2 });
132 }
133
134 struct FullyConnectedWithBiasFixture : FullyConnectedFixture
135 {
136     FullyConnectedWithBiasFixture()
137         : FullyConnectedFixture("[ 1, 4, 1, 1 ]",     // inputShape
138                                 "[ 1, 1 ]",           // outputShape
139                                 "[ 1, 4 ]",           // filterShape
140                                 "[ 2, 3, 4, 5 ]",     // filterData
141                                 "[ 1 ]",              // biasShape
142                                 "[ 10, 0, 0, 0 ]" )   // biasData
143     {}
144 };
145
146 BOOST_FIXTURE_TEST_CASE(ParseFullyConnectedWithBias, FullyConnectedWithBiasFixture)
147 {
148     RunTest<2, armnn::DataType::QuantisedAsymm8>(
149         0,
150         { 10, 20, 30, 40 },
151         { (400+10)/2 });
152 }
153
154 BOOST_AUTO_TEST_SUITE_END()