Imported Upstream version 1.18.0
[platform/core/ml/nnfw.git] / compiler / luci / pass / src / ForceQuantParamPass.cpp
1 /*
2  * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
3  * Licensed under the Apache License, Version 2.0 (the "License");
4  * you may not use this file except in compliance with the License.
5  * You may obtain a copy of the License at
6  *
7  *    http://www.apache.org/licenses/LICENSE-2.0
8  *
9  * Unless required by applicable law or agreed to in writing, software
10  * distributed under the License is distributed on an "AS IS" BASIS,
11  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12  * See the License for the specific language governing permissions and
13  * limitations under the License.
14  */
15
16 #include "luci/Pass/ForceQuantParamPass.h"
17 #include "luci/Profile/CircleNodeID.h"
18
19 #include <luci/IR/CircleNodes.h>
20 #include <luci/Log.h>
21
22 namespace luci
23 {
24
25 namespace
26 {
27
28 void set_qparam(luci::CircleNode *node, float scale, int64_t zp)
29 {
30   assert(node); // FIX_CALLER_UNLESS
31
32   auto quantparam = std::make_unique<CircleQuantParam>();
33   quantparam->scale.push_back(scale);
34   quantparam->zerop.push_back(zp);
35
36   node->quantparam(std::move(quantparam));
37 }
38
39 } // namespace
40
41 bool ForceQuantParamPass::run(loco::Graph *g)
42 {
43   LOGGER(l);
44   INFO(l) << "ForceQuantParamPass Start" << std::endl;
45
46   for (auto node : loco::active_nodes(loco::output_nodes(g)))
47   {
48     auto const cnode = loco::must_cast<CircleNode *>(node);
49     auto const name = cnode->name();
50     auto target = std::find(_tensors.begin(), _tensors.end(), name);
51     if (target == _tensors.end())
52       continue;
53
54     auto index = target - _tensors.begin();
55     auto scale = _scales[index];
56     auto zp = _zerops[index];
57     set_qparam(cnode, scale, zp);
58
59     _tensors.erase(_tensors.begin() + index);
60     _scales.erase(_scales.begin() + index);
61     _zerops.erase(_zerops.begin() + index);
62   }
63
64   if (_tensors.size() > 0)
65   {
66     std::string msg;
67     for (auto const &t : _tensors)
68       msg += "Tensor does not exist: " + t + ".\n";
69     msg += "Please check tensor name.\n";
70     throw std::runtime_error(msg);
71   }
72
73   INFO(l) << "ForceQuantParamPass End" << std::endl;
74   return false; // one time run
75 }
76
77 } // namespace luci