Imported Upstream version 1.18.0
[platform/core/ml/nnfw.git] / compiler / luci / pass / src / ResolveCustomOpAddPass.cpp
1 /*
2  * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *    http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16
17 #include "luci/Pass/ResolveCustomOpAddPass.h"
18
19 #include <luci/IR/CircleNodes.h>
20 #include <luci/IR/AttrFusedActFunc.h>
21 #include <luci/Profile/CircleNodeOrigin.h>
22
23 #include <flatbuffers/flexbuffers.h>
24
25 namespace
26 {
27
28 /// @brief Returns the index of BroadcastTo node among cop's inputs.
29 // NOTE This function assumes there is only one BroadcastTo node among its inputs.
30 int32_t get_broadcastTo_index_among_inputs_of(luci::CircleCustom *cop)
31 {
32   for (uint32_t idx = 0; idx < cop->numInputs(); idx++)
33   {
34     auto input = dynamic_cast<const luci::CircleCustomOut *>(cop->inputs(idx));
35     if (input)
36     {
37       auto broadcastTo = loco::must_cast<luci::CircleCustom *>(input->input());
38       if (broadcastTo->custom_code() == "BroadcastTo")
39         return idx;
40     }
41   }
42
43   return -1;
44 }
45
46 /** BEFORE
47  *                                  [CircleConst]
48  *                                        |
49  *        [CircleNode]         [BroadcastTo(CircleCustom)]
50  *              \                         |
51  *               \                [CircleCustomOUt]
52  *                \                   /
53  *               [AddV2(CircleCustom)]
54  *  AFTER
55  *
56  *         [CircleConst]         [CircleNode]
57  *                   \           /
58  *                    \         /
59  *                    [CircleAdd]
60  */
61 bool resolve_with_BroadcastTo(luci::CircleCustom *addv2)
62 {
63   int32_t broadcastTo_idx = get_broadcastTo_index_among_inputs_of(addv2);
64
65   if (broadcastTo_idx == -1)
66     return false;
67
68   auto input = loco::must_cast<const luci::CircleCustomOut *>(addv2->inputs(broadcastTo_idx));
69   auto broadcastTo = loco::must_cast<luci::CircleCustom *>(input->input());
70
71   auto name = addv2->name();
72   assert(name.length() > 0);
73
74   auto add = addv2->graph()->nodes()->create<luci::CircleAdd>();
75   add->fusedActivationFunction(luci::FusedActFunc::NONE);
76   add->x(addv2->inputs(1 - broadcastTo_idx));
77   add->y(broadcastTo->inputs(0));
78   add->name(name + "/Add");
79   luci::add_origin(
80     add, luci::composite_origin({luci::get_origin(broadcastTo), luci::get_origin(addv2)}));
81
82   auto customOut = loco::succs(addv2);
83   assert(customOut.size() == 1);
84   replace(*customOut.begin()).with(add);
85
86   return true;
87 }
88
89 bool resolve_custom_op(luci::CircleCustom *addv2)
90 {
91   const std::string custom_code = addv2->custom_code();
92   const std::vector<uint8_t> custom_options = addv2->custom_options();
93
94   if (custom_code != "AddV2")
95     return false;
96
97   if (addv2->numInputs() != 2)
98     return false;
99
100   // check if inputs are suppport data types
101   for (uint32_t i = 0; i < addv2->numInputs(); i++)
102   {
103     auto input = loco::must_cast<luci::CircleNode *>(addv2->inputs(i));
104     switch (input->dtype())
105     {
106       case loco::DataType::U8:
107       case loco::DataType::S8:
108       case loco::DataType::S16:
109       case loco::DataType::S32:
110       case loco::DataType::FLOAT32:
111         break;
112       default:
113         return false;
114     }
115   }
116
117   if (resolve_with_BroadcastTo(addv2))
118     return true;
119
120   auto name = addv2->name();
121   assert(name.length() > 0);
122
123   auto add = addv2->graph()->nodes()->create<luci::CircleAdd>();
124   add->fusedActivationFunction(luci::FusedActFunc::NONE);
125   add->x(addv2->inputs(0));
126   add->y(addv2->inputs(1));
127   add->name(name + "/Add");
128   luci::add_origin(add, luci::get_origin(addv2));
129
130   auto customOut = loco::succs(addv2);
131   assert(customOut.size() == 1);
132   replace(*customOut.begin()).with(add);
133
134   return true;
135 }
136
137 } // namespace
138
139 namespace luci
140 {
141
142 bool ResolveCustomOpAddPass::run(loco::Graph *g)
143 {
144   bool changed = false;
145
146   for (auto node : loco::active_nodes(loco::output_nodes(g)))
147   {
148     auto cop = dynamic_cast<luci::CircleCustom *>(node);
149     if (not cop)
150       continue;
151
152     if (resolve_custom_op(cop))
153       changed = true;
154   }
155
156   return changed;
157 }
158
159 } // namespace luci