2 * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
8 * http://www.apache.org/licenses/LICENSE-2.0
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
17 #include "PartitionPGroups.h"
18 #include "PartitionIR.h"
19 #include "CircleOpCode.h"
21 #include "luci/Partition.h"
23 #include "luci/LogHelper.h"
25 #include <luci/IR/CircleNodes.h>
26 #include <luci/IR/CircleNodeVisitor.h>
33 class IsVirtualNode final : public luci::CircleNodeVisitor<bool>
36 bool visit(const luci::CircleInput *) final { return true; }
37 bool visit(const luci::CircleOutput *) final { return true; }
38 // TODO add all virtual nodes
41 bool visit(const luci::CircleNode *) final { return false; }
44 bool check_allocate_partition(const luci::CircleNode *node)
47 if (node->accept(&query))
50 * @note About CircleConst
51 * CirleConst acts like a part of some CircleNode and managing mulitiple
52 * used(referenced) CircleConst is a bit difficult if it's used across
53 * different PGroup. So we treat this different to other types.
54 * https://github.com/Samsung/ONE/issues/6230#issuecomment-809802813
56 if (dynamic_cast<const luci::CircleConst *>(node) != nullptr)
66 std::unique_ptr<luci::PGroups> produce_pgroups(const luci::Module *source,
67 const luci::PartitionTable &partition)
69 assert(source != nullptr);
70 // NOTE Only main graph (subgraph index 0) will be partitioned.
71 // Other subgraphs will follow the owner (IF/WHILE/...) group
75 auto pgroups = std::make_unique<luci::PGroups>();
77 pgroups->default_group = partition.default_group;
79 // Create a PGroup per CircleNode: each PGroup will have one CircleNode
80 auto graph = source->graph();
81 auto nodes = graph->nodes();
82 for (uint32_t idx = 0; idx < nodes->size(); ++idx)
84 auto node = loco::must_cast<luci::CircleNode *>(nodes->at(idx));
86 // check if node is normal node that we are interested
87 if (check_allocate_partition(node))
89 auto group = partition.default_group;
91 std::string opcodename; // opcodename or opname
93 switch (partition.comply)
95 case luci::PartitionTable::COMPLY::OPCODE:
97 opcodename = luci::opcode_name(node);
98 assert(!opcodename.empty());
100 auto it = partition.byopcodes.find(opcodename);
101 if (it != partition.byopcodes.end())
105 case luci::PartitionTable::COMPLY::OPNAME:
107 opcodename = node->name();
108 assert(!opcodename.empty());
110 auto it = partition.byopnames.find(opcodename);
111 if (it != partition.byopnames.end())
117 throw std::runtime_error("Unsupported partition.comply");
120 INFO(l) << "Op: " << node->name() << ": " << opcodename << ", " << node << ", " << group
123 auto pgroup = std::make_unique<luci::PGroup>();
124 pgroup->group = group;
125 pgroup->id = idx + 1;
127 auto pnode = std::make_unique<luci::PNode>();
129 pnode->group = group;
130 pnode->pgroup = pgroup.get();
132 pgroup->pnodes.push_back(std::move(pnode));
134 // Set input of PGroup
135 for (uint32_t in = 0; in < node->arity(); ++in)
137 auto input = loco::must_cast<luci::CircleNode *>(node->arg(in));
138 // this input maybe CircleInput in source graph
139 // --> not confident this is safe
140 pgroup->inputs.push_back(input);
142 // Set output of PGroup: node itself or multiple virtual outputs
143 // TODO support multiple virtual outputs
144 pgroup->outputs.push_back(node);
146 pgroups->node2group[node] = group;
147 pgroups->id2pgroup[pgroup->id] = pgroup.get();
149 pgroups->pgroups.push_back(std::move(pgroup));
153 INFO(l) << "Skip Op: " << node->name() << std::endl;
154 // record as default group
155 pgroups->node2group[node] = partition.default_group;
159 return std::move(pgroups);