2 * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
8 * http://www.apache.org/licenses/LICENSE-2.0
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
17 #include "luci/Pass/ResolveCustomOpAddPass.h"
19 #include <luci/IR/CircleNodes.h>
20 #include <luci/IR/AttrFusedActFunc.h>
21 #include <luci/Profile/CircleNodeOrigin.h>
23 #include <flatbuffers/flexbuffers.h>
28 /// @brief Returns the index of BroadcastTo node among cop's inputs.
29 // NOTE This function assumes there is only one BroadcastTo node among its inputs.
30 int32_t get_broadcastTo_index_among_inputs_of(luci::CircleCustom *cop)
32 for (uint32_t idx = 0; idx < cop->numInputs(); idx++)
34 auto input = dynamic_cast<const luci::CircleCustomOut *>(cop->inputs(idx));
37 auto broadcastTo = loco::must_cast<luci::CircleCustom *>(input->input());
38 if (broadcastTo->custom_code() == "BroadcastTo")
49 * [CircleNode] [BroadcastTo(CircleCustom)]
53 * [AddV2(CircleCustom)]
56 * [CircleConst] [CircleNode]
61 bool resolve_with_BroadcastTo(luci::CircleCustom *addv2)
63 int32_t broadcastTo_idx = get_broadcastTo_index_among_inputs_of(addv2);
65 if (broadcastTo_idx == -1)
68 auto input = loco::must_cast<const luci::CircleCustomOut *>(addv2->inputs(broadcastTo_idx));
69 auto broadcastTo = loco::must_cast<luci::CircleCustom *>(input->input());
71 auto name = addv2->name();
72 assert(name.length() > 0);
74 auto add = addv2->graph()->nodes()->create<luci::CircleAdd>();
75 add->fusedActivationFunction(luci::FusedActFunc::NONE);
76 add->x(addv2->inputs(1 - broadcastTo_idx));
77 add->y(broadcastTo->inputs(0));
78 add->name(name + "/Add");
80 add, luci::composite_origin({luci::get_origin(broadcastTo), luci::get_origin(addv2)}));
82 auto customOut = loco::succs(addv2);
83 assert(customOut.size() == 1);
84 replace(*customOut.begin()).with(add);
89 bool resolve_custom_op(luci::CircleCustom *addv2)
91 const std::string custom_code = addv2->custom_code();
92 const std::vector<uint8_t> custom_options = addv2->custom_options();
94 if (custom_code != "AddV2")
97 if (addv2->numInputs() != 2)
100 // check if inputs are suppport data types
101 for (uint32_t i = 0; i < addv2->numInputs(); i++)
103 auto input = loco::must_cast<luci::CircleNode *>(addv2->inputs(i));
104 switch (input->dtype())
106 case loco::DataType::U8:
107 case loco::DataType::S8:
108 case loco::DataType::S16:
109 case loco::DataType::S32:
110 case loco::DataType::FLOAT32:
117 if (resolve_with_BroadcastTo(addv2))
120 auto name = addv2->name();
121 assert(name.length() > 0);
123 auto add = addv2->graph()->nodes()->create<luci::CircleAdd>();
124 add->fusedActivationFunction(luci::FusedActFunc::NONE);
125 add->x(addv2->inputs(0));
126 add->y(addv2->inputs(1));
127 add->name(name + "/Add");
128 luci::add_origin(add, luci::get_origin(addv2));
130 auto customOut = loco::succs(addv2);
131 assert(customOut.size() == 1);
132 replace(*customOut.begin()).with(add);
142 bool ResolveCustomOpAddPass::run(loco::Graph *g)
144 bool changed = false;
146 for (auto node : loco::active_nodes(loco::output_nodes(g)))
148 auto cop = dynamic_cast<luci::CircleCustom *>(node);
152 if (resolve_custom_op(cop))