2 * Copyright (c) 2022 Samsung Electronics Co., Ltd. All Rights Reserved
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
8 * http://www.apache.org/licenses/LICENSE-2.0
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
17 #include "luci/Pass/PropagateQParamBackwardPass.h"
19 #include <luci/IR/CircleNodes.h>
21 #include <gtest/gtest.h>
28 void set_qparam(luci::CircleNode *node, float scale, int64_t zp)
30 auto qparam = std::make_unique<luci::CircleQuantParam>();
31 qparam->scale.emplace_back(scale);
32 qparam->zerop.emplace_back(zp);
34 node->quantparam(std::move(qparam));
38 * @brief Base Test Graph
43 virtual void init(void) = 0;
47 * Graph with two concats
49 * [CircleInput] [CircleConst]
51 * [CircleConcatenation] [CircleConst]
53 * [CircleConcatenation]
58 * - Concat1 and Concat 2 have different qparams
61 * - All Ops have the same qparam
63 struct SubsequentConcatGraph : public TestGraph
68 // graph input and output
69 auto graph_input = g.inputs()->create();
70 auto graph_output = g.outputs()->create();
73 input = g.nodes()->create<luci::CircleInput>();
74 input->index(graph_input->index());
75 input->shape({1, 4, 4, 3});
76 input->dtype(loco::DataType::U8);
77 set_qparam(input, 1.0, 1);
80 const1 = g.nodes()->create<luci::CircleConst>();
81 const1->shape({1, 4, 4, 3});
82 const1->dtype(loco::DataType::FLOAT32);
83 const1->size<loco::DataType::FLOAT32>(48);
84 for (uint32_t i = 0; i < 48; i++)
85 const1->at<loco::DataType::FLOAT32>(i) = i;
88 concat1 = g.nodes()->create<luci::CircleConcatenation>(2);
89 concat1->shape({1, 4, 4, 6});
90 concat1->dtype(loco::DataType::U8);
91 set_qparam(concat1, 2.0, 2);
92 concat1->values(0, input);
93 concat1->values(1, const1);
94 concat1->fusedActivationFunction(luci::FusedActFunc::NONE);
97 const2 = g.nodes()->create<luci::CircleConst>();
98 const2->shape({1, 4, 4, 3});
99 const2->dtype(loco::DataType::FLOAT32);
100 const2->size<loco::DataType::FLOAT32>(48);
101 for (uint32_t i = 0; i < 48; i++)
102 const2->at<loco::DataType::FLOAT32>(i) = i;
105 concat2 = g.nodes()->create<luci::CircleConcatenation>(2);
106 concat2->shape({1, 4, 4, 9});
107 concat2->dtype(loco::DataType::U8);
108 set_qparam(concat2, 3.0, 3);
109 concat2->values(0, concat1);
110 concat2->values(1, const2);
111 concat2->fusedActivationFunction(luci::FusedActFunc::NONE);
114 output = g.nodes()->create<luci::CircleOutput>();
115 output->index(graph_output->index());
116 output->from(concat2);
117 output->shape({1, 4, 4, 9});
118 output->dtype(loco::DataType::U8);
119 set_qparam(output, 3.0, 3);
124 CircleInput *input = nullptr;
125 CircleConcatenation *concat1 = nullptr;
126 CircleConcatenation *concat2 = nullptr;
127 CircleConst *const1 = nullptr;
128 CircleConst *const2 = nullptr;
129 CircleOutput *output = nullptr;
139 * [Reshape] (qparam 2)
149 * [Reshape] (qparam 2)
153 class ConvReshapeGraph
158 input = g.nodes()->create<luci::CircleInput>();
159 conv = g.nodes()->create<luci::CircleConv2D>();
160 reshape = g.nodes()->create<luci::CircleReshape>();
161 output = g.nodes()->create<luci::CircleOutput>();
163 auto graph_input = g.inputs()->create();
164 input->index(graph_input->index());
165 auto graph_output = g.outputs()->create();
166 output->index(graph_output->index());
168 set_qparam(conv, 2.0, 2);
169 set_qparam(reshape, 1.0, 1);
172 reshape->tensor(conv);
173 output->from(reshape);
178 luci::CircleInput *input = nullptr;
179 luci::CircleConv2D *conv = nullptr;
180 luci::CircleReshape *reshape = nullptr;
181 luci::CircleOutput *output = nullptr;
191 * +---------------------+
193 * [Reshape] (qparam 2) [Output]
197 * AFTER (qparam is not propagated as Conv has multiple users)
203 * +---------------------+
205 * [Reshape] (qparam 2) [Output]
209 class ConvReshapeMultiOutGraph
212 ConvReshapeMultiOutGraph()
214 input = g.nodes()->create<luci::CircleInput>();
215 conv = g.nodes()->create<luci::CircleConv2D>();
216 reshape = g.nodes()->create<luci::CircleReshape>();
217 output1 = g.nodes()->create<luci::CircleOutput>();
218 output2 = g.nodes()->create<luci::CircleOutput>();
220 auto graph_input = g.inputs()->create();
221 input->index(graph_input->index());
222 auto graph_output1 = g.outputs()->create();
223 output1->index(graph_output1->index());
224 auto graph_output2 = g.outputs()->create();
225 output2->index(graph_output2->index());
227 set_qparam(conv, 2.0, 2);
228 set_qparam(reshape, 1.0, 1);
231 reshape->tensor(conv);
232 output1->from(reshape);
238 luci::CircleInput *input = nullptr;
239 luci::CircleConv2D *conv = nullptr;
240 luci::CircleReshape *reshape = nullptr;
241 luci::CircleOutput *output1 = nullptr;
242 luci::CircleOutput *output2 = nullptr;
247 TEST(PropagateQParamBackwardPassTest, name)
249 luci::PropagateQParamBackwardPass pass(loco::DataType::U8);
250 auto const name = pass.name();
251 ASSERT_NE(nullptr, name);
254 TEST(PropagateQParamBackwardPassTest, subsequent_propagation)
256 SubsequentConcatGraph graph;
260 luci::PropagateQParamBackwardPass pass(loco::DataType::U8);
264 EXPECT_EQ(3.0, graph.concat2->quantparam()->scale[0]);
265 EXPECT_EQ(3, graph.concat2->quantparam()->zerop[0]);
267 auto const2 = loco::must_cast<CircleNode *>(graph.concat2->values(1));
268 EXPECT_EQ(3.0, const2->quantparam()->scale[0]);
269 EXPECT_EQ(3, const2->quantparam()->zerop[0]);
271 EXPECT_EQ(3.0, graph.concat1->quantparam()->scale[0]);
272 EXPECT_EQ(3, graph.concat1->quantparam()->zerop[0]);
274 auto const1 = loco::must_cast<CircleNode *>(graph.concat1->values(1));
275 EXPECT_EQ(3.0, const1->quantparam()->scale[0]);
276 EXPECT_EQ(3, const1->quantparam()->zerop[0]);
278 EXPECT_EQ(3.0, graph.input->quantparam()->scale[0]);
279 EXPECT_EQ(3, graph.input->quantparam()->zerop[0]);
282 TEST(PropagateQParamBackwardPassTest, reshape)
284 ConvReshapeGraph graph;
286 EXPECT_NE(graph.conv->quantparam()->scale, graph.reshape->quantparam()->scale);
287 EXPECT_NE(graph.conv->quantparam()->zerop, graph.reshape->quantparam()->zerop);
289 luci::PropagateQParamBackwardPass pass(loco::DataType::U8);
293 EXPECT_EQ(graph.conv->quantparam()->scale, graph.reshape->quantparam()->scale);
294 EXPECT_EQ(graph.conv->quantparam()->zerop, graph.reshape->quantparam()->zerop);
297 TEST(PropagateQParamBackwardPassTest, reshape_multi_use_NEG)
299 ConvReshapeMultiOutGraph graph;
301 EXPECT_NE(graph.conv->quantparam()->scale, graph.reshape->quantparam()->scale);
302 EXPECT_NE(graph.conv->quantparam()->zerop, graph.reshape->quantparam()->zerop);
304 luci::PropagateQParamBackwardPass pass(loco::DataType::U8);
308 EXPECT_NE(graph.conv->quantparam()->scale, graph.reshape->quantparam()->scale);
309 EXPECT_NE(graph.conv->quantparam()->zerop, graph.reshape->quantparam()->zerop);