Imported Upstream version 1.18.0
[platform/core/ml/nnfw.git] / compiler / luci / partition / src / PartitionPGroups.cpp
1 /*
2  * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *    http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16
17 #include "PartitionPGroups.h"
18 #include "PartitionIR.h"
19 #include "CircleOpCode.h"
20
21 #include "luci/Partition.h"
22 #include "luci/Log.h"
23 #include "luci/LogHelper.h"
24
25 #include <luci/IR/CircleNodes.h>
26 #include <luci/IR/CircleNodeVisitor.h>
27
28 #include <loco.h>
29
30 namespace
31 {
32
33 class IsVirtualNode final : public luci::CircleNodeVisitor<bool>
34 {
35 public:
36   bool visit(const luci::CircleInput *) final { return true; }
37   bool visit(const luci::CircleOutput *) final { return true; }
38   // TODO add all virtual nodes
39
40   // default is false
41   bool visit(const luci::CircleNode *) final { return false; }
42 };
43
44 bool check_allocate_partition(const luci::CircleNode *node)
45 {
46   IsVirtualNode query;
47   if (node->accept(&query))
48     return false;
49   /**
50    * @note About CircleConst
51    *       CirleConst acts like a part of some CircleNode and managing mulitiple
52    *       used(referenced) CircleConst is a bit difficult if it's used across
53    *       different PGroup. So we treat this different to other types.
54    *       https://github.com/Samsung/ONE/issues/6230#issuecomment-809802813
55    */
56   if (dynamic_cast<const luci::CircleConst *>(node) != nullptr)
57     return false;
58   return true;
59 }
60
61 } // namespace
62
63 namespace luci
64 {
65
66 std::unique_ptr<luci::PGroups> produce_pgroups(const luci::Module *source,
67                                                const luci::PartitionTable &partition)
68 {
69   assert(source != nullptr);
70   // NOTE Only main graph (subgraph index 0) will be partitioned.
71   // Other subgraphs will follow the owner (IF/WHILE/...) group
72
73   LOGGER(l);
74
75   auto pgroups = std::make_unique<luci::PGroups>();
76
77   pgroups->default_group = partition.default_group;
78
79   // Create a PGroup per CircleNode: each PGroup will have one CircleNode
80   auto graph = source->graph();
81   auto nodes = graph->nodes();
82   for (uint32_t idx = 0; idx < nodes->size(); ++idx)
83   {
84     auto node = loco::must_cast<luci::CircleNode *>(nodes->at(idx));
85
86     // check if node is normal node that we are interested
87     if (check_allocate_partition(node))
88     {
89       auto group = partition.default_group;
90
91       std::string opcodename; // opcodename or opname
92
93       switch (partition.comply)
94       {
95         case luci::PartitionTable::COMPLY::OPCODE:
96         {
97           opcodename = luci::opcode_name(node);
98           assert(!opcodename.empty());
99
100           auto it = partition.byopcodes.find(opcodename);
101           if (it != partition.byopcodes.end())
102             group = it->second;
103           break;
104         }
105         case luci::PartitionTable::COMPLY::OPNAME:
106         {
107           opcodename = node->name();
108           assert(!opcodename.empty());
109
110           auto it = partition.byopnames.find(opcodename);
111           if (it != partition.byopnames.end())
112             group = it->second;
113           break;
114         }
115
116         default:
117           throw std::runtime_error("Unsupported partition.comply");
118       }
119
120       INFO(l) << "Op: " << node->name() << ": " << opcodename << ", " << node << ", " << group
121               << std::endl;
122
123       auto pgroup = std::make_unique<luci::PGroup>();
124       pgroup->group = group;
125       pgroup->id = idx + 1;
126
127       auto pnode = std::make_unique<luci::PNode>();
128       pnode->node = node;
129       pnode->group = group;
130       pnode->pgroup = pgroup.get();
131
132       pgroup->pnodes.push_back(std::move(pnode));
133
134       // Set input of PGroup
135       for (uint32_t in = 0; in < node->arity(); ++in)
136       {
137         auto input = loco::must_cast<luci::CircleNode *>(node->arg(in));
138         // this input maybe CircleInput in source graph
139         // --> not confident this is safe
140         pgroup->inputs.push_back(input);
141       }
142       // Set output of PGroup: node itself or multiple virtual outputs
143       // TODO support multiple virtual outputs
144       pgroup->outputs.push_back(node);
145
146       pgroups->node2group[node] = group;
147       pgroups->id2pgroup[pgroup->id] = pgroup.get();
148
149       pgroups->pgroups.push_back(std::move(pgroup));
150     }
151     else
152     {
153       INFO(l) << "Skip Op: " << node->name() << std::endl;
154       // record as default group
155       pgroups->node2group[node] = partition.default_group;
156     }
157   }
158
159   return std::move(pgroups);
160 }
161
162 } // namespace luci