Imported Upstream version 1.25.0
[platform/core/ml/nnfw.git] / compiler / luci / pass / src / PropagateQParamBackwardPass.test.cpp
1 /*
2  * Copyright (c) 2022 Samsung Electronics Co., Ltd. All Rights Reserved
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *    http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16
17 #include "luci/Pass/PropagateQParamBackwardPass.h"
18
19 #include <luci/IR/CircleNodes.h>
20
21 #include <gtest/gtest.h>
22
23 using namespace luci;
24
25 namespace
26 {
27
28 void set_qparam(luci::CircleNode *node, float scale, int64_t zp)
29 {
30   auto qparam = std::make_unique<luci::CircleQuantParam>();
31   qparam->scale.emplace_back(scale);
32   qparam->zerop.emplace_back(zp);
33
34   node->quantparam(std::move(qparam));
35 }
36
37 /**
38  * @brief Base Test Graph
39  */
40 struct TestGraph
41 {
42 public:
43   virtual void init(void) = 0;
44 };
45
46 /**
47  *  Graph with two concats
48  *
49  *  [CircleInput]  [CircleConst]
50  *         \         /
51  *  [CircleConcatenation]  [CircleConst]
52  *           |                |
53  *          [CircleConcatenation]
54  *                  |
55  *            [CircleOutput]
56  *
57  *  BEFORE
58  *  - Concat1 and Concat 2 have different qparams
59  *
60  *  AFTER
61  *  - All Ops have the same qparam
62  */
63 struct SubsequentConcatGraph : public TestGraph
64 {
65 public:
66   void init(void) final
67   {
68     // graph input and output
69     auto graph_input = g.inputs()->create();
70     auto graph_output = g.outputs()->create();
71
72     // input
73     input = g.nodes()->create<luci::CircleInput>();
74     input->index(graph_input->index());
75     input->shape({1, 4, 4, 3});
76     input->dtype(loco::DataType::U8);
77     set_qparam(input, 1.0, 1);
78
79     // const1
80     const1 = g.nodes()->create<luci::CircleConst>();
81     const1->shape({1, 4, 4, 3});
82     const1->dtype(loco::DataType::FLOAT32);
83     const1->size<loco::DataType::FLOAT32>(48);
84     for (uint32_t i = 0; i < 48; i++)
85       const1->at<loco::DataType::FLOAT32>(i) = i;
86
87     // concat1
88     concat1 = g.nodes()->create<luci::CircleConcatenation>(2);
89     concat1->shape({1, 4, 4, 6});
90     concat1->dtype(loco::DataType::U8);
91     set_qparam(concat1, 2.0, 2);
92     concat1->values(0, input);
93     concat1->values(1, const1);
94     concat1->fusedActivationFunction(luci::FusedActFunc::NONE);
95
96     // const2
97     const2 = g.nodes()->create<luci::CircleConst>();
98     const2->shape({1, 4, 4, 3});
99     const2->dtype(loco::DataType::FLOAT32);
100     const2->size<loco::DataType::FLOAT32>(48);
101     for (uint32_t i = 0; i < 48; i++)
102       const2->at<loco::DataType::FLOAT32>(i) = i;
103
104     // concat2
105     concat2 = g.nodes()->create<luci::CircleConcatenation>(2);
106     concat2->shape({1, 4, 4, 9});
107     concat2->dtype(loco::DataType::U8);
108     set_qparam(concat2, 3.0, 3);
109     concat2->values(0, concat1);
110     concat2->values(1, const2);
111     concat2->fusedActivationFunction(luci::FusedActFunc::NONE);
112
113     // output
114     output = g.nodes()->create<luci::CircleOutput>();
115     output->index(graph_output->index());
116     output->from(concat2);
117     output->shape({1, 4, 4, 9});
118     output->dtype(loco::DataType::U8);
119     set_qparam(output, 3.0, 3);
120   }
121
122 public:
123   loco::Graph g;
124   CircleInput *input = nullptr;
125   CircleConcatenation *concat1 = nullptr;
126   CircleConcatenation *concat2 = nullptr;
127   CircleConst *const1 = nullptr;
128   CircleConst *const2 = nullptr;
129   CircleOutput *output = nullptr;
130 };
131
132 /**
133  *  BEFORE
134  *
135  *        [Input]
136  *           |
137  *        [Conv] (qparam 1)
138  *           |
139  *       [Reshape] (qparam 2)
140  *           |
141  *       [Output]
142  *
143  *  AFTER
144  *
145  *        [Input]
146  *           |
147  *        [Conv] (qparam 2)
148  *           |
149  *       [Reshape] (qparam 2)
150  *           |
151  *       [Output]
152  */
153 class ConvReshapeGraph
154 {
155 public:
156   ConvReshapeGraph()
157   {
158     input = g.nodes()->create<luci::CircleInput>();
159     conv = g.nodes()->create<luci::CircleConv2D>();
160     reshape = g.nodes()->create<luci::CircleReshape>();
161     output = g.nodes()->create<luci::CircleOutput>();
162
163     auto graph_input = g.inputs()->create();
164     input->index(graph_input->index());
165     auto graph_output = g.outputs()->create();
166     output->index(graph_output->index());
167
168     set_qparam(conv, 2.0, 2);
169     set_qparam(reshape, 1.0, 1);
170
171     conv->input(input);
172     reshape->tensor(conv);
173     output->from(reshape);
174   }
175
176 public:
177   loco::Graph g;
178   luci::CircleInput *input = nullptr;
179   luci::CircleConv2D *conv = nullptr;
180   luci::CircleReshape *reshape = nullptr;
181   luci::CircleOutput *output = nullptr;
182 };
183
184 /**
185  *  BEFORE
186  *
187  *        [Input]
188  *           |
189  *        [Conv] (qparam 1)
190  *           |
191  *           +---------------------+
192  *           |                     |
193  *       [Reshape] (qparam 2)   [Output]
194  *           |
195  *       [Output]
196  *
197  *  AFTER (qparam is not propagated as Conv has multiple users)
198  *
199  *        [Input]
200  *           |
201  *        [Conv] (qparam 1)
202  *           |
203  *           +---------------------+
204  *           |                     |
205  *       [Reshape] (qparam 2)   [Output]
206  *           |
207  *       [Output]
208  */
209 class ConvReshapeMultiOutGraph
210 {
211 public:
212   ConvReshapeMultiOutGraph()
213   {
214     input = g.nodes()->create<luci::CircleInput>();
215     conv = g.nodes()->create<luci::CircleConv2D>();
216     reshape = g.nodes()->create<luci::CircleReshape>();
217     output1 = g.nodes()->create<luci::CircleOutput>();
218     output2 = g.nodes()->create<luci::CircleOutput>();
219
220     auto graph_input = g.inputs()->create();
221     input->index(graph_input->index());
222     auto graph_output1 = g.outputs()->create();
223     output1->index(graph_output1->index());
224     auto graph_output2 = g.outputs()->create();
225     output2->index(graph_output2->index());
226
227     set_qparam(conv, 2.0, 2);
228     set_qparam(reshape, 1.0, 1);
229
230     conv->input(input);
231     reshape->tensor(conv);
232     output1->from(reshape);
233     output2->from(conv);
234   }
235
236 public:
237   loco::Graph g;
238   luci::CircleInput *input = nullptr;
239   luci::CircleConv2D *conv = nullptr;
240   luci::CircleReshape *reshape = nullptr;
241   luci::CircleOutput *output1 = nullptr;
242   luci::CircleOutput *output2 = nullptr;
243 };
244
245 } // namespace
246
247 TEST(PropagateQParamBackwardPassTest, name)
248 {
249   luci::PropagateQParamBackwardPass pass(loco::DataType::U8);
250   auto const name = pass.name();
251   ASSERT_NE(nullptr, name);
252 }
253
254 TEST(PropagateQParamBackwardPassTest, subsequent_propagation)
255 {
256   SubsequentConcatGraph graph;
257
258   graph.init();
259
260   luci::PropagateQParamBackwardPass pass(loco::DataType::U8);
261
262   pass.run(&graph.g);
263
264   EXPECT_EQ(3.0, graph.concat2->quantparam()->scale[0]);
265   EXPECT_EQ(3, graph.concat2->quantparam()->zerop[0]);
266
267   auto const2 = loco::must_cast<CircleNode *>(graph.concat2->values(1));
268   EXPECT_EQ(3.0, const2->quantparam()->scale[0]);
269   EXPECT_EQ(3, const2->quantparam()->zerop[0]);
270
271   EXPECT_EQ(3.0, graph.concat1->quantparam()->scale[0]);
272   EXPECT_EQ(3, graph.concat1->quantparam()->zerop[0]);
273
274   auto const1 = loco::must_cast<CircleNode *>(graph.concat1->values(1));
275   EXPECT_EQ(3.0, const1->quantparam()->scale[0]);
276   EXPECT_EQ(3, const1->quantparam()->zerop[0]);
277
278   EXPECT_EQ(3.0, graph.input->quantparam()->scale[0]);
279   EXPECT_EQ(3, graph.input->quantparam()->zerop[0]);
280 }
281
282 TEST(PropagateQParamBackwardPassTest, reshape)
283 {
284   ConvReshapeGraph graph;
285
286   EXPECT_NE(graph.conv->quantparam()->scale, graph.reshape->quantparam()->scale);
287   EXPECT_NE(graph.conv->quantparam()->zerop, graph.reshape->quantparam()->zerop);
288
289   luci::PropagateQParamBackwardPass pass(loco::DataType::U8);
290
291   pass.run(&graph.g);
292
293   EXPECT_EQ(graph.conv->quantparam()->scale, graph.reshape->quantparam()->scale);
294   EXPECT_EQ(graph.conv->quantparam()->zerop, graph.reshape->quantparam()->zerop);
295 }
296
297 TEST(PropagateQParamBackwardPassTest, reshape_multi_use_NEG)
298 {
299   ConvReshapeMultiOutGraph graph;
300
301   EXPECT_NE(graph.conv->quantparam()->scale, graph.reshape->quantparam()->scale);
302   EXPECT_NE(graph.conv->quantparam()->zerop, graph.reshape->quantparam()->zerop);
303
304   luci::PropagateQParamBackwardPass pass(loco::DataType::U8);
305
306   pass.run(&graph.g);
307
308   EXPECT_NE(graph.conv->quantparam()->scale, graph.reshape->quantparam()->scale);
309   EXPECT_NE(graph.conv->quantparam()->zerop, graph.reshape->quantparam()->zerop);
310 }