2 * Copyright (c) 2022 Samsung Electronics Co., Ltd. All Rights Reserved
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
8 * http://www.apache.org/licenses/LICENSE-2.0
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
17 #include "luci/Pass/PropagateQParamBackwardPass.h"
18 #include "QuantizationUtils.h"
20 #include <luci/IR/CircleNodes.h>
21 #include <luci/IR/CircleNodeVisitor.h>
22 #include <luci/Service/Nodes/CircleConst.h>
31 // Return true if node is a virtual node
32 bool virtual_op(const luci::CircleOpcode opcode)
36 #define CIRCLE_NODE(OPCODE, CIRCLE_CLASS) \
37 case luci::CircleOpcode::OPCODE: \
39 #define CIRCLE_VNODE(OPCODE, CIRCLE_CLASS) \
40 case luci::CircleOpcode::OPCODE: \
42 #include <luci/IR/CircleNodes.lst>
46 throw std::runtime_error("Unknown opcode detected");
50 void quant_const_values(luci::CircleConst *const_node, float scaling_factor, float zerop,
51 loco::DataType quant_type)
53 uint32_t size = const_node->size<loco::DataType::FLOAT32>();
55 const float scaling_factor_inv = 1.0 / scaling_factor;
56 std::vector<int32_t> quantized_values(size);
57 for (uint32_t i = 0; i < size; ++i)
59 auto data = static_cast<double>(const_node->at<loco::DataType::FLOAT32>(i));
60 double quantized_data = std::round(data * scaling_factor_inv) + zerop;
61 constexpr double int_max = static_cast<double>(std::numeric_limits<int32_t>::max());
62 constexpr double int_min = static_cast<double>(std::numeric_limits<int32_t>::min());
63 quantized_data = std::min(int_max, std::max(int_min, quantized_data));
65 quantized_values[i] = static_cast<int32_t>(quantized_data);
70 case loco::DataType::U8:
71 const_node->dtype(loco::DataType::U8); // change the type of tensor
72 const_node->size<loco::DataType::U8>(size); // resize tensor
73 for (uint32_t i = 0; i < size; ++i)
74 const_node->at<loco::DataType::U8>(i) = std::min(255, std::max(0, quantized_values[i]));
76 case loco::DataType::S16:
78 const_node->dtype(loco::DataType::S16); // change the type of tensor
79 const_node->size<loco::DataType::S16>(size); // resize tensor
80 for (uint32_t i = 0; i < size; ++i)
81 const_node->at<loco::DataType::S16>(i) =
82 std::min(32767, std::max(-32767, quantized_values[i]));
85 throw std::runtime_error("Unsupported data type");
89 void overwrite_quantparam(const luci::CircleNode *source, luci::CircleNode *target)
91 auto source_qparam = source->quantparam();
92 if (source_qparam == nullptr)
93 throw std::runtime_error("source quantparam is not found during overwrite");
95 auto target_qparam = target->quantparam();
96 if (target_qparam == nullptr)
98 auto quantparam = std::make_unique<luci::CircleQuantParam>();
99 target->quantparam(std::move(quantparam));
100 target_qparam = target->quantparam();
102 if (target_qparam == nullptr)
103 throw std::runtime_error("Creating new quant param failed");
105 target_qparam->min = source_qparam->min;
106 target_qparam->max = source_qparam->max;
107 target_qparam->scale = source_qparam->scale;
108 target_qparam->zerop = source_qparam->zerop;
109 target_qparam->quantized_dimension = source_qparam->quantized_dimension;
113 * Tells if pad_v2 quantization should ignore padding value
114 * In that case padding const will be quantized with input parameters, and probably clipped
116 bool ignore_pad_v2_const_quantization(const luci::CirclePadV2 *pad)
118 // This is a workaround to quantize pad generated from MaxPoolWithArgmax operation properly
119 // TODO use metadata hints to detect this case
120 auto const_value_node = dynamic_cast<const luci::CircleConst *>(pad->arg(2));
121 if (!const_value_node)
123 if (const_value_node->dtype() == loco::DataType::FLOAT32)
125 float const_value = const_value_node->at<loco::DataType::FLOAT32>(0);
126 if (const_value == std::numeric_limits<float>::lowest())
136 * [CircleNode] [CircleConst]
145 * [CircleNode] [CircleConst] [CircleConst] <- Dead node
146 * (qparam2) (qparam2) (FP32)
152 * NOTE Quantization parameter of CirclePack (qparam2) is propagated to the inputs.
154 void propagate_pack_quantparam(luci::CirclePack *pack)
156 assert(pack->quantparam() != nullptr);
158 const auto num_inputs = pack->values_count();
160 for (uint32_t i = 0; i < num_inputs; i++)
162 auto node = loco::must_cast<luci::CircleNode *>(pack->arg(i));
164 // Quantize constant values
165 if (node->opcode() == luci::CircleOpcode::CIRCLECONST)
167 luci::CircleConst *const_node = loco::must_cast<luci::CircleConst *>(node);
168 if (const_node->dtype() != loco::DataType::FLOAT32)
169 throw std::runtime_error("Unsupported data type for constant input of pack Op");
171 const auto pack_qparam = pack->quantparam();
172 if (pack_qparam == nullptr)
173 throw std::runtime_error("quantparam of pack is not found during propagation");
175 assert(pack_qparam->scale.size() == 1);
176 assert(pack_qparam->zerop.size() == 1);
177 const auto scaling_factor = pack_qparam->scale[0];
178 const auto zerop = pack_qparam->zerop[0];
180 auto new_const = luci::clone(const_node);
181 quant_const_values(new_const, scaling_factor, zerop, pack->dtype());
182 pack->values(i, new_const);
183 overwrite_quantparam(pack, new_const);
187 const auto succs = loco::succs(node);
188 if (succs.size() > 1)
191 // Non-const input must have been quantized
192 assert(node->quantparam() != nullptr);
193 overwrite_quantparam(pack, node);
204 * [CircleNode] [CircleConst] [CircleConst] [CircleNode]
205 * (S32) (S32) (FP32) (U8 qparam1)
209 * -------[CircleOneHot]-------
214 * [CircleNode] [CircleConst] [CircleConst] [CircleNode] [CircleConst] <- Dead node
215 * (S32) (S32) (U8 qparam2) (U8 qparam2) (FP32)
219 * -------[CircleOneHot]-------
222 * NOTE Quantization parameter of CircleOneHot (qparam2) is propagated to on_value/off_value.
224 void propagate_one_hot_quantparam(luci::CircleOneHot *one_hot)
226 assert(one_hot->quantparam() != nullptr);
228 // Propagate quantization parameters from output to inputs,
229 // to fit both input and counstant_value in one quant range.
230 auto quant_input = [one_hot](void (luci::CircleOneHot::*arg_setter)(loco::Node *),
231 loco::Node *(luci::CircleOneHot::*arg_getter)() const) {
232 auto node = loco::must_cast<luci::CircleNode *>((one_hot->*arg_getter)());
234 // Quantize constant values
235 if (node->opcode() == luci::CircleOpcode::CIRCLECONST)
237 luci::CircleConst *const_node = loco::must_cast<luci::CircleConst *>(node);
238 if (is_quantized(const_node))
241 if (const_node->dtype() != loco::DataType::FLOAT32)
242 throw std::runtime_error("Unsupported data type for constant input of OneHot Op");
244 const auto qparam = one_hot->quantparam();
245 if (qparam == nullptr)
246 throw std::runtime_error("quantparam of OneHot is not found during propagation");
248 assert(qparam->scale.size() == 1);
249 const auto scaling_factor = qparam->scale.at(0);
250 const auto zerop = qparam->zerop.at(0);
252 auto new_const = luci::clone(const_node);
253 quant_const_values(new_const, scaling_factor, zerop, one_hot->dtype());
254 overwrite_quantparam(one_hot, new_const);
255 (one_hot->*arg_setter)(new_const);
259 const auto succs = loco::succs(node);
260 if (succs.size() > 1)
263 // Non-const input must have been quantized
264 assert(node->quantparam() != nullptr);
265 overwrite_quantparam(one_hot, node);
269 quant_input(&luci::CircleOneHot::on_value, &luci::CircleOneHot::on_value);
270 quant_input(&luci::CircleOneHot::off_value, &luci::CircleOneHot::off_value);
280 * [CircleNode] [CircleConst]
281 * (U8 qparam1) (FP32)
284 * [CircleConcatenation]
288 * [CircleNode] [CircleConst] [CircleConst] <- Dead node
289 * (U8 qparam2) (U8 qparam2) (FP32)
292 * [CircleConcatenation]
295 void propagate_concat_quantparam(luci::CircleConcatenation *concat)
297 assert(concat->quantparam() != nullptr);
299 const auto num_inputs = concat->numValues();
301 // Quantize const inputs using their values if concat has fused act function
302 if (concat->fusedActivationFunction() != luci::FusedActFunc::NONE)
304 for (uint32_t i = 0; i < num_inputs; i++)
306 auto node = concat->arg(i);
307 auto const_node = dynamic_cast<luci::CircleConst *>(node);
308 if (const_node != nullptr)
310 auto new_const = luci::clone(const_node);
311 quant_const(new_const, concat->dtype());
312 concat->values(i, new_const);
318 for (uint32_t i = 0; i < num_inputs; i++)
320 auto node = loco::must_cast<luci::CircleNode *>(concat->arg(i));
322 // Quantize constant values
323 if (node->opcode() == luci::CircleOpcode::CIRCLECONST)
325 luci::CircleConst *const_node = loco::must_cast<luci::CircleConst *>(node);
327 const auto concat_qparam = concat->quantparam();
328 assert(concat_qparam->scale.size() == 1);
329 const auto scaling_factor = concat_qparam->scale[0];
330 const auto zerop = concat_qparam->zerop[0];
332 auto new_const = luci::clone(const_node);
333 quant_const_values(new_const, scaling_factor, zerop, concat->dtype());
334 concat->values(i, new_const);
335 overwrite_quantparam(concat, new_const);
339 const auto succs = loco::succs(node);
340 if (succs.size() > 1)
343 // Non-const input must have been quantized
344 assert(node->quantparam() != nullptr);
345 overwrite_quantparam(concat, node);
352 * [CircleNode] [CircleConst] [CircleConst]
353 * (U8 qparam1) (S32) (FP32)
361 * By default qparam is propagated from output to inputs to meet backend requirements.
363 * [CircleNode] [CircleConst] [CircleConst] [CircleConst] <- Dead node
364 * (U8 qparam2) (S32) (U8 qparam2) (FP32)
372 * In case padded value is the lowest float value
373 * Qparam is propagated from input to output and constant.
375 * This is a special case for optimization constructed pad, needed to guarantee that
376 * extremely large negative constant do not stretch output quantization range.
378 * [CircleNode] [CircleConst] [CircleConst] [CircleConst] <- Dead node
379 * (U8 qparam1) (S32) (U8 qparam1) (FP32)
385 void propagate_pad_v2_quantparam(luci::CirclePadV2 *pad_v2)
387 if (ignore_pad_v2_const_quantization(pad_v2))
389 // propagate input quantization paramters from input to output and padding const value
390 auto pad_v2_input = loco::must_cast<luci::CircleNode *>(pad_v2->arg(0));
391 overwrite_quantparam(pad_v2_input, pad_v2);
393 auto const_value_node = loco::must_cast<luci::CircleConst *>(
394 pad_v2->arg(2)); // FIX ignore_pad_v2_const_quantization UNLESS
395 auto new_const = luci::clone(const_value_node);
397 const auto pad_v2_input_qparam = pad_v2_input->quantparam();
398 assert(pad_v2_input_qparam != nullptr);
399 assert(pad_v2_input_qparam->scale.size() == 1);
400 const auto scaling_factor = pad_v2_input_qparam->scale.at(0);
401 const auto zerop = pad_v2_input_qparam->zerop.at(0);
403 quant_const_values(new_const, scaling_factor, zerop, pad_v2->dtype());
404 overwrite_quantparam(pad_v2_input, new_const);
405 pad_v2->constant_values(new_const);
409 // Propagate quantization paramters from output to inputs,
410 // to fit both input and counstant_value in one quant range.
411 auto quant_input = [pad_v2](void (CirclePadV2::*arg_setter)(loco::Node *), uint32_t arg) {
412 auto node = loco::must_cast<luci::CircleNode *>(pad_v2->arg(arg));
414 // Quantize constant values
415 if (node->opcode() == luci::CircleOpcode::CIRCLECONST)
417 luci::CircleConst *const_node = loco::must_cast<luci::CircleConst *>(node);
418 if (is_quantized(const_node))
421 if (const_node->dtype() != loco::DataType::FLOAT32)
422 throw std::runtime_error("Unsupported data type for constant input of PadV2 Op");
424 const auto pad_v2_qparam = pad_v2->quantparam();
425 if (pad_v2_qparam == nullptr)
426 throw std::runtime_error("quantparam of PadV2 is not found during propagation");
428 assert(pad_v2_qparam->scale.size() == 1);
429 const auto scaling_factor = pad_v2_qparam->scale.at(0);
430 const auto zerop = pad_v2_qparam->zerop.at(0);
432 auto new_const = luci::clone(const_node);
433 quant_const_values(new_const, scaling_factor, zerop, pad_v2->dtype());
434 overwrite_quantparam(pad_v2, new_const);
435 (pad_v2->*arg_setter)(new_const);
439 const auto succs = loco::succs(node);
440 if (succs.size() > 1)
443 // Non-const input must have been quantized
444 assert(node->quantparam() != nullptr);
445 overwrite_quantparam(pad_v2, node);
449 quant_input(&CirclePadV2::input, 0);
450 quant_input(&CirclePadV2::constant_values, 2);
458 // Visitor to propagate quantization parameters backwards
459 struct PropagateQParamBackward final : public luci::CircleNodeMutableVisitor<void>
461 void visit(luci::CircleNode *) {}
463 void visit(luci::CircleConcatenation *node) { propagate_concat_quantparam(node); }
465 void visit(luci::CircleOneHot *node) { propagate_one_hot_quantparam(node); }
467 void visit(luci::CirclePack *node) { propagate_pack_quantparam(node); }
469 void visit(luci::CirclePadV2 *node) { propagate_pad_v2_quantparam(node); }
471 // Propagate qparam for non-value changing Ops
472 // (ex: Reshape, Transpose, etc.)
475 void visit(luci::CircleReshape *node)
477 auto input_node = loco::must_cast<luci::CircleNode *>(node->tensor());
479 // Do not propagate qparam if input node has multiple users
480 if (loco::succs(input_node).size() > 1)
483 const auto input_opcode = input_node->opcode();
485 // Do not propagate qparam if input node is virtual Op (except CIRCLEINPUT)
486 // Why? It is not safe to propagate qparam to some virtual nodes. For example,
487 // const node, multi-out nodes. Let's block them for now.
488 // TODO Revisit this condition
489 if (virtual_op(input_opcode) and input_opcode != luci::CircleOpcode::CIRCLEINPUT)
492 overwrite_quantparam(node, input_node);
495 void visit(luci::CircleTranspose *node)
497 auto input_node = loco::must_cast<luci::CircleNode *>(node->a());
499 // Do not propagate qparam if input node has multiple users
500 if (loco::succs(input_node).size() > 1)
503 const auto input_opcode = input_node->opcode();
505 // Do not propagate qparam if input node is virtual Op (except CIRCLEINPUT)
506 // Why? It is not safe to propagate qparam to some virtual nodes. For example,
507 // const node, multi-out nodes. Let's block them for now.
508 // TODO Revisit this condition
509 if (virtual_op(input_opcode) and input_opcode != luci::CircleOpcode::CIRCLEINPUT)
512 overwrite_quantparam(node, input_node);
521 bool PropagateQParamBackwardPass::run(loco::Graph *g)
525 // We use reverse post-order traversal as qparam is propagated backward
526 auto nodes = loco::postorder_traversal(loco::output_nodes(g));
527 std::reverse(nodes.begin(), nodes.end());
528 for (auto node : nodes)
530 auto circle_node = loco::must_cast<luci::CircleNode *>(node);
531 INFO(l) << "PropagateQParamBackwardPass visit node: " << circle_node->name() << std::endl;
533 // We can't propagate non-existent qparam
534 if (circle_node->quantparam() == nullptr)
537 PropagateQParamBackward pqb;
538 circle_node->accept(&pqb);
541 // This pass is only run once, so return false
542 // TODO Refactoring not to return meaningless value