4ea745628ceb43e4ba654d9834b42a84bf1bea62
[platform/upstream/armnn.git] / src / armnnDeserializer / test / DeserializeMean.cpp
1 //
2 // Copyright © 2017 Arm Ltd. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5
6 #include <boost/test/unit_test.hpp>
7 #include "ParserFlatbuffersSerializeFixture.hpp"
8 #include "../Deserializer.hpp"
9
10 #include <string>
11 #include <iostream>
12
13 BOOST_AUTO_TEST_SUITE(Deserializer)
14
15 struct MeanFixture : public ParserFlatbuffersSerializeFixture
16 {
17     explicit MeanFixture(const std::string &inputShape,
18                          const std::string &outputShape,
19                          const std::string &axis,
20                          const std::string &dataType)
21     {
22         m_JsonString = R"(
23             {
24                 inputIds: [0],
25                 outputIds: [2],
26                 layers: [
27                     {
28                         layer_type: "InputLayer",
29                         layer: {
30                             base: {
31                                 layerBindingId: 0,
32                                 base: {
33                                     index: 0,
34                                     layerName: "InputLayer",
35                                     layerType: "Input",
36                                     inputSlots: [{
37                                         index: 0,
38                                         connection: {sourceLayerIndex:0, outputSlotIndex:0 },
39                                     }],
40                                     outputSlots: [{
41                                         index: 0,
42                                         tensorInfo: {
43                                             dimensions: )" + inputShape + R"(,
44                                             dataType: )" + dataType + R"(
45                                         }
46                                     }]
47                                 }
48                             }
49                         }
50                     },
51                     {
52                         layer_type: "MeanLayer",
53                         layer: {
54                             base: {
55                                 index: 1,
56                                 layerName: "MeanLayer",
57                                 layerType: "Mean",
58                                 inputSlots: [{
59                                     index: 0,
60                                     connection: {sourceLayerIndex:0, outputSlotIndex:0 },
61                                 }],
62                                 outputSlots: [{
63                                     index: 0,
64                                     tensorInfo: {
65                                         dimensions: )" + outputShape + R"(,
66                                         dataType: )" + dataType + R"(
67                                     }
68                                 }]
69                             },
70                             descriptor: {
71                                 axis: )" + axis + R"(,
72                                 keepDims: true
73                             }
74                         }
75                     },
76                     {
77                         layer_type: "OutputLayer",
78                         layer: {
79                             base:{
80                                 layerBindingId: 2,
81                                 base: {
82                                     index: 2,
83                                     layerName: "OutputLayer",
84                                     layerType: "Output",
85                                     inputSlots: [{
86                                         index: 0,
87                                         connection: {sourceLayerIndex:1, outputSlotIndex:0 },
88                                     }],
89                                     outputSlots: [{
90                                         index: 0,
91                                         tensorInfo: {
92                                             dimensions: )" + outputShape + R"(,
93                                             dataType: )" + dataType + R"(
94                                         },
95                                     }],
96                                 }
97                             }
98                         },
99                     }
100                 ]
101             }
102         )";
103         Setup();
104     }
105 };
106
107 struct SimpleMeanFixture : MeanFixture
108 {
109     SimpleMeanFixture()
110         : MeanFixture("[ 1, 1, 3, 2 ]",     // inputShape
111                       "[ 1, 1, 1, 2 ]",     // outputShape
112                       "[ 2 ]",              // axis
113                       "Float32")            // dataType
114     {}
115 };
116
117 BOOST_FIXTURE_TEST_CASE(SimpleMean, SimpleMeanFixture)
118 {
119     RunTest<4, armnn::DataType::Float32>(
120          0,
121          {{"InputLayer",  { 1.0f, 1.0f, 2.0f, 2.0f, 3.0f, 3.0f }}},
122          {{"OutputLayer", { 2.0f, 2.0f }}});
123 }
124
125 BOOST_AUTO_TEST_SUITE_END()