0f75db4abb27f3364dc0da9e6284f4353b14d214
[platform/upstream/armnn.git] / src / armnnDeserializer / test / DeserializeGather.cpp
1 //
2 // Copyright © 2017 Arm Ltd and Contributors. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5
6 #include <boost/test/unit_test.hpp>
7 #include "ParserFlatbuffersSerializeFixture.hpp"
8 #include "../Deserializer.hpp"
9
10 #include <string>
11 #include <iostream>
12
13 BOOST_AUTO_TEST_SUITE(Deserializer)
14
15 struct GatherFixture : public ParserFlatbuffersSerializeFixture
16 {
17     explicit GatherFixture(const std::string& inputShape,
18                            const std::string& indicesShape,
19                            const std::string& input1Content,
20                            const std::string& outputShape,
21                            const std::string& axis,
22                            const std::string dataType,
23                            const std::string constDataType)
24     {
25         m_JsonString = R"(
26         {
27                 inputIds: [0],
28                 outputIds: [3],
29                 layers: [
30                 {
31                     layer_type: "InputLayer",
32                     layer: {
33                           base: {
34                                 layerBindingId: 0,
35                                 base: {
36                                     index: 0,
37                                     layerName: "InputLayer",
38                                     layerType: "Input",
39                                     inputSlots: [{
40                                         index: 0,
41                                         connection: {sourceLayerIndex:0, outputSlotIndex:0 },
42                                     }],
43                                     outputSlots: [ {
44                                         index: 0,
45                                         tensorInfo: {
46                                             dimensions: )" + inputShape + R"(,
47                                             dataType: )" + dataType + R"(
48                                             }}]
49                                     }
50                     }}},
51                     {
52                     layer_type: "ConstantLayer",
53                         layer: {
54                                base: {
55                                   index:1,
56                                   layerName: "ConstantLayer",
57                                   layerType: "Constant",
58                                    outputSlots: [ {
59                                     index: 0,
60                                     tensorInfo: {
61                                         dimensions: )" + indicesShape + R"(,
62                                         dataType: "Signed32",
63                                     },
64                                   }],
65                               },
66                               input: {
67                               info: {
68                                        dimensions: )" + indicesShape + R"(,
69                                        dataType: )" + dataType + R"(
70                                    },
71                               data_type: )" + constDataType + R"(,
72                               data: {
73                                   data: )" + input1Content + R"(,
74                                     } }
75                                 },},
76                     {
77                     layer_type: "GatherLayer",
78                         layer: {
79                               base: {
80                                    index: 2,
81                                    layerName: "GatherLayer",
82                                    layerType: "Gather",
83                                    inputSlots: [
84                                    {
85                                        index: 0,
86                                        connection: {sourceLayerIndex:0, outputSlotIndex:0 },
87                                    },
88                                    {
89                                         index: 1,
90                                         connection: {sourceLayerIndex:1, outputSlotIndex:0 }
91                                    }],
92                                    outputSlots: [ {
93                                           index: 0,
94                                           tensorInfo: {
95                                                dimensions: )" + outputShape + R"(,
96                                                dataType: )" + dataType + R"(
97
98                                    }}]},
99                                    descriptor: {
100                                        axis: )" + axis + R"(
101                                    }
102                         }},
103                     {
104                     layer_type: "OutputLayer",
105                     layer: {
106                         base:{
107                               layerBindingId: 0,
108                               base: {
109                                     index: 3,
110                                     layerName: "OutputLayer",
111                                     layerType: "Output",
112                                     inputSlots: [{
113                                         index: 0,
114                                         connection: {sourceLayerIndex:2, outputSlotIndex:0 },
115                                     }],
116                                     outputSlots: [ {
117                                         index: 0,
118                                         tensorInfo: {
119                                             dimensions: )" + outputShape + R"(,
120                                             dataType: )" + dataType + R"(
121                                         },
122                                 }],
123                             }}},
124                 }]
125                  } )";
126
127         Setup();
128     }
129 };
130
131 struct SimpleGatherFixtureFloat32 : GatherFixture
132 {
133     SimpleGatherFixtureFloat32() : GatherFixture("[ 3, 2, 3 ]", "[ 2, 3 ]", "[1, 2, 1, 2, 1, 0]",
134                                                  "[ 2, 3, 2, 3 ]", "0", "Float32", "IntData") {}
135 };
136
137 BOOST_FIXTURE_TEST_CASE(GatherFloat32, SimpleGatherFixtureFloat32)
138 {
139     RunTest<4, armnn::DataType::Float32>(0,
140                                          {{"InputLayer", {  1,  2,  3,
141                                                             4,  5,  6,
142                                                             7,  8,  9,
143                                                             10, 11, 12,
144                                                             13, 14, 15,
145                                                             16, 17, 18 }}},
146                                          {{"OutputLayer", { 7,  8,  9,
147                                                             10, 11, 12,
148                                                             13, 14, 15,
149                                                             16, 17, 18,
150                                                             7,  8,  9,
151                                                             10, 11, 12,
152                                                             13, 14, 15,
153                                                             16, 17, 18,
154                                                             7,  8,  9,
155                                                             10, 11, 12,
156                                                             1,  2,  3,
157                                                             4,  5,  6 }}});
158 }
159
160 BOOST_AUTO_TEST_SUITE_END()
161