2 * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
8 * http://www.apache.org/licenses/LICENSE-2.0
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
17 #include "ir/LoweredGraph.h"
21 #include "util/logging.h"
22 #include "pass/ConstantInsertionPass.h"
23 #include "pass/ConstantLoweringPass.h"
24 #include "pass/PermutationOperationPass.h"
25 #include "pass/PermutationInsertionPass.h"
26 #include "pass/PermutationEliminationPass.h"
27 #include "ir/GraphIterator.h"
28 #include "verifier/Verifier.h"
29 #include "backend/Backend.h"
30 #include "backend/IConfig.h"
31 #include "compiler/BackendResolver.h"
32 #include "compiler/ManualScheduler.h"
33 #include "compiler/HEScheduler.h"
40 LoweredGraph::LoweredGraph(const Graph &graph, const compiler::CompilerOptions &options)
43 bool linear_executor = (options.executor == "Linear");
45 // Build backend contexts
46 auto &backend_manager = compiler::BackendManager::get();
48 // Always create Controlflow backend context
49 auto cf_backend = backend_manager.getControlflow();
50 _backend_contexts.emplace(
51 cf_backend, cf_backend->newContext(_graph, _graph.getKernelBuilder(), linear_executor));
53 // Create contexts for other backends
54 for (auto backend_str : options.backend_list)
56 backend_manager.loadBackend(backend_str);
57 auto backend = backend_manager.get(backend_str);
59 // TODO As the default value of backend list contains "cpu", "acl_cl" and "acl_neon", and some
60 // are not available on x64 or some other platforms. So this may be a workaround for x64 and
61 // we should change it back(throw if backend is not loaded) later.
64 VERBOSE(LoweredGraph) << "Cannot load backend - " << backend_str;
68 _backend_contexts.emplace(
69 backend, backend->newContext(_graph, _graph.getKernelBuilder(), linear_executor));
71 if (backend_manager.num_backends() == 0)
72 throw std::runtime_error{"No available backends loaded."};
74 // TODO Move "schedule" phase out of here
76 std::unique_ptr<compiler::BackendResolver> backend_resolver;
77 if (options.he_scheduler)
79 auto scheduler = compiler::HEScheduler(_backend_contexts, options);
80 backend_resolver = scheduler.schedule(_graph);
81 _indexed_ranks = scheduler.getIndexedRanks();
85 auto scheduler = compiler::ManualScheduler(_backend_contexts, options);
86 backend_resolver = scheduler.schedule(_graph);
90 // operand::LowerInfo holder
91 OperandIndexMap<std::unique_ptr<operand::LowerInfo>> operands_lower_info;
93 _graph.operands().iterate([&](const OperandIndex &index, const Operand &) {
94 operands_lower_info[index] = std::make_unique<operand::LowerInfo>();
97 // Make op_seqs while checking whether a node can be merged into a op_seq.
98 makeOpSequences(operands_lower_info, options, *backend_resolver);
100 _op_seqs.iterate([&](const OpSequenceIndex &, OpSequence &op_seq) {
101 assert(op_seq.operations().size() > 0);
102 std::reverse(std::begin(op_seq.operations()), std::end(op_seq.operations()));
105 _op_seqs.dump("merged and sorted operations without permutation", _graph.operations());
107 pass::ConstantInsertionPass ci_pass(*this);
110 pass::ConstantLoweringPass cl_pass(*this);
113 // Set LowerInfo for each operand from the operand::LowerInfo holder
114 manipulateLowerInfo(operands_lower_info, options.is_primary_subgraph);
119 // Run Permutation Passes
121 pass::PermutationOperationPass po_pass(*this);
124 pass::PermutationInsertionPass pi_pass(*this);
127 pass::PermutationEliminationPass pe_pass(*this);
130 _op_seqs.dump("merged and sorted operations with permutation", _graph.operations());
133 // Graph verifications
135 assert(verifier::DAGChecker().verify(_graph));
136 assert(verifier::EdgeConsistencyChecker().verify(_graph));
140 const operation::LowerInfo *LoweredGraph::getLowerInfo(const OpSequenceIndex &op_seq_index) const
142 auto itr = _lower_info_map.op_seq.find(op_seq_index);
143 if (itr == _lower_info_map.op_seq.end())
145 return itr->second.get();
148 void LoweredGraph::setLowerInfo(const OpSequenceIndex &op_seq_index,
149 std::unique_ptr<operation::LowerInfo> &&lower_info)
151 _lower_info_map.op_seq.insert(std::make_pair(op_seq_index, std::move(lower_info)));
154 void LoweredGraph::removeLowerInfo(const OpSequenceIndex &op_seq_index)
156 auto &op_seq_lower_info = _lower_info_map.op_seq;
157 assert(op_seq_lower_info.find(op_seq_index) != op_seq_lower_info.end());
158 for (auto it = op_seq_lower_info.begin(); it != op_seq_lower_info.end(); ++it)
160 if (it->first == op_seq_index)
162 op_seq_lower_info.erase(it);
168 const operand::LowerInfo *LoweredGraph::getLowerInfo(const OperandIndex &index) const
170 auto itr = _lower_info_map.operand.find(index);
171 if (itr == _lower_info_map.operand.end())
173 return itr->second.get();
176 operand::LowerInfo *LoweredGraph::getLowerInfo(const OperandIndex &index)
178 auto itr = _lower_info_map.operand.find(index);
179 if (itr == _lower_info_map.operand.end())
181 return itr->second.get();
184 void LoweredGraph::setLowerInfo(const OperandIndex &index,
185 std::unique_ptr<operand::LowerInfo> &&lower_info)
187 _lower_info_map.operand.insert(std::make_pair(index, std::move(lower_info)));
190 void LoweredGraph::removeLowerInfo(const OperandIndex &index)
192 _lower_info_map.operand.erase(index);
195 void LoweredGraph::iterateTopolOpSeqs(
196 const std::function<void(const OpSequenceIndex &, const OpSequence &)> &fn) const
198 // Topological Sorting for OpSequences
199 std::vector<OpSequenceIndex> topol_sorted;
200 PostDfsIterator<true>{}.iterateOpSeqs(
202 [&](const OpSequenceIndex &index, const OpSequence &) { topol_sorted.emplace_back(index); });
203 std::reverse(topol_sorted.begin(), topol_sorted.end());
204 for (const auto op_seq_idx : topol_sorted)
206 const auto &op_seq = _op_seqs.at(op_seq_idx);
207 fn(op_seq_idx, op_seq);
211 void LoweredGraph::iterateTopolOpSeqs(
212 const std::function<void(const OpSequenceIndex &, OpSequence &)> &fn)
214 // Topological Sorting for OpSequences
215 std::vector<OpSequenceIndex> topol_sorted;
216 PostDfsIterator<false>{}.iterateOpSeqs(
217 *this, [&](const OpSequenceIndex &index, OpSequence &) { topol_sorted.emplace_back(index); });
218 std::reverse(topol_sorted.begin(), topol_sorted.end());
219 for (const auto op_seq_idx : topol_sorted)
221 auto &op_seq = _op_seqs.at(op_seq_idx);
222 fn(op_seq_idx, op_seq);
226 OpSequenceIndex LoweredGraph::appendFreshSingleOpSequence(const OperationIndex &node_index,
227 const Operation &node)
229 // Create a fresh op_seq with one operation, and append it to op_seqs
230 // Create a fresh op_seq
231 auto op_seq = std::make_unique<OpSequence>(_graph.layout());
234 op_seq->appendOperation(node_index);
236 // Update input/output
237 op_seq->setOutputs(node.getOutputs());
238 op_seq->setInputs(node.getInputs());
240 return _op_seqs.emplace(std::move(op_seq));
243 void LoweredGraph::makeOpSequences(
244 OperandIndexMap<std::unique_ptr<operand::LowerInfo>> &operands_lower_info,
245 const compiler::CompilerOptions &options, const compiler::BackendResolver &backend_resolver)
247 // if SUBG_MAX_NODE == 0, no limit on nodes of a op_seq
248 const int op_seq_max_node = options.op_seq_max_node;
249 assert(op_seq_max_node >= 0);
251 bool is_profiling = options.he_profiling_mode;
252 OpSequence *op_seq = nullptr;
253 OpSequenceIndex op_seq_index;
255 // NOTE: The below method appends nodes while making one op_seq if needed. If something better
256 // ways, happy to update this code.
257 PostDfsConstIterator{}.iterate(
258 _graph, [&](const OperationIndex &node_index, const Operation &node) {
259 // LowerInfo for in/output operands
260 auto backend = backend_resolver.getBackend(node_index);
262 // Get frontend's layout
263 auto frontend_layout = _graph.layout();
265 // The layout of each backend should be set at another place
266 // TODO Change setting layout of each backend at another place
267 auto backend_layout = backend->config()->supportLayout(node, frontend_layout);
269 for (auto operand : node.getInputs() | ir::Remove::UNDEFINED)
271 auto &&lower_info = operands_lower_info.at(operand);
272 lower_info->addUsePermuteFactor(operand::PermuteFactor{backend, backend_layout});
274 for (auto operand : node.getOutputs())
276 auto &&lower_info = operands_lower_info.at(operand);
277 lower_info->addDefPermuteFactor(operand::PermuteFactor{backend, backend_layout});
280 bool new_op_seq = (op_seq == nullptr ||
281 (op_seq_max_node != 0 &&
282 op_seq->operations().size() >= static_cast<size_t>(op_seq_max_node)));
284 // for profiling each op_seq must contain just one node,
285 // so that we can measure a node separately
286 if (new_op_seq || is_profiling ||
287 !mergeable(op_seq_index, node_index, backend_layout, backend_resolver))
289 auto new_op_seq_index = appendFreshSingleOpSequence(node_index, node);
291 // OpSequence LowerInfo
292 setLowerInfo(new_op_seq_index,
293 std::make_unique<operation::LowerInfo>(backend, backend_layout));
295 op_seq_index = new_op_seq_index;
296 op_seq = &(_op_seqs.at(new_op_seq_index));
298 VERBOSE(Lower) << "OpSequence#" << op_seq_index.value() << " is created for "
299 << "NODE#" << node_index.value() << "(" << node.name() << ")" << std::endl;
303 op_seq->appendOperation(node_index);
305 auto new_inputs = node.getInputs();
306 // Add inputs except outputs of the previous node
307 for (auto ind : op_seq->getInputs())
309 if (!node.getOutputs().contains(ind))
310 new_inputs.append(ind);
312 op_seq->setInputs(new_inputs);
314 VERBOSE(Lower) << "OpSequence#" << op_seq_index.value() << " merges "
315 << "NODE#" << node_index.value() << "(" << node.name() << ")" << std::endl;
320 void LoweredGraph::manipulateLowerInfo(
321 OperandIndexMap<std::unique_ptr<operand::LowerInfo>> &operands_lower_info, bool is_primary)
323 const auto controlflow_backend = compiler::BackendManager::get().getControlflow();
325 // TODO Rather than handling primary graph specially,
326 // let the permute inserted and remove it later
329 // TODO Rather than using NHWC Get frontend layout of this node from IR
330 auto factor = operand::PermuteFactor{controlflow_backend, Layout::NHWC};
331 for (auto index : _graph.getInputs() | ir::Remove::UNDEFINED)
333 auto &&lower_info = operands_lower_info.at(index);
334 assert(lower_info->def_factors().empty());
335 lower_info->addDefPermuteFactor(factor);
337 for (auto index : _graph.getOutputs())
339 auto &&lower_info = operands_lower_info.at(index);
340 lower_info->addUsePermuteFactor(factor);
345 for (auto index : _graph.getInputs() | ir::Remove::UNDEFINED)
347 auto &&lower_info = operands_lower_info.at(index);
348 if (!(lower_info->def_factors().size() == 0 && lower_info->use_factors().size() == 0))
350 // In case of not that Graph's input is not used in any operation and not the graph's
352 // In other words, it is not unused input in Graph.
353 lower_info->addDefPermuteFactor(*lower_info->use_factors().begin());
357 // In case of that an operand is Graph's input and not input or output of any operation
358 lower_info->addDefPermuteFactor(operand::PermuteFactor{
360 Layout::NHWC // TODO Get frontend layout of this node from IR
365 for (auto index : _graph.getOutputs())
367 auto &&lower_info = operands_lower_info.at(index);
368 if (lower_info->def_factors().size() == 0)
370 // In case of that an operand is Graph's output and not input or output of any operation
371 lower_info->addDefPermuteFactor(operand::PermuteFactor{
373 Layout::NHWC // TODO Get frontend layout of this node from IR
378 // Set LowerInfo for each operand from the operand::LowerInfo holder
379 _graph.operands().iterate([&](const OperandIndex &index, Operand &) {
380 setLowerInfo(index, std::move(operands_lower_info[index]));
384 void LoweredGraph::dumpLowerInfo()
386 if (::onert::util::logging::ctx.enabled() == false)
389 std::map<uint32_t, std::string> dumps;
391 _graph.operands().iterate([&](const OperandIndex &index, Operand &object) {
392 std::stringstream sstream;
393 if (!getLowerInfo(index)->def_factors().empty() || !getLowerInfo(index)->use_factors().empty())
395 auto factors_to_string = [](const operand::PermuteFactorSet &factors) {
397 for (auto factor : factors)
399 str += factor.backend()->config()->id();
400 str += "(" + to_string(factor.layout()) + ")";
403 return "{ " + str + "}";
406 auto operation_index_to_string = [](const OperationIndexSet &operations) {
408 for (auto op : operations)
410 str += std::to_string(op.value());
413 return "{ " + str + "}";
416 const auto lower_info = getLowerInfo(index);
417 const auto &shape = object.shape();
418 std::string def_ops =
419 object.getDef().valid() ? std::to_string(object.getDef().value()) : "N/A";
420 std::string use_ops = operation_index_to_string(object.getUses());
421 std::string def_layouts = factors_to_string(lower_info->def_factors());
422 std::string use_layouts = factors_to_string(lower_info->use_factors());
423 sstream << "Operand #" << index.value() << " LowerInfo" << std::endl;
424 sstream << " - Shape : { ";
425 for (auto i = 0; i < shape.rank(); ++i)
427 sstream << (shape.dim(i)) << " ";
429 sstream << "}" << std::endl;
430 sstream << " - Def Operations : " << def_ops << std::endl;
431 sstream << " - Use Operations : " << use_ops << std::endl;
432 sstream << " - Lower Info" << std::endl;
433 sstream << " - Def Backends : " << def_layouts << std::endl;
434 sstream << " - Use Backends : " << use_layouts << std::endl;
436 dumps.emplace(index.value(), sstream.str());
439 for (const auto &e : dumps)
441 if (!e.second.empty())
443 VERBOSE(Lower) << e.second;
448 bool LoweredGraph::mergeable(const OpSequenceIndex &op_seq_index, const OperationIndex &node_index,
449 Layout layout, const compiler::BackendResolver &backend_resolver)
451 // Are they mergeable?
452 // 1. the same backend id and layout?
453 // 2. Is op_seq or node branched?
454 // 3. if 1 is true, the op_seq and a node are connected?
455 const auto &op_seq = _op_seqs.at(op_seq_index);
456 const auto &node = _graph.operations().at(node_index);
458 // The same backend id and layout?
460 const auto op_seq_backend_layout = getLowerInfo(op_seq_index)->layout();
461 const auto &op_seq_backend_id = getLowerInfo(op_seq_index)->backend()->config()->id();
462 const auto &node_backend_id = backend_resolver.getBackend(node_index)->config()->id();
463 VERBOSE(Lower) << "OpSequence#" << op_seq_index.value() << " { " << op_seq_backend_id << "("
464 << to_string(op_seq_backend_layout) << ") } "
465 << " NODE#" << node_index.value() << " (" << node.name() << ") { "
466 << node_backend_id << "(" << to_string(layout) << ") } " << std::endl;
467 if (op_seq_backend_id != node_backend_id || op_seq_backend_layout != layout)
473 std::unordered_set<OperationIndex> branched_set;
475 // Check for branching up
476 for (const auto &input : op_seq.getInputs() | Remove::DUPLICATED | ir::Remove::UNDEFINED)
478 const auto &input_obj = _graph.operands().at(input);
479 auto def = input_obj.getDef();
482 branched_set.insert(def);
483 if (branched_set.size() > 1)
489 branched_set.clear();
491 // Check for branching down
492 for (const auto &output : node.getOutputs() | Remove::DUPLICATED)
494 // TODO Fix this workaround for the case of model outputs that are used by another operation
495 // This is needed since the branching is decided by operation, but for model outputs,
496 // there is controlflow backen(use backend) but no actual use operation exists
497 if (_graph.getOutputs().contains(output))
500 const auto &output_obj = _graph.operands().at(output);
501 for (const auto &use : output_obj.getUses())
503 branched_set.insert(use);
504 if (branched_set.size() > 1)
513 // an input of one node is an output of the other node? or vice-versa?
515 const auto &node_inputs = node.getInputs();
516 const auto &node_outputs = node.getOutputs();
518 // op_seq's operations are in order so that we just check the first and the last
519 std::vector<OperationIndex> op_seq_ops{op_seq.operations()[0]};
520 if (op_seq.operations().size() > 1)
521 op_seq_ops.emplace_back(op_seq.operations()[op_seq.operations().size() - 1]);
523 for (const auto &n_index : op_seq_ops)
525 const auto &n = _graph.operations().at(n_index);
527 // node's output == op_seq's input?
528 for (const auto input : n.getInputs() | ir::Remove::UNDEFINED)
530 if (node_outputs.contains(input))
532 VERBOSE(Lower) << "OpSequence#" << op_seq_index.value() << " 's NODE#" << n_index.value()
533 << "(" << n.name() << ") is connected to NODE#" << node_index.value()
534 << "(" << node.name() << ")" << std::endl;
539 // node's input == op_seq's output?
540 for (const auto output : n.getOutputs())
542 if (node_inputs.contains(output))
544 VERBOSE(Lower) << "OpSequence#" << op_seq_index.value() << " 's NODE#" << n_index.value()
545 << " (" << n.name() << ") is connected to NODE#" << node_index.value()
552 VERBOSE(Lower) << "OpSequence#" << op_seq_index.value() << " is not connected to NODE#"
553 << node_index.value() << "(" << node.name() << ")" << std::endl;