2 * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
8 * http://www.apache.org/licenses/LICENSE-2.0
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
17 #include "luci/Pass/PropagateQuantParamPass.h"
19 #include <luci/IR/CircleNodes.h>
20 #include <luci/IR/CircleNodeVisitor.h>
28 bool copy_qparam(luci::CircleQuantParam *src, luci::CircleQuantParam *dst)
30 assert(src->scale.size() == dst->scale.size());
31 assert(src->zerop.size() == dst->zerop.size());
33 // src and dst have the same qparam
34 if (std::equal(src->scale.begin(), src->scale.end(), dst->scale.begin()) &&
35 std::equal(src->zerop.begin(), src->zerop.end(), dst->zerop.begin()) &&
36 src->quantized_dimension == dst->quantized_dimension)
39 dst->scale.assign(src->scale.begin(), src->scale.end());
40 dst->zerop.assign(src->zerop.begin(), src->zerop.end());
41 dst->quantized_dimension = src->quantized_dimension;
45 bool copy_qparam(luci::CircleNode *src, luci::CircleNode *dst)
47 // Skip nodes that do not have quantparams
48 auto src_qparam = src->quantparam();
52 auto dst_qparam = dst->quantparam();
56 return copy_qparam(src_qparam, dst_qparam);
59 // Visitor to propagate quantization parameters
60 struct PropagateQuantParam final : public luci::CircleNodeMutableVisitor<bool>
62 PropagateQuantParam() = default;
64 bool visit(luci::CircleNode *) { return false; }
66 bool visit(luci::CircleReshape *node)
68 auto input = node->tensor();
69 if (loco::succs(input).size() != 1)
72 auto input_node = loco::must_cast<luci::CircleNode *>(input);
73 return copy_qparam(input_node, node);
76 bool visit(luci::CircleTranspose *node)
78 auto input_node = loco::must_cast<luci::CircleNode *>(node->a());
79 return copy_qparam(input_node, node);
82 // TODO : Add more Ops (e.g., layout-changing Ops)
90 bool PropagateQuantParamPass::run(loco::Graph *g)
94 for (auto node : loco::active_nodes(loco::output_nodes(g)))
96 auto circle_node = loco::must_cast<luci::CircleNode *>(node);
97 INFO(l) << "PropagateQuantParamPass visit node: " << circle_node->name() << std::endl;
99 PropagateQuantParam pqp;
100 if (circle_node->accept(&pqp))