IVGCVSW-5593 Implement Pimpl Idiom for serialization classes
[platform/upstream/armnn.git] / src / armnnDeserializer / test / DeserializeRsqrt.cpp
1 //
2 // Copyright © 2017 Arm Ltd. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5
6 #include <boost/test/unit_test.hpp>
7 #include "ParserFlatbuffersSerializeFixture.hpp"
8 #include <armnnDeserializer/IDeserializer.hpp>
9
10 #include <string>
11
12 BOOST_AUTO_TEST_SUITE(Deserializer)
13
14 struct RsqrtFixture : public ParserFlatbuffersSerializeFixture
15 {
16     explicit RsqrtFixture(const std::string & inputShape,
17                           const std::string & outputShape,
18                           const std::string & dataType)
19     {
20         m_JsonString = R"(
21         {
22                 inputIds: [0],
23                 outputIds: [2],
24                 layers: [
25                 {
26                     layer_type: "InputLayer",
27                     layer: {
28                           base: {
29                                 layerBindingId: 0,
30                                 base: {
31                                     index: 0,
32                                     layerName: "InputLayer",
33                                     layerType: "Input",
34                                     inputSlots: [{
35                                         index: 0,
36                                         connection: {sourceLayerIndex:0, outputSlotIndex:0 },
37                                     }],
38                                     outputSlots: [ {
39                                         index: 0,
40                                         tensorInfo: {
41                                             dimensions: )" + inputShape + R"(,
42                                             dataType: )" + dataType + R"(
43                                         },
44                                     }],
45                                  },}},
46                 },
47                 {
48                 layer_type: "RsqrtLayer",
49                 layer : {
50                         base: {
51                              index:1,
52                              layerName: "RsqrtLayer",
53                              layerType: "Rsqrt",
54                              inputSlots: [
55                                             {
56                                              index: 0,
57                                              connection: {sourceLayerIndex:0, outputSlotIndex:0 },
58                                             }
59                              ],
60                              outputSlots: [ {
61                                  index: 0,
62                                  tensorInfo: {
63                                      dimensions: )" + outputShape + R"(,
64                                      dataType: )" + dataType + R"(
65                                  },
66                              }],
67                             }},
68                 },
69                 {
70                 layer_type: "OutputLayer",
71                 layer: {
72                         base:{
73                               layerBindingId: 0,
74                               base: {
75                                     index: 2,
76                                     layerName: "OutputLayer",
77                                     layerType: "Output",
78                                     inputSlots: [{
79                                         index: 0,
80                                         connection: {sourceLayerIndex:1, outputSlotIndex:0 },
81                                     }],
82                                     outputSlots: [ {
83                                         index: 0,
84                                         tensorInfo: {
85                                             dimensions: )" + outputShape + R"(,
86                                             dataType: )" + dataType + R"(
87                                         },
88                                 }],
89                             }}},
90                 }]
91          }
92         )";
93         Setup();
94     }
95 };
96
97
98 struct Rsqrt2dFixture : RsqrtFixture
99 {
100     Rsqrt2dFixture() : RsqrtFixture("[ 2, 2 ]",
101                                     "[ 2, 2 ]",
102                                     "Float32") {}
103 };
104
105 BOOST_FIXTURE_TEST_CASE(Rsqrt2d, Rsqrt2dFixture)
106 {
107   RunTest<2, armnn::DataType::Float32>(
108       0,
109       {{"InputLayer", { 1.0f,  4.0f,
110                         16.0f, 25.0f }}},
111       {{"OutputLayer",{ 1.0f,  0.5f,
112                         0.25f, 0.2f }}});
113 }
114
115
116 BOOST_AUTO_TEST_SUITE_END()