8f14af150b04abbf7c9c924b01a6962338ef3d88
[platform/upstream/armnn.git] / src / armnnDeserializer / test / DeserializeRank.cpp
1 //
2 // Copyright © 2020 Arm Ltd and Contributors. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5
6 #include <boost/test/unit_test.hpp>
7 #include "ParserFlatbuffersSerializeFixture.hpp"
8 #include "../Deserializer.hpp"
9
10 #include <string>
11
12 BOOST_AUTO_TEST_SUITE(Deserializer)
13
14 struct RankFixture : public ParserFlatbuffersSerializeFixture
15 {
16     explicit RankFixture(const std::string &inputShape,
17                          const std::string &dataType)
18     {
19         m_JsonString = R"(
20         {
21             inputIds: [0],
22             outputIds: [2],
23               layers: [
24                {
25                  layer_type: "InputLayer",
26                  layer: {
27                    base: {
28                      base: {
29                        layerName: "",
30                        layerType: "Input",
31                        inputSlots: [
32
33                        ],
34                        outputSlots: [
35                          {
36                            tensorInfo: {
37                              dimensions: )" + inputShape + R"(,
38                              dataType: )" + dataType + R"(,
39                              quantizationScale: 0.0
40                            }
41                          }
42                        ]
43                      }
44                    }
45                  }
46                },
47                {
48                  layer_type: "RankLayer",
49                  layer: {
50                    base: {
51                      index: 1,
52                      layerName: "rank",
53                      layerType: "Rank",
54                      inputSlots: [
55                        {
56                          connection: {
57                            sourceLayerIndex: 0,
58                            outputSlotIndex: 0
59                          }
60                        }
61                      ],
62                      outputSlots: [
63                        {
64                          tensorInfo: {
65                            dimensions: [ 1 ],
66                            dataType: "Signed32",
67                            quantizationScale: 0.0,
68                            dimensionality: 2
69                          }
70                        }
71                      ]
72                    }
73                  }
74                },
75                {
76                  layer_type: "OutputLayer",
77                  layer: {
78                    base: {
79                      base: {
80                        index: 2,
81                        layerName: "",
82                        layerType: "Output",
83                        inputSlots: [
84                          {
85                            connection: {
86                              sourceLayerIndex: 1,
87                              outputSlotIndex: 0
88                            }
89                          }
90                        ],
91                        outputSlots: []
92                      }
93                    }
94                  }
95                }
96              ],
97          }
98      )";
99         Setup();
100     }
101 };
102
103 struct SimpleRankDimSize1Fixture : RankFixture
104 {
105     SimpleRankDimSize1Fixture() : RankFixture("[ 8 ]", "QSymmS16") {}
106 };
107
108 struct SimpleRankDimSize2Fixture : RankFixture
109 {
110     SimpleRankDimSize2Fixture() : RankFixture("[ 3, 3 ]", "QSymmS8") {}
111 };
112
113 struct SimpleRankDimSize3Fixture : RankFixture
114 {
115     SimpleRankDimSize3Fixture() : RankFixture("[ 2, 2, 1 ]", "Signed32") {}
116 };
117
118 struct SimpleRankDimSize4Fixture : RankFixture
119 {
120     SimpleRankDimSize4Fixture() : RankFixture("[ 2, 2, 1, 1 ]", "Float32") {}
121 };
122
123 BOOST_FIXTURE_TEST_CASE(RankDimSize1Float16, SimpleRankDimSize1Fixture)
124 {
125     RunTest<1, armnn::DataType::QSymmS16, armnn::DataType::Signed32>( 0,
126                                                                       { 1, 2, 3, 4, 5, 6, 7, 8 },
127                                                                       { 1 });
128 }
129
130 BOOST_FIXTURE_TEST_CASE(RankDimSize2QAsymmU8, SimpleRankDimSize2Fixture)
131 {
132     RunTest<1, armnn::DataType::QSymmS8, armnn::DataType::Signed32>( 0,
133                                                                     { 1, 2, 3, 4, 5, 6, 7, 8, 9 },
134                                                                     { 2 });
135 }
136
137 BOOST_FIXTURE_TEST_CASE(RankDimSize3Signed32, SimpleRankDimSize3Fixture)
138 {
139     RunTest<1, armnn::DataType::Signed32, armnn::DataType::Signed32>( 0,
140                                                                     { 111, 85, 226, 3 },
141                                                                     { 3 });
142 }
143
144 BOOST_FIXTURE_TEST_CASE(RankDimSize4Float32, SimpleRankDimSize4Fixture)
145 {
146     RunTest<1, armnn::DataType::Float32, armnn::DataType::Signed32>( 0,
147                                                                    { 111, 85, 226, 3 },
148                                                                    { 4 });
149 }
150
151 BOOST_AUTO_TEST_SUITE_END()