Imported Upstream version 1.18.0
[platform/core/ml/nnfw.git] / compiler / luci / pass / src / PropagateQuantParamPass.cpp
1 /*
2  * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *    http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16
17 #include "luci/Pass/PropagateQuantParamPass.h"
18
19 #include <luci/IR/CircleNodes.h>
20 #include <luci/IR/CircleNodeVisitor.h>
21 #include <luci/Log.h>
22
23 #include <iostream>
24
25 namespace
26 {
27
28 bool copy_qparam(luci::CircleQuantParam *src, luci::CircleQuantParam *dst)
29 {
30   assert(src->scale.size() == dst->scale.size());
31   assert(src->zerop.size() == dst->zerop.size());
32
33   // src and dst have the same qparam
34   if (std::equal(src->scale.begin(), src->scale.end(), dst->scale.begin()) &&
35       std::equal(src->zerop.begin(), src->zerop.end(), dst->zerop.begin()) &&
36       src->quantized_dimension == dst->quantized_dimension)
37     return false;
38
39   dst->scale.assign(src->scale.begin(), src->scale.end());
40   dst->zerop.assign(src->zerop.begin(), src->zerop.end());
41   dst->quantized_dimension = src->quantized_dimension;
42   return true;
43 }
44
45 bool copy_qparam(luci::CircleNode *src, luci::CircleNode *dst)
46 {
47   // Skip nodes that do not have quantparams
48   auto src_qparam = src->quantparam();
49   if (not src_qparam)
50     return false;
51
52   auto dst_qparam = dst->quantparam();
53   if (not dst_qparam)
54     return false;
55
56   return copy_qparam(src_qparam, dst_qparam);
57 }
58
59 //  Visitor to propagate quantization parameters
60 struct PropagateQuantParam final : public luci::CircleNodeMutableVisitor<bool>
61 {
62   PropagateQuantParam() = default;
63
64   bool visit(luci::CircleNode *) { return false; }
65
66   bool visit(luci::CircleReshape *node)
67   {
68     auto input = node->tensor();
69     if (loco::succs(input).size() != 1)
70       return false;
71
72     auto input_node = loco::must_cast<luci::CircleNode *>(input);
73     return copy_qparam(input_node, node);
74   }
75
76   bool visit(luci::CircleTranspose *node)
77   {
78     auto input_node = loco::must_cast<luci::CircleNode *>(node->a());
79     return copy_qparam(input_node, node);
80   }
81
82   // TODO : Add more Ops (e.g., layout-changing Ops)
83 };
84
85 } // namespace
86
87 namespace luci
88 {
89
90 bool PropagateQuantParamPass::run(loco::Graph *g)
91 {
92   bool changed = false;
93   LOGGER(l);
94   for (auto node : loco::active_nodes(loco::output_nodes(g)))
95   {
96     auto circle_node = loco::must_cast<luci::CircleNode *>(node);
97     INFO(l) << "PropagateQuantParamPass visit node: " << circle_node->name() << std::endl;
98
99     PropagateQuantParam pqp;
100     if (circle_node->accept(&pqp))
101       changed = true;
102   }
103
104   return changed;
105 }
106
107 } // namespace luci