554b867db759b42433c019b7b533aeab9e0877a7
[platform/upstream/armnn.git] / src / armnnDeserializer / test / DeserializeReshape.cpp
1 //
2 // Copyright © 2017 Arm Ltd. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5
6 #include <boost/test/unit_test.hpp>
7 #include "ParserFlatbuffersSerializeFixture.hpp"
8 #include "../Deserializer.hpp"
9
10 #include <string>
11 #include <iostream>
12
13 BOOST_AUTO_TEST_SUITE(Deserializer)
14
15 struct ReshapeFixture : public ParserFlatbuffersSerializeFixture
16 {
17     explicit ReshapeFixture(const std::string &inputShape,
18                             const std::string &targetShape,
19                             const std::string &outputShape,
20                             const std::string &dataType)
21     {
22         m_JsonString = R"(
23         {
24                 inputIds: [0],
25                 outputIds: [2],
26                 layers: [
27                 {
28                     layer_type: "InputLayer",
29                     layer: {
30                           base: {
31                                 layerBindingId: 0,
32                                 base: {
33                                     index: 0,
34                                     layerName: "InputLayer",
35                                     layerType: "Input",
36                                     inputSlots: [{
37                                         index: 0,
38                                         connection: {sourceLayerIndex:0, outputSlotIndex:0 },
39                                     }],
40                                     outputSlots: [ {
41                                         index: 0,
42                                         tensorInfo: {
43                                             dimensions: )" + inputShape + R"(,
44                                             dataType: )" + dataType + R"(
45                                             }}]
46                                     }
47                     }}},
48                     {
49                     layer_type: "ReshapeLayer",
50                     layer: {
51                           base: {
52                                index: 1,
53                                layerName: "ReshapeLayer",
54                                layerType: "Reshape",
55                                inputSlots: [{
56                                       index: 0,
57                                       connection: {sourceLayerIndex:0, outputSlotIndex:0 },
58                                }],
59                                outputSlots: [ {
60                                       index: 0,
61                                       tensorInfo: {
62                                            dimensions: )" + inputShape + R"(,
63                                            dataType: )" + dataType + R"(
64
65                                }}]},
66                           descriptor: {
67                                targetShape: )" + targetShape + R"(,
68                                }
69
70                     }},
71                     {
72                     layer_type: "OutputLayer",
73                     layer: {
74                         base:{
75                               layerBindingId: 2,
76                               base: {
77                                     index: 2,
78                                     layerName: "OutputLayer",
79                                     layerType: "Output",
80                                     inputSlots: [{
81                                         index: 0,
82                                         connection: {sourceLayerIndex:0, outputSlotIndex:0 },
83                                     }],
84                                     outputSlots: [ {
85                                         index: 0,
86                                         tensorInfo: {
87                                             dimensions: )" + outputShape + R"(,
88                                             dataType: )" + dataType + R"(
89                                         },
90                                 }],
91                             }}},
92                 }]
93          }
94      )";
95      SetupSingleInputSingleOutput("InputLayer", "OutputLayer");
96     }
97 };
98
99 struct SimpleReshapeFixture : ReshapeFixture
100 {
101     SimpleReshapeFixture() : ReshapeFixture("[ 1, 9 ]", "[ 3, 3 ]", "[ 3, 3 ]",
102                                             "QuantisedAsymm8") {}
103 };
104
105 struct SimpleReshapeFixture2 : ReshapeFixture
106 {
107     SimpleReshapeFixture2() : ReshapeFixture("[ 2, 2, 1, 1 ]",
108                                              "[ 2, 2, 1, 1 ]",
109                                              "[ 2, 2, 1, 1 ]",
110                                              "Float32") {}
111 };
112
113 BOOST_FIXTURE_TEST_CASE(ReshapeQuantisedAsymm8, SimpleReshapeFixture)
114 {
115     RunTest<2, armnn::DataType::QAsymmU8>(0,
116                                                 { 1, 2, 3, 4, 5, 6, 7, 8, 9 },
117                                                 { 1, 2, 3, 4, 5, 6, 7, 8, 9 });
118 }
119
120 BOOST_FIXTURE_TEST_CASE(ReshapeFloat32, SimpleReshapeFixture2)
121 {
122     RunTest<4, armnn::DataType::Float32>(0,
123                                         { 111, 85, 226, 3 },
124                                         { 111, 85, 226, 3 });
125 }
126
127
128 BOOST_AUTO_TEST_SUITE_END()