2 * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
8 * http://www.apache.org/licenses/LICENSE-2.0
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
17 #include "PartitionMerge.h"
25 * @brief return true if pgroup_i output is one of the inputs of pgroup
27 bool is_input_of(const luci::PGroup *pgroup_i, const luci::PGroup *pgroup)
29 for (auto *output : pgroup_i->outputs)
31 for (auto *input : pgroup->inputs)
41 * @brief return true if there is only one input or all the inputs have same group
42 * @note pgroups is used to find group of pgroup
44 bool is_input_same(const luci::PGroup *pgroup, const luci::PGroups *pgroups)
46 assert(pgroups != nullptr);
47 assert(pgroup != nullptr);
49 const luci::PGroup *input_pgroup = nullptr;
51 for (auto &input : pgroup->inputs)
53 // We ignore below logic for CircleConst.
54 // CircleConst will be cloned if they are not found in pgroup as an input.
55 // Refer build_graph(), "add CircleConst for inputs"
56 // Reason: CircleConst can be shared as input to multiple nodes
57 // where each node can be placed in different groups. For this case
58 // we need to clone this CircleConst for each graph of the group.
59 if (dynamic_cast<const luci::CircleConst *>(input) != nullptr)
62 auto input_group = pgroups->group_of(input);
63 // NOTE: all the nodes should be registered and return should be valid group.
64 // produce_pgroups() should ensure this, except CircleConst, Input, Outputs.
65 // assert here to find if there is any problem with this.
66 assert(not input_group.empty());
67 if (input_group.empty())
68 input_group = pgroups->default_group;
74 if (group != input_group)
77 // if there are multiple inputs, all the inputs should be in same pgroup
78 // https://github.com/Samsung/ONE/issues/6230#issuecomment-801618150
79 // https://github.com/Samsung/ONE/issues/6230#issuecomment-801680531
80 auto pgroup_input = pgroups->pgroup_of(input);
81 if (pgroup_input != nullptr)
83 if (input_pgroup == nullptr)
84 input_pgroup = pgroup_input;
87 if (input_pgroup != pgroup_input)
96 * @brief merge pgroup into pgroup_i
97 * @note output of pgroup_i should be input of pgroup
99 void merge_into(luci::PGroup *pgroup, luci::PGroup *pgroup_i)
101 for (auto &pnode : pgroup->pnodes)
103 // update pgroup for this pnode
104 pnode->pgroup = pgroup_i;
105 assert(pnode->group == pgroup_i->group);
107 // we don't need to add this in topological order:
108 // all the nodes will be created first then connection will be held
109 pgroup_i->pnodes.push_back(std::move(pnode));
110 // note: pnode is now nullptr as it's moved into pgroup_i->pnodes
113 for (auto &input : pgroup->inputs)
115 // add inputs of pgroup to pgroup_i if not member of pgroup_i
116 bool found_in_pgroup_i = false;
117 for (auto &pnode : pgroup_i->pnodes)
119 if (input == pnode->node)
121 found_in_pgroup_i = true;
125 // skip if this input is already in the inputs
126 auto fit = std::find(pgroup_i->inputs.begin(), pgroup_i->inputs.end(), input);
127 if (fit != pgroup_i->inputs.end())
129 found_in_pgroup_i = true;
131 // note: if we force found_in_pgroup_i to false, for testing there will be
132 // unnecessary inputs
133 if (not found_in_pgroup_i)
135 // node input maybe in another pgroup
136 pgroup_i->inputs.push_back(input);
139 // add outputs of pgroup to pgroup_i outputs if not exist
140 for (auto &output : pgroup->outputs)
142 auto it = std::find(pgroup_i->outputs.begin(), pgroup_i->outputs.end(), output);
143 if (it == pgroup_i->outputs.end())
145 pgroup_i->outputs.push_back(output);
156 * @brief This will merge pgroups with same group values in topological order
158 std::unique_ptr<luci::PGroups> merge_pgroups(const luci::PGroups *s_pgroups)
160 // Make a copy of pgroups to apply merge action
161 // Q) do we really need a copy?
162 auto d_pgroups = s_pgroups->make_copy();
164 // Merge partition graphs
165 // - This is initial implementation that works for limited networks
166 // - if A and B is same group -> if A is input of B -> ... -> merge B into A
167 auto &pgroups = d_pgroups->pgroups;
172 for (auto &pgroup_i : pgroups)
175 for (auto it = pgroups.begin(); it != pgroups.end(); ++it)
179 // skip if same object
180 if (pgroup->id == pgroup_i->id)
182 // skip if different group
183 if (pgroup->group != pgroup_i->group)
185 // skip if not connected
186 if (!is_input_of(pgroup_i.get(), pgroup.get()))
188 // skip if there are multiple inputs but inputs differ in group
189 if (!is_input_same(pgroup.get(), d_pgroups.get()))
191 // TODO add more condition may be needed
193 merge_into(pgroup.get(), pgroup_i.get());
195 auto eit = d_pgroups->id2pgroup.find(pgroup->id);
196 assert(eit != d_pgroups->id2pgroup.end());
197 d_pgroups->id2pgroup.erase(eit);
199 // remove merged pgroup from pgroups
213 return std::move(d_pgroups);