Imported Upstream version 1.19.0
[platform/core/ml/nnfw.git] / compiler / luci / partition / src / PartitionPGroups.cpp
1 /*
2  * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *    http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16
17 #include "PartitionPGroups.h"
18 #include "PartitionIR.h"
19 #include "CircleOpCode.h"
20
21 #include "luci/Partition.h"
22 #include "luci/Log.h"
23 #include "luci/LogHelper.h"
24
25 #include <luci/IR/CircleNodes.h>
26 #include <luci/IR/CircleNodeVisitor.h>
27
28 #include <loco.h>
29
30 namespace
31 {
32
33 class IsVirtualNode final : public luci::CircleNodeVisitor<bool>
34 {
35 public:
36   bool visit(const luci::CircleInput *) final { return true; }
37   bool visit(const luci::CircleOutput *) final { return true; }
38   // For multiple outputs
39   bool visit(const luci::CircleCustomOut *) final { return true; }
40   bool visit(const luci::CircleIfOut *) final { return true; }
41   bool visit(const luci::CircleNonMaxSuppressionV4Out *) final { return true; }
42   bool visit(const luci::CircleNonMaxSuppressionV5Out *) final { return true; }
43   bool visit(const luci::CircleSplitOut *) final { return true; }
44   bool visit(const luci::CircleSplitVOut *) final { return true; }
45   bool visit(const luci::CircleTopKV2Out *) final { return true; }
46   bool visit(const luci::CircleUniqueOut *) final { return true; }
47   bool visit(const luci::CircleUnpackOut *) final { return true; }
48   bool visit(const luci::CircleWhileOut *) final { return true; }
49   // TODO add all virtual nodes
50
51   // default is false
52   bool visit(const luci::CircleNode *) final { return false; }
53 };
54
55 bool check_allocate_partition(const luci::CircleNode *node)
56 {
57   IsVirtualNode query;
58   if (node->accept(&query))
59     return false;
60   /**
61    * @note About CircleConst
62    *       CirleConst acts like a part of some CircleNode and managing mulitiple
63    *       used(referenced) CircleConst is a bit difficult if it's used across
64    *       different PGroup. So we treat this different to other types.
65    *       https://github.com/Samsung/ONE/issues/6230#issuecomment-809802813
66    */
67   if (dynamic_cast<const luci::CircleConst *>(node) != nullptr)
68     return false;
69   return true;
70 }
71
72 class FindGroupToFollow final : public luci::CircleNodeVisitor<const std::string &>
73 {
74 public:
75   FindGroupToFollow(const luci::PartitionTable &partition, luci::PGroups *pgroups)
76     : _partition(partition), _pgroups(pgroups)
77   {
78     // NOTHING TODO
79   }
80
81 private:
82   const std::string &groupof(const luci::CircleNode *input) const
83   {
84     auto group = _pgroups->node2group[input];
85     assert(not group.empty());
86     if (group.empty())
87       return _partition.default_group;
88     return _pgroups->node2group[input];
89   }
90
91 public:
92 #define IMPLEMENT(CLASS)                                             \
93   const std::string &visit(const luci::CLASS *node) final            \
94   {                                                                  \
95     auto input = loco::must_cast<luci::CircleNode *>(node->input()); \
96     return groupof(input);                                           \
97   }
98
99   IMPLEMENT(CircleCustomOut);
100   IMPLEMENT(CircleIfOut);
101   IMPLEMENT(CircleNonMaxSuppressionV4Out);
102   IMPLEMENT(CircleNonMaxSuppressionV5Out);
103   IMPLEMENT(CircleSplitOut);
104   IMPLEMENT(CircleSplitVOut);
105   IMPLEMENT(CircleTopKV2Out);
106   IMPLEMENT(CircleUniqueOut);
107   IMPLEMENT(CircleUnpackOut);
108   IMPLEMENT(CircleWhileOut);
109
110 #undef IMPLEMENT
111
112   // return empty for nothing to do
113   const std::string &visit(const luci::CircleNode *) final { return _empty_str; }
114
115 private:
116   const luci::PartitionTable &_partition;
117   luci::PGroups *_pgroups = nullptr;
118   std::string _empty_str;
119 };
120
121 } // namespace
122
123 namespace
124 {
125
126 void append(luci::CircleNode *node, luci::PGroups *pgroups, const std::string &group, uint32_t idx)
127 {
128   auto pgroup = std::make_unique<luci::PGroup>();
129   pgroup->group = group;
130   pgroup->id = idx + 1;
131
132   auto pnode = std::make_unique<luci::PNode>();
133   pnode->node = node;
134   pnode->group = group;
135   pnode->pgroup = pgroup.get();
136
137   pgroup->pnodes.push_back(std::move(pnode));
138
139   // Set input of PGroup
140   for (uint32_t in = 0; in < node->arity(); ++in)
141   {
142     auto input = loco::must_cast<luci::CircleNode *>(node->arg(in));
143     // this input maybe CircleInput in source graph
144     // --> not confident this is safe
145     pgroup->inputs.push_back(input);
146   }
147   // Set output of PGroup: node itself or multiple virtual outputs
148   // TODO support multiple virtual outputs
149   pgroup->outputs.push_back(node);
150
151   pgroups->node2group[node] = group;
152   pgroups->id2pgroup[pgroup->id] = pgroup.get();
153
154   pgroups->pgroups.push_back(std::move(pgroup));
155 }
156
157 } // namespace
158
159 namespace luci
160 {
161
162 std::unique_ptr<luci::PGroups> produce_pgroups(const luci::Module *source,
163                                                const luci::PartitionTable &partition)
164 {
165   assert(source != nullptr);
166   // NOTE Only main graph (subgraph index 0) will be partitioned.
167   // Other subgraphs will follow the owner (IF/WHILE/...) group
168
169   LOGGER(l);
170
171   auto pgroups = std::make_unique<luci::PGroups>();
172
173   pgroups->default_group = partition.default_group;
174
175   // Create a PGroup per CircleNode: each PGroup will have one CircleNode
176   auto graph = source->graph();
177   auto nodes = graph->nodes();
178   for (uint32_t idx = 0; idx < nodes->size(); ++idx)
179   {
180     auto node = loco::must_cast<luci::CircleNode *>(nodes->at(idx));
181
182     // check if node is normal node that we are interested
183     if (check_allocate_partition(node))
184     {
185       auto group = partition.default_group;
186
187       std::string opcodename; // opcodename or opname
188
189       switch (partition.comply)
190       {
191         case luci::PartitionTable::COMPLY::OPCODE:
192         {
193           opcodename = luci::opcode_name(node);
194           assert(!opcodename.empty());
195
196           auto it = partition.byopcodes.find(opcodename);
197           if (it != partition.byopcodes.end())
198             group = it->second;
199           break;
200         }
201         case luci::PartitionTable::COMPLY::OPNAME:
202         {
203           opcodename = node->name();
204           assert(!opcodename.empty());
205
206           auto it = partition.byopnames.find(opcodename);
207           if (it != partition.byopnames.end())
208             group = it->second;
209           break;
210         }
211
212         default:
213           throw std::runtime_error("Unsupported partition.comply");
214       }
215
216       INFO(l) << "Op: " << node->name() << ": " << opcodename << ", " << node << ", " << group
217               << std::endl;
218
219       append(node, pgroups.get(), group, idx);
220 #if 0
221       auto pgroup = std::make_unique<luci::PGroup>();
222       pgroup->group = group;
223       pgroup->id = idx + 1;
224
225       auto pnode = std::make_unique<luci::PNode>();
226       pnode->node = node;
227       pnode->group = group;
228       pnode->pgroup = pgroup.get();
229
230       pgroup->pnodes.push_back(std::move(pnode));
231
232       // Set input of PGroup
233       for (uint32_t in = 0; in < node->arity(); ++in)
234       {
235         auto input = loco::must_cast<luci::CircleNode *>(node->arg(in));
236         // this input maybe CircleInput in source graph
237         // --> not confident this is safe
238         pgroup->inputs.push_back(input);
239       }
240       // Set output of PGroup: node itself or multiple virtual outputs
241       // TODO support multiple virtual outputs
242       pgroup->outputs.push_back(node);
243
244       pgroups->node2group[node] = group;
245       pgroups->id2pgroup[pgroup->id] = pgroup.get();
246
247       pgroups->pgroups.push_back(std::move(pgroup));
248 #endif
249     }
250     else
251     {
252       INFO(l) << "Skip Op: " << node->name() << std::endl;
253       // record as default group
254       pgroups->node2group[node] = partition.default_group;
255     }
256   }
257
258   // handle for virtual nodes like multiple outputs
259   // these nodes should follow group of the input
260   for (uint32_t idx = 0; idx < nodes->size(); ++idx)
261   {
262     auto node = loco::must_cast<luci::CircleNode *>(nodes->at(idx));
263
264     // for virtual nodes like CircleUnpackOut should follow it's input (owner)
265     // or just set to default
266     FindGroupToFollow query(partition, pgroups.get());
267     const auto &group = node->accept(&query);
268     if (not group.empty())
269     {
270       append(node, pgroups.get(), group, idx);
271     }
272   }
273
274   return std::move(pgroups);
275 }
276
277 } // namespace luci