Imported Upstream version 1.18.0
[platform/core/ml/nnfw.git] / compiler / luci / partition / src / PartitionMerge.cpp
1 /*
2  * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *    http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16
17 #include "PartitionMerge.h"
18
19 #include <algorithm>
20
21 namespace
22 {
23
24 /**
25  * @brief return true if pgroup_i output is one of the inputs of pgroup
26  */
27 bool is_input_of(const luci::PGroup *pgroup_i, const luci::PGroup *pgroup)
28 {
29   for (auto *output : pgroup_i->outputs)
30   {
31     for (auto *input : pgroup->inputs)
32     {
33       if (input == output)
34         return true;
35     }
36   }
37   return false;
38 }
39
40 /**
41  * @brief return true if there is only one input or all the inputs have same group
42  * @note  pgroups is used to find group of pgroup
43  */
44 bool is_input_same(const luci::PGroup *pgroup, const luci::PGroups *pgroups)
45 {
46   assert(pgroups != nullptr);
47   assert(pgroup != nullptr);
48
49   const luci::PGroup *input_pgroup = nullptr;
50   std::string group;
51   for (auto &input : pgroup->inputs)
52   {
53     // We ignore below logic for CircleConst.
54     // CircleConst will be cloned if they are not found in pgroup as an input.
55     // Refer build_graph(), "add CircleConst for inputs"
56     // Reason: CircleConst can be shared as input to multiple nodes
57     //         where each node can be placed in different groups. For this case
58     //         we need to clone this CircleConst for each graph of the group.
59     if (dynamic_cast<const luci::CircleConst *>(input) != nullptr)
60       continue;
61
62     auto input_group = pgroups->group_of(input);
63     // NOTE: all the nodes should be registered and return should be valid group.
64     // produce_pgroups() should ensure this, except CircleConst, Input, Outputs.
65     // assert here to find if there is any problem with this.
66     assert(not input_group.empty());
67     if (input_group.empty())
68       input_group = pgroups->default_group;
69
70     if (group.empty())
71       group = input_group;
72     else
73     {
74       if (group != input_group)
75         return false;
76     }
77     // if there are multiple inputs, all the inputs should be in same pgroup
78     // https://github.com/Samsung/ONE/issues/6230#issuecomment-801618150
79     // https://github.com/Samsung/ONE/issues/6230#issuecomment-801680531
80     auto pgroup_input = pgroups->pgroup_of(input);
81     if (pgroup_input != nullptr)
82     {
83       if (input_pgroup == nullptr)
84         input_pgroup = pgroup_input;
85       else
86       {
87         if (input_pgroup != pgroup_input)
88           return false;
89       }
90     }
91   }
92   return true;
93 }
94
95 /**
96  * @brief merge pgroup into pgroup_i
97  * @note  output of pgroup_i should be input of pgroup
98  */
99 void merge_into(luci::PGroup *pgroup, luci::PGroup *pgroup_i)
100 {
101   for (auto &pnode : pgroup->pnodes)
102   {
103     // update pgroup for this pnode
104     pnode->pgroup = pgroup_i;
105     assert(pnode->group == pgroup_i->group);
106
107     // we don't need to add this in topological order:
108     // all the nodes will be created first then connection will be held
109     pgroup_i->pnodes.push_back(std::move(pnode));
110     // note: pnode is now nullptr as it's moved into pgroup_i->pnodes
111   }
112
113   for (auto &input : pgroup->inputs)
114   {
115     // add inputs of pgroup to pgroup_i if not member of pgroup_i
116     bool found_in_pgroup_i = false;
117     for (auto &pnode : pgroup_i->pnodes)
118     {
119       if (input == pnode->node)
120       {
121         found_in_pgroup_i = true;
122         break;
123       }
124     }
125     // skip if this input is already in the inputs
126     auto fit = std::find(pgroup_i->inputs.begin(), pgroup_i->inputs.end(), input);
127     if (fit != pgroup_i->inputs.end())
128     {
129       found_in_pgroup_i = true;
130     }
131     // note: if we force found_in_pgroup_i to false, for testing there will be
132     // unnecessary inputs
133     if (not found_in_pgroup_i)
134     {
135       // node input maybe in another pgroup
136       pgroup_i->inputs.push_back(input);
137     }
138   }
139   // add outputs of pgroup to pgroup_i outputs if not exist
140   for (auto &output : pgroup->outputs)
141   {
142     auto it = std::find(pgroup_i->outputs.begin(), pgroup_i->outputs.end(), output);
143     if (it == pgroup_i->outputs.end())
144     {
145       pgroup_i->outputs.push_back(output);
146     }
147   }
148 }
149
150 } // namespace
151
152 namespace luci
153 {
154
155 /**
156  * @brief This will merge pgroups with same group values in topological order
157  */
158 std::unique_ptr<luci::PGroups> merge_pgroups(const luci::PGroups *s_pgroups)
159 {
160   // Make a copy of pgroups to apply merge action
161   // Q) do we really need a copy?
162   auto d_pgroups = s_pgroups->make_copy();
163
164   // Merge partition graphs
165   // - This is initial implementation that works for limited networks
166   // - if A and B is same group -> if A is input of B -> ... -> merge B into A
167   auto &pgroups = d_pgroups->pgroups;
168   bool changed;
169   do
170   {
171     changed = false;
172     for (auto &pgroup_i : pgroups)
173     {
174       bool merged = false;
175       for (auto it = pgroups.begin(); it != pgroups.end(); ++it)
176       {
177         auto &pgroup = *it;
178
179         // skip if same object
180         if (pgroup->id == pgroup_i->id)
181           continue;
182         // skip if different group
183         if (pgroup->group != pgroup_i->group)
184           continue;
185         // skip if not connected
186         if (!is_input_of(pgroup_i.get(), pgroup.get()))
187           continue;
188         // skip if there are multiple inputs but inputs differ in group
189         if (!is_input_same(pgroup.get(), d_pgroups.get()))
190           continue;
191         // TODO add more condition may be needed
192
193         merge_into(pgroup.get(), pgroup_i.get());
194
195         auto eit = d_pgroups->id2pgroup.find(pgroup->id);
196         assert(eit != d_pgroups->id2pgroup.end());
197         d_pgroups->id2pgroup.erase(eit);
198
199         // remove merged pgroup from pgroups
200         pgroups.erase(it);
201
202         merged = true;
203         break;
204       }
205       if (merged)
206       {
207         changed = true;
208         break;
209       }
210     }
211   } while (changed);
212
213   return std::move(d_pgroups);
214 }
215
216 } // namespace luci