IVGCVSW-5593 Implement Pimpl Idiom for serialization classes
[platform/upstream/armnn.git] / src / armnnDeserializer / test / DeserializeStridedSlice.cpp
1 //
2 // Copyright © 2017 Arm Ltd. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5
6 #include <boost/test/unit_test.hpp>
7 #include "ParserFlatbuffersSerializeFixture.hpp"
8 #include <armnnDeserializer/IDeserializer.hpp>
9
10 #include <string>
11
12 BOOST_AUTO_TEST_SUITE(Deserializer)
13
14 struct StridedSliceFixture : public ParserFlatbuffersSerializeFixture
15 {
16     explicit StridedSliceFixture(const std::string& inputShape,
17                                  const std::string& begin,
18                                  const std::string& end,
19                                  const std::string& stride,
20                                  const std::string& beginMask,
21                                  const std::string& endMask,
22                                  const std::string& shrinkAxisMask,
23                                  const std::string& ellipsisMask,
24                                  const std::string& newAxisMask,
25                                  const std::string& dataLayout,
26                                  const std::string& outputShape,
27                                  const std::string& dataType)
28     {
29         m_JsonString = R"(
30             {
31                 inputIds: [0],
32                 outputIds: [2],
33                 layers: [
34                     {
35                         layer_type: "InputLayer",
36                         layer: {
37                             base: {
38                                 layerBindingId: 0,
39                                 base: {
40                                     index: 0,
41                                     layerName: "InputLayer",
42                                     layerType: "Input",
43                                     inputSlots: [{
44                                         index: 0,
45                                         connection: {sourceLayerIndex:0, outputSlotIndex:0 },
46                                     }],
47                                     outputSlots: [{
48                                         index: 0,
49                                         tensorInfo: {
50                                             dimensions: )" + inputShape + R"(,
51                                             dataType: )" + dataType + R"(
52                                         }
53                                     }]
54                                 }
55                             }
56                         }
57                     },
58                     {
59                         layer_type: "StridedSliceLayer",
60                         layer: {
61                             base: {
62                                 index: 1,
63                                 layerName: "StridedSliceLayer",
64                                 layerType: "StridedSlice",
65                                 inputSlots: [{
66                                     index: 0,
67                                     connection: {sourceLayerIndex:0, outputSlotIndex:0 },
68                                 }],
69                                 outputSlots: [{
70                                     index: 0,
71                                     tensorInfo: {
72                                         dimensions: )" + outputShape + R"(,
73                                         dataType: )" + dataType + R"(
74                                     }
75                                 }]
76                             },
77                             descriptor: {
78                                 begin: )" + begin + R"(,
79                                 end: )" + end + R"(,
80                                 stride: )" + stride + R"(,
81                                 beginMask: )" + beginMask + R"(,
82                                 endMask: )" + endMask + R"(,
83                                 shrinkAxisMask: )" + shrinkAxisMask + R"(,
84                                 ellipsisMask: )" + ellipsisMask + R"(,
85                                 newAxisMask: )" + newAxisMask + R"(,
86                                 dataLayout: )" + dataLayout + R"(,
87                             }
88                         }
89                     },
90                     {
91                         layer_type: "OutputLayer",
92                         layer: {
93                             base:{
94                                 layerBindingId: 2,
95                                 base: {
96                                     index: 2,
97                                     layerName: "OutputLayer",
98                                     layerType: "Output",
99                                     inputSlots: [{
100                                         index: 0,
101                                         connection: {sourceLayerIndex:1, outputSlotIndex:0 },
102                                     }],
103                                     outputSlots: [{
104                                         index: 0,
105                                         tensorInfo: {
106                                             dimensions: )" + outputShape + R"(,
107                                             dataType: )" + dataType + R"(
108                                         },
109                                     }],
110                                 }
111                             }
112                         },
113                     }
114                 ]
115             }
116         )";
117         SetupSingleInputSingleOutput("InputLayer", "OutputLayer");
118     }
119 };
120
121 struct SimpleStridedSliceFixture : StridedSliceFixture
122 {
123     SimpleStridedSliceFixture() : StridedSliceFixture("[ 3, 2, 3, 1 ]",
124                                                       "[ 0, 0, 0, 0 ]",
125                                                       "[ 3, 2, 3, 1 ]",
126                                                       "[ 2, 2, 2, 1 ]",
127                                                       "0",
128                                                       "0",
129                                                       "0",
130                                                       "0",
131                                                       "0",
132                                                       "NCHW",
133                                                       "[ 2, 1, 2, 1 ]",
134                                                       "Float32") {}
135 };
136
137 BOOST_FIXTURE_TEST_CASE(SimpleStridedSliceFloat32, SimpleStridedSliceFixture)
138 {
139     RunTest<4, armnn::DataType::Float32>(0,
140                                          {
141                                              1.0f, 1.0f, 1.0f, 2.0f, 2.0f, 2.0f,
142                                              3.0f, 3.0f, 3.0f, 4.0f, 4.0f, 4.0f,
143                                              5.0f, 5.0f, 5.0f, 6.0f, 6.0f, 6.0f
144                                          },
145                                          {
146                                              1.0f, 1.0f, 5.0f, 5.0f
147                                          });
148 }
149
150 struct StridedSliceMaskFixture : StridedSliceFixture
151 {
152     StridedSliceMaskFixture() : StridedSliceFixture("[ 3, 2, 3, 1 ]",
153                                                     "[ 1, 1, 1, 1 ]",
154                                                     "[ 1, 1, 1, 1 ]",
155                                                     "[ 1, 1, 1, 1 ]",
156                                                     "15",
157                                                     "15",
158                                                     "0",
159                                                     "0",
160                                                     "0",
161                                                     "NCHW",
162                                                     "[ 3, 2, 3, 1 ]",
163                                                     "Float32") {}
164 };
165
166 BOOST_FIXTURE_TEST_CASE(StridedSliceMaskFloat32, StridedSliceMaskFixture)
167 {
168     RunTest<4, armnn::DataType::Float32>(0,
169                                          {
170                                              1.0f, 1.0f, 1.0f, 2.0f, 2.0f, 2.0f,
171                                              3.0f, 3.0f, 3.0f, 4.0f, 4.0f, 4.0f,
172                                              5.0f, 5.0f, 5.0f, 6.0f, 6.0f, 6.0f
173                                          },
174                                          {
175                                              1.0f, 1.0f, 1.0f, 2.0f, 2.0f, 2.0f,
176                                              3.0f, 3.0f, 3.0f, 4.0f, 4.0f, 4.0f,
177                                              5.0f, 5.0f, 5.0f, 6.0f, 6.0f, 6.0f
178                                          });
179 }
180
181 BOOST_AUTO_TEST_SUITE_END()