4361e5048581ed95c9df331bd3260771156108aa
[platform/upstream/armnn.git] / src / armnnDeserializer / test / DeserializeRsqrt.cpp
1 //
2 // Copyright © 2017 Arm Ltd. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5
6 #include <boost/test/unit_test.hpp>
7 #include "ParserFlatbuffersSerializeFixture.hpp"
8 #include "../Deserializer.hpp"
9
10 #include <string>
11 #include <iostream>
12
13 BOOST_AUTO_TEST_SUITE(Deserializer)
14
15 struct RsqrtFixture : public ParserFlatbuffersSerializeFixture
16 {
17     explicit RsqrtFixture(const std::string & inputShape,
18                           const std::string & outputShape,
19                           const std::string & dataType)
20     {
21         m_JsonString = R"(
22         {
23                 inputIds: [0],
24                 outputIds: [2],
25                 layers: [
26                 {
27                     layer_type: "InputLayer",
28                     layer: {
29                           base: {
30                                 layerBindingId: 0,
31                                 base: {
32                                     index: 0,
33                                     layerName: "InputLayer",
34                                     layerType: "Input",
35                                     inputSlots: [{
36                                         index: 0,
37                                         connection: {sourceLayerIndex:0, outputSlotIndex:0 },
38                                     }],
39                                     outputSlots: [ {
40                                         index: 0,
41                                         tensorInfo: {
42                                             dimensions: )" + inputShape + R"(,
43                                             dataType: )" + dataType + R"(
44                                         },
45                                     }],
46                                  },}},
47                 },
48                 {
49                 layer_type: "RsqrtLayer",
50                 layer : {
51                         base: {
52                              index:1,
53                              layerName: "RsqrtLayer",
54                              layerType: "Rsqrt",
55                              inputSlots: [
56                                             {
57                                              index: 0,
58                                              connection: {sourceLayerIndex:0, outputSlotIndex:0 },
59                                             }
60                              ],
61                              outputSlots: [ {
62                                  index: 0,
63                                  tensorInfo: {
64                                      dimensions: )" + outputShape + R"(,
65                                      dataType: )" + dataType + R"(
66                                  },
67                              }],
68                             }},
69                 },
70                 {
71                 layer_type: "OutputLayer",
72                 layer: {
73                         base:{
74                               layerBindingId: 0,
75                               base: {
76                                     index: 2,
77                                     layerName: "OutputLayer",
78                                     layerType: "Output",
79                                     inputSlots: [{
80                                         index: 0,
81                                         connection: {sourceLayerIndex:1, outputSlotIndex:0 },
82                                     }],
83                                     outputSlots: [ {
84                                         index: 0,
85                                         tensorInfo: {
86                                             dimensions: )" + outputShape + R"(,
87                                             dataType: )" + dataType + R"(
88                                         },
89                                 }],
90                             }}},
91                 }]
92          }
93         )";
94         Setup();
95     }
96 };
97
98
99 struct Rsqrt2dFixture : RsqrtFixture
100 {
101     Rsqrt2dFixture() : RsqrtFixture("[ 2, 2 ]",
102                                     "[ 2, 2 ]",
103                                     "Float32") {}
104 };
105
106 BOOST_FIXTURE_TEST_CASE(Rsqrt2d, Rsqrt2dFixture)
107 {
108   RunTest<2, armnn::DataType::Float32>(
109       0,
110       {{"InputLayer", { 1.0f,  4.0f,
111                         16.0f, 25.0f }}},
112       {{"OutputLayer",{ 1.0f,  0.5f,
113                         0.25f, 0.2f }}});
114 }
115
116
117 BOOST_AUTO_TEST_SUITE_END()