4873fd1d2fab3d1395af12237f858b41d2056031
[platform/upstream/armnn.git] / src / armnnDeserializer / test / DeserializeInstanceNormalization.cpp
1 //
2 // Copyright © 2019 Arm Ltd. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5
6 #include "ParserFlatbuffersSerializeFixture.hpp"
7 #include "../Deserializer.hpp"
8
9 #include <string>
10
11 #include <boost/test/unit_test.hpp>
12
13 BOOST_AUTO_TEST_SUITE(Deserializer)
14
15 struct InstanceNormalizationFixture : public ParserFlatbuffersSerializeFixture
16 {
17     explicit InstanceNormalizationFixture(const std::string &inputShape,
18                                           const std::string &outputShape,
19                                           const std::string &gamma,
20                                           const std::string &beta,
21                                           const std::string &epsilon,
22                                           const std::string &dataType,
23                                           const std::string &dataLayout)
24     {
25         m_JsonString = R"(
26     {
27         inputIds: [0],
28         outputIds: [2],
29         layers: [
30            {
31             layer_type: "InputLayer",
32             layer: {
33                 base: {
34                     layerBindingId: 0,
35                     base: {
36                         index: 0,
37                         layerName: "InputLayer",
38                         layerType: "Input",
39                         inputSlots: [{
40                             index: 0,
41                             connection: {sourceLayerIndex:0, outputSlotIndex:0 },
42                             }],
43                         outputSlots: [{
44                             index: 0,
45                             tensorInfo: {
46                                 dimensions: )" + inputShape + R"(,
47                                 dataType: ")" + dataType + R"(",
48                                 quantizationScale: 0.5,
49                                 quantizationOffset: 0
50                                 },
51                             }]
52                         },
53                     }
54                 },
55             },
56         {
57         layer_type: "InstanceNormalizationLayer",
58         layer : {
59             base: {
60                 index:1,
61                 layerName: "InstanceNormalizationLayer",
62                 layerType: "InstanceNormalization",
63                 inputSlots: [{
64                         index: 0,
65                         connection: {sourceLayerIndex:0, outputSlotIndex:0 },
66                    }],
67                 outputSlots: [{
68                     index: 0,
69                     tensorInfo: {
70                         dimensions: )" + outputShape + R"(,
71                         dataType: ")" + dataType + R"("
72                     },
73                     }],
74                 },
75             descriptor: {
76                 dataLayout: ")" + dataLayout + R"(",
77                 gamma: ")" + gamma + R"(",
78                 beta: ")" + beta + R"(",
79                 eps: )" + epsilon + R"(
80                 },
81             },
82         },
83         {
84         layer_type: "OutputLayer",
85         layer: {
86             base:{
87                 layerBindingId: 0,
88                 base: {
89                     index: 2,
90                     layerName: "OutputLayer",
91                     layerType: "Output",
92                     inputSlots: [{
93                         index: 0,
94                         connection: {sourceLayerIndex:1, outputSlotIndex:0 },
95                     }],
96                     outputSlots: [ {
97                         index: 0,
98                         tensorInfo: {
99                             dimensions: )" + outputShape + R"(,
100                             dataType: ")" + dataType + R"("
101                         },
102                     }],
103                 }
104             }},
105         }]
106     }
107 )";
108         SetupSingleInputSingleOutput("InputLayer", "OutputLayer");
109     }
110 };
111
112 struct InstanceNormalizationFloat32Fixture : InstanceNormalizationFixture
113 {
114     InstanceNormalizationFloat32Fixture():InstanceNormalizationFixture("[ 2, 2, 2, 2 ]",
115                                                                        "[ 2, 2, 2, 2 ]",
116                                                                        "1.0",
117                                                                        "0.0",
118                                                                        "0.0001",
119                                                                        "Float32",
120                                                                        "NHWC") {}
121 };
122
123 BOOST_FIXTURE_TEST_CASE(InstanceNormalizationFloat32, InstanceNormalizationFloat32Fixture)
124 {
125     RunTest<4, armnn::DataType::Float32>(
126         0,
127          {
128              0.f,  1.f,
129              0.f,  2.f,
130
131              0.f,  2.f,
132              0.f,  4.f,
133
134              1.f, -1.f,
135             -1.f,  2.f,
136
137             -1.f, -2.f,
138              1.f,  4.f
139         },
140         {
141              0.0000000f, -1.1470304f,
142              0.0000000f, -0.2294061f,
143
144              0.0000000f, -0.2294061f,
145              0.0000000f,  1.6058424f,
146
147              0.9999501f, -0.7337929f,
148             -0.9999501f,  0.5241377f,
149
150             -0.9999501f, -1.1531031f,
151              0.9999501f,  1.3627582f
152         });
153 }
154
155 BOOST_AUTO_TEST_SUITE_END()