IVGCVSW-5593 Implement Pimpl Idiom for serialization classes
[platform/upstream/armnn.git] / src / armnnDeserializer / test / DeserializeGather.cpp
1 //
2 // Copyright © 2017 Arm Ltd and Contributors. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5
6 #include <boost/test/unit_test.hpp>
7 #include "ParserFlatbuffersSerializeFixture.hpp"
8 #include <armnnDeserializer/IDeserializer.hpp>
9
10 #include <string>
11
12 BOOST_AUTO_TEST_SUITE(Deserializer)
13
14 struct GatherFixture : public ParserFlatbuffersSerializeFixture
15 {
16     explicit GatherFixture(const std::string& inputShape,
17                            const std::string& indicesShape,
18                            const std::string& input1Content,
19                            const std::string& outputShape,
20                            const std::string& axis,
21                            const std::string dataType,
22                            const std::string constDataType)
23     {
24         m_JsonString = R"(
25         {
26                 inputIds: [0],
27                 outputIds: [3],
28                 layers: [
29                 {
30                     layer_type: "InputLayer",
31                     layer: {
32                           base: {
33                                 layerBindingId: 0,
34                                 base: {
35                                     index: 0,
36                                     layerName: "InputLayer",
37                                     layerType: "Input",
38                                     inputSlots: [{
39                                         index: 0,
40                                         connection: {sourceLayerIndex:0, outputSlotIndex:0 },
41                                     }],
42                                     outputSlots: [ {
43                                         index: 0,
44                                         tensorInfo: {
45                                             dimensions: )" + inputShape + R"(,
46                                             dataType: )" + dataType + R"(
47                                             }}]
48                                     }
49                     }}},
50                     {
51                     layer_type: "ConstantLayer",
52                         layer: {
53                                base: {
54                                   index:1,
55                                   layerName: "ConstantLayer",
56                                   layerType: "Constant",
57                                    outputSlots: [ {
58                                     index: 0,
59                                     tensorInfo: {
60                                         dimensions: )" + indicesShape + R"(,
61                                         dataType: "Signed32",
62                                     },
63                                   }],
64                               },
65                               input: {
66                               info: {
67                                        dimensions: )" + indicesShape + R"(,
68                                        dataType: )" + dataType + R"(
69                                    },
70                               data_type: )" + constDataType + R"(,
71                               data: {
72                                   data: )" + input1Content + R"(,
73                                     } }
74                                 },},
75                     {
76                     layer_type: "GatherLayer",
77                         layer: {
78                               base: {
79                                    index: 2,
80                                    layerName: "GatherLayer",
81                                    layerType: "Gather",
82                                    inputSlots: [
83                                    {
84                                        index: 0,
85                                        connection: {sourceLayerIndex:0, outputSlotIndex:0 },
86                                    },
87                                    {
88                                         index: 1,
89                                         connection: {sourceLayerIndex:1, outputSlotIndex:0 }
90                                    }],
91                                    outputSlots: [ {
92                                           index: 0,
93                                           tensorInfo: {
94                                                dimensions: )" + outputShape + R"(,
95                                                dataType: )" + dataType + R"(
96
97                                    }}]},
98                                    descriptor: {
99                                        axis: )" + axis + R"(
100                                    }
101                         }},
102                     {
103                     layer_type: "OutputLayer",
104                     layer: {
105                         base:{
106                               layerBindingId: 0,
107                               base: {
108                                     index: 3,
109                                     layerName: "OutputLayer",
110                                     layerType: "Output",
111                                     inputSlots: [{
112                                         index: 0,
113                                         connection: {sourceLayerIndex:2, outputSlotIndex:0 },
114                                     }],
115                                     outputSlots: [ {
116                                         index: 0,
117                                         tensorInfo: {
118                                             dimensions: )" + outputShape + R"(,
119                                             dataType: )" + dataType + R"(
120                                         },
121                                 }],
122                             }}},
123                 }]
124                  } )";
125
126         Setup();
127     }
128 };
129
130 struct SimpleGatherFixtureFloat32 : GatherFixture
131 {
132     SimpleGatherFixtureFloat32() : GatherFixture("[ 3, 2, 3 ]", "[ 2, 3 ]", "[1, 2, 1, 2, 1, 0]",
133                                                  "[ 2, 3, 2, 3 ]", "0", "Float32", "IntData") {}
134 };
135
136 BOOST_FIXTURE_TEST_CASE(GatherFloat32, SimpleGatherFixtureFloat32)
137 {
138     RunTest<4, armnn::DataType::Float32>(0,
139                                          {{"InputLayer", {  1,  2,  3,
140                                                             4,  5,  6,
141                                                             7,  8,  9,
142                                                             10, 11, 12,
143                                                             13, 14, 15,
144                                                             16, 17, 18 }}},
145                                          {{"OutputLayer", { 7,  8,  9,
146                                                             10, 11, 12,
147                                                             13, 14, 15,
148                                                             16, 17, 18,
149                                                             7,  8,  9,
150                                                             10, 11, 12,
151                                                             13, 14, 15,
152                                                             16, 17, 18,
153                                                             7,  8,  9,
154                                                             10, 11, 12,
155                                                             1,  2,  3,
156                                                             4,  5,  6 }}});
157 }
158
159 BOOST_AUTO_TEST_SUITE_END()
160