Imported Upstream version 1.12.0
[platform/core/ml/nnfw.git] / compiler / luci / pass / src / PropagateQuantParamPass.cpp
1 /*
2  * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *    http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16
17 #include "luci/Pass/PropagateQuantParamPass.h"
18
19 #include <luci/IR/CircleNodes.h>
20 #include <luci/IR/CircleNodeVisitor.h>
21 #include <luci/Log.h>
22
23 #include <iostream>
24
25 namespace
26 {
27
28 bool copy_qparam(luci::CircleQuantParam *src, luci::CircleQuantParam *dst)
29 {
30   assert(src->scale.size() == dst->scale.size());
31   assert(src->zerop.size() == dst->zerop.size());
32
33   // src and dst have the same qparam
34   if (std::equal(src->scale.begin(), src->scale.end(), dst->scale.begin()) &&
35       std::equal(src->zerop.begin(), src->zerop.end(), dst->zerop.begin()) &&
36       src->quantized_dimension == dst->quantized_dimension)
37     return false;
38
39   dst->scale.assign(src->scale.begin(), src->scale.end());
40   dst->zerop.assign(src->zerop.begin(), src->zerop.end());
41   dst->quantized_dimension = src->quantized_dimension;
42   return true;
43 }
44
45 bool copy_qparam(luci::CircleNode *src, luci::CircleNode *dst)
46 {
47   // Skip nodes that do not have quantparams
48   auto src_qparam = src->quantparam();
49   if (not src_qparam)
50     return false;
51
52   auto dst_qparam = dst->quantparam();
53   if (not dst_qparam)
54     return false;
55
56   return copy_qparam(src_qparam, dst_qparam);
57 }
58
59 //  Visitor to propagate quantization parameters
60 struct PropagateQuantParam final : public luci::CircleNodeMutableVisitor<bool>
61 {
62   PropagateQuantParam() = default;
63
64   bool visit(luci::CircleNode *) { return false; }
65
66   bool visit(luci::CircleReshape *node)
67   {
68     auto input = node->tensor();
69     if (loco::succs(input).size() != 1)
70       return false;
71
72     auto input_node = loco::must_cast<luci::CircleNode *>(input);
73     return copy_qparam(node, input_node);
74   }
75
76   // TODO : Add more Ops (e.g., Transpose)
77 };
78
79 } // namespace
80
81 namespace luci
82 {
83
84 bool PropagateQuantParamPass::run(loco::Graph *g)
85 {
86   bool changed = false;
87   LOGGER(l);
88   for (auto node : loco::active_nodes(loco::output_nodes(g)))
89   {
90     auto circle_node = loco::must_cast<luci::CircleNode *>(node);
91     INFO(l) << "PropagateQuantParamPass visit node: " << circle_node->name() << std::endl;
92
93     PropagateQuantParam pqp;
94     changed = circle_node->accept(&pqp);
95     if (changed)
96       break;
97   }
98
99   return changed;
100 }
101
102 } // namespace luci