2 * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
8 * http://www.apache.org/licenses/LICENSE-2.0
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
17 #include "PartitionPGroups.h"
18 #include "PartitionIR.h"
19 #include "CircleOpCode.h"
21 #include "luci/Partition.h"
23 #include "luci/LogHelper.h"
25 #include <luci/IR/CircleNodes.h>
26 #include <luci/IR/CircleNodeVisitor.h>
33 class IsVirtualNode final : public luci::CircleNodeVisitor<bool>
36 bool visit(const luci::CircleInput *) final { return true; }
37 bool visit(const luci::CircleOutput *) final { return true; }
38 // For multiple outputs
39 bool visit(const luci::CircleCustomOut *) final { return true; }
40 bool visit(const luci::CircleIfOut *) final { return true; }
41 bool visit(const luci::CircleNonMaxSuppressionV4Out *) final { return true; }
42 bool visit(const luci::CircleNonMaxSuppressionV5Out *) final { return true; }
43 bool visit(const luci::CircleSplitOut *) final { return true; }
44 bool visit(const luci::CircleSplitVOut *) final { return true; }
45 bool visit(const luci::CircleTopKV2Out *) final { return true; }
46 bool visit(const luci::CircleUniqueOut *) final { return true; }
47 bool visit(const luci::CircleUnpackOut *) final { return true; }
48 bool visit(const luci::CircleWhileOut *) final { return true; }
49 // TODO add all virtual nodes
52 bool visit(const luci::CircleNode *) final { return false; }
55 bool check_allocate_partition(const luci::CircleNode *node)
58 if (node->accept(&query))
61 * @note About CircleConst
62 * CirleConst acts like a part of some CircleNode and managing mulitiple
63 * used(referenced) CircleConst is a bit difficult if it's used across
64 * different PGroup. So we treat this different to other types.
65 * https://github.com/Samsung/ONE/issues/6230#issuecomment-809802813
67 if (dynamic_cast<const luci::CircleConst *>(node) != nullptr)
72 class FindGroupToFollow final : public luci::CircleNodeVisitor<const std::string &>
75 FindGroupToFollow(const luci::PartitionTable &partition, luci::PGroups *pgroups)
76 : _partition(partition), _pgroups(pgroups)
82 const std::string &groupof(const luci::CircleNode *input) const
84 auto group = _pgroups->node2group[input];
85 assert(not group.empty());
87 return _partition.default_group;
88 return _pgroups->node2group[input];
92 #define IMPLEMENT(CLASS) \
93 const std::string &visit(const luci::CLASS *node) final \
95 auto input = loco::must_cast<luci::CircleNode *>(node->input()); \
96 return groupof(input); \
99 IMPLEMENT(CircleCustomOut);
100 IMPLEMENT(CircleIfOut);
101 IMPLEMENT(CircleNonMaxSuppressionV4Out);
102 IMPLEMENT(CircleNonMaxSuppressionV5Out);
103 IMPLEMENT(CircleSplitOut);
104 IMPLEMENT(CircleSplitVOut);
105 IMPLEMENT(CircleTopKV2Out);
106 IMPLEMENT(CircleUniqueOut);
107 IMPLEMENT(CircleUnpackOut);
108 IMPLEMENT(CircleWhileOut);
112 // return empty for nothing to do
113 const std::string &visit(const luci::CircleNode *) final { return _empty_str; }
116 const luci::PartitionTable &_partition;
117 luci::PGroups *_pgroups = nullptr;
118 std::string _empty_str;
126 void append(luci::CircleNode *node, luci::PGroups *pgroups, const std::string &group, uint32_t idx)
128 auto pgroup = std::make_unique<luci::PGroup>();
129 pgroup->group = group;
130 pgroup->id = idx + 1;
132 auto pnode = std::make_unique<luci::PNode>();
134 pnode->group = group;
135 pnode->pgroup = pgroup.get();
137 pgroup->pnodes.push_back(std::move(pnode));
139 // Set input of PGroup
140 for (uint32_t in = 0; in < node->arity(); ++in)
142 auto input = loco::must_cast<luci::CircleNode *>(node->arg(in));
143 // this input maybe CircleInput in source graph
144 // --> not confident this is safe
145 pgroup->inputs.push_back(input);
147 // Set output of PGroup: node itself or multiple virtual outputs
148 // TODO support multiple virtual outputs
149 pgroup->outputs.push_back(node);
151 pgroups->node2group[node] = group;
152 pgroups->id2pgroup[pgroup->id] = pgroup.get();
154 pgroups->pgroups.push_back(std::move(pgroup));
162 std::unique_ptr<luci::PGroups> produce_pgroups(const luci::Module *source,
163 const luci::PartitionTable &partition)
165 assert(source != nullptr);
166 // NOTE Only main graph (subgraph index 0) will be partitioned.
167 // Other subgraphs will follow the owner (IF/WHILE/...) group
171 auto pgroups = std::make_unique<luci::PGroups>();
173 pgroups->default_group = partition.default_group;
175 // Create a PGroup per CircleNode: each PGroup will have one CircleNode
176 auto graph = source->graph();
177 auto nodes = graph->nodes();
178 for (uint32_t idx = 0; idx < nodes->size(); ++idx)
180 auto node = loco::must_cast<luci::CircleNode *>(nodes->at(idx));
182 // check if node is normal node that we are interested
183 if (check_allocate_partition(node))
185 auto group = partition.default_group;
187 std::string opcodename; // opcodename or opname
189 switch (partition.comply)
191 case luci::PartitionTable::COMPLY::OPCODE:
193 opcodename = luci::opcode_name(node);
194 assert(!opcodename.empty());
196 auto it = partition.byopcodes.find(opcodename);
197 if (it != partition.byopcodes.end())
201 case luci::PartitionTable::COMPLY::OPNAME:
203 opcodename = node->name();
204 assert(!opcodename.empty());
206 auto it = partition.byopnames.find(opcodename);
207 if (it != partition.byopnames.end())
213 throw std::runtime_error("Unsupported partition.comply");
216 INFO(l) << "Op: " << node->name() << ": " << opcodename << ", " << node << ", " << group
219 append(node, pgroups.get(), group, idx);
221 auto pgroup = std::make_unique<luci::PGroup>();
222 pgroup->group = group;
223 pgroup->id = idx + 1;
225 auto pnode = std::make_unique<luci::PNode>();
227 pnode->group = group;
228 pnode->pgroup = pgroup.get();
230 pgroup->pnodes.push_back(std::move(pnode));
232 // Set input of PGroup
233 for (uint32_t in = 0; in < node->arity(); ++in)
235 auto input = loco::must_cast<luci::CircleNode *>(node->arg(in));
236 // this input maybe CircleInput in source graph
237 // --> not confident this is safe
238 pgroup->inputs.push_back(input);
240 // Set output of PGroup: node itself or multiple virtual outputs
241 // TODO support multiple virtual outputs
242 pgroup->outputs.push_back(node);
244 pgroups->node2group[node] = group;
245 pgroups->id2pgroup[pgroup->id] = pgroup.get();
247 pgroups->pgroups.push_back(std::move(pgroup));
252 INFO(l) << "Skip Op: " << node->name() << std::endl;
253 // record as default group
254 pgroups->node2group[node] = partition.default_group;
258 // handle for virtual nodes like multiple outputs
259 // these nodes should follow group of the input
260 for (uint32_t idx = 0; idx < nodes->size(); ++idx)
262 auto node = loco::must_cast<luci::CircleNode *>(nodes->at(idx));
264 // for virtual nodes like CircleUnpackOut should follow it's input (owner)
265 // or just set to default
266 FindGroupToFollow query(partition, pgroups.get());
267 const auto &group = node->accept(&query);
268 if (not group.empty())
270 append(node, pgroups.get(), group, idx);
274 return std::move(pgroups);