IVGCVSW-2467 Remove GetDataType<T> function
[platform/upstream/armnn.git] / src / armnnTfLiteParser / test / Concatenation.cpp
1 //
2 // Copyright © 2017 Arm Ltd. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5
6 #include <boost/test/unit_test.hpp>
7 #include "ParserFlatbuffersFixture.hpp"
8 #include "../TfLiteParser.hpp"
9
10 #include <string>
11 #include <iostream>
12
13 BOOST_AUTO_TEST_SUITE(TensorflowLiteParser)
14
15 struct ConcatenationFixture : public ParserFlatbuffersFixture
16 {
17     explicit ConcatenationFixture(const std::string & inputShape1,
18                                   const std::string & inputShape2,
19                                   const std::string & outputShape,
20                                   const std::string & axis,
21                                   const std::string & activation="NONE")
22     {
23         m_JsonString = R"(
24             {
25                 "version": 3,
26                 "operator_codes": [ { "builtin_code": "CONCATENATION" } ],
27                 "subgraphs": [ {
28                     "tensors": [
29                         {
30                             "shape": )" + inputShape1 + R"(,
31                             "type": "UINT8",
32                             "buffer": 0,
33                             "name": "inputTensor1",
34                             "quantization": {
35                                 "min": [ 0.0 ],
36                                 "max": [ 255.0 ],
37                                 "scale": [ 1.0 ],
38                                 "zero_point": [ 0 ],
39                             }
40                         },
41                         {
42                             "shape": )" + inputShape2 + R"(,
43                             "type": "UINT8",
44                             "buffer": 1,
45                             "name": "inputTensor2",
46                             "quantization": {
47                                 "min": [ 0.0 ],
48                                 "max": [ 255.0 ],
49                                 "scale": [ 1.0 ],
50                                 "zero_point": [ 0 ],
51                             }
52                         },
53                         {
54                             "shape": )" + outputShape + R"( ,
55                             "type": "UINT8",
56                             "buffer": 2,
57                             "name": "outputTensor",
58                             "quantization": {
59                                 "min": [ 0.0 ],
60                                 "max": [ 255.0 ],
61                                 "scale": [ 1.0 ],
62                                 "zero_point": [ 0 ],
63                             }
64                         }
65                     ],
66                     "inputs": [ 0, 1 ],
67                     "outputs": [ 2 ],
68                     "operators": [
69                         {
70                             "opcode_index": 0,
71                             "inputs": [ 0, 1 ],
72                             "outputs": [ 2 ],
73                             "builtin_options_type": "ConcatenationOptions",
74                             "builtin_options": {
75                                 "axis": )" + axis + R"(,
76                                 "fused_activation_function": )" + activation + R"(
77                             },
78                             "custom_options_format": "FLEXBUFFERS"
79                         }
80                     ],
81                 } ],
82                 "buffers" : [
83                     { },
84                     { }
85                 ]
86             }
87         )";
88         Setup();
89     }
90 };
91
92
93 struct ConcatenationFixtureNegativeDim : ConcatenationFixture
94 {
95     ConcatenationFixtureNegativeDim() : ConcatenationFixture("[ 1, 1, 2, 2 ]",
96                                                              "[ 1, 1, 2, 2 ]",
97                                                              "[ 1, 2, 2, 2 ]",
98                                                              "-3" ) {}
99 };
100
101 BOOST_FIXTURE_TEST_CASE(ParseConcatenationNegativeDim, ConcatenationFixtureNegativeDim)
102 {
103     RunTest<4, armnn::DataType::QuantisedAsymm8>(
104         0,
105         {{"inputTensor1", { 0, 1, 2, 3 }},
106         {"inputTensor2", { 4, 5, 6, 7 }}},
107         {{"outputTensor", { 0, 1, 2, 3, 4, 5, 6, 7 }}});
108 }
109
110 struct ConcatenationFixtureNCHW : ConcatenationFixture
111 {
112     ConcatenationFixtureNCHW() : ConcatenationFixture("[ 1, 1, 2, 2 ]", "[ 1, 1, 2, 2 ]", "[ 1, 2, 2, 2 ]", "1" ) {}
113 };
114
115 BOOST_FIXTURE_TEST_CASE(ParseConcatenationNCHW, ConcatenationFixtureNCHW)
116 {
117     RunTest<4, armnn::DataType::QuantisedAsymm8>(
118         0,
119         {{"inputTensor1", { 0, 1, 2, 3 }},
120         {"inputTensor2", { 4, 5, 6, 7 }}},
121         {{"outputTensor", { 0, 1, 2, 3, 4, 5, 6, 7 }}});
122 }
123
124 struct ConcatenationFixtureNHWC : ConcatenationFixture
125 {
126     ConcatenationFixtureNHWC() : ConcatenationFixture("[ 1, 1, 2, 2 ]", "[ 1, 1, 2, 2 ]", "[ 1, 1, 2, 4 ]", "3" ) {}
127 };
128
129 BOOST_FIXTURE_TEST_CASE(ParseConcatenationNHWC, ConcatenationFixtureNHWC)
130 {
131     RunTest<4, armnn::DataType::QuantisedAsymm8>(
132         0,
133         {{"inputTensor1", { 0, 1, 2, 3 }},
134         {"inputTensor2", { 4, 5, 6, 7 }}},
135         {{"outputTensor", { 0, 1, 4, 5, 2, 3, 6, 7 }}});
136 }
137
138 struct ConcatenationFixtureDim1 : ConcatenationFixture
139 {
140     ConcatenationFixtureDim1() : ConcatenationFixture("[ 1, 2, 3, 4 ]", "[ 1, 2, 3, 4 ]", "[ 1, 4, 3, 4 ]", "1" ) {}
141 };
142
143 BOOST_FIXTURE_TEST_CASE(ParseConcatenationDim1, ConcatenationFixtureDim1)
144 {
145     RunTest<4, armnn::DataType::QuantisedAsymm8>(
146         0,
147         { { "inputTensor1", {  0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11,
148                                12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23 } },
149         { "inputTensor2", {  50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61,
150                              62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73 } } },
151         { { "outputTensor", {  0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11,
152                                12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23,
153                                50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61,
154                                62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73 } } });
155 }
156
157 struct ConcatenationFixtureDim3 : ConcatenationFixture
158 {
159     ConcatenationFixtureDim3() : ConcatenationFixture("[ 1, 2, 3, 4 ]", "[ 1, 2, 3, 4 ]", "[ 1, 2, 3, 8 ]", "3" ) {}
160 };
161
162 BOOST_FIXTURE_TEST_CASE(ParseConcatenationDim3, ConcatenationFixtureDim3)
163 {
164     RunTest<4, armnn::DataType::QuantisedAsymm8>(
165         0,
166         { { "inputTensor1", {  0,  1,  2,  3,
167                                4,  5,  6,  7,
168                                8,  9, 10, 11,
169                                12, 13, 14, 15,
170                                16, 17, 18, 19,
171                                20, 21, 22, 23 } },
172         { "inputTensor2", {  50, 51, 52, 53,
173                              54, 55, 56, 57,
174                              58, 59, 60, 61,
175                              62, 63, 64, 65,
176                              66, 67, 68, 69,
177                              70, 71, 72, 73 } } },
178         { { "outputTensor", {  0,  1,  2,  3,
179                                50, 51, 52, 53,
180                                4,  5,  6,  7,
181                                54, 55, 56, 57,
182                                8,  9,  10, 11,
183                                58, 59, 60, 61,
184                                12, 13, 14, 15,
185                                62, 63, 64, 65,
186                                16, 17, 18, 19,
187                                66, 67, 68, 69,
188                                20, 21, 22, 23,
189                                70, 71, 72, 73 } } });
190 }
191
192 BOOST_AUTO_TEST_SUITE_END()