2 * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
8 * http://www.apache.org/licenses/LICENSE-2.0
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
17 #include "ir/LoweredGraph.h"
21 #include "util/logging.h"
22 #include "pass/ConstantInsertionPass.h"
23 #include "pass/ConstantLoweringPass.h"
24 #include "pass/PermutationOperationPass.h"
25 #include "pass/PermutationInsertionPass.h"
26 #include "ir/GraphIterator.h"
27 #include "verifier/Verifier.h"
28 #include "backend/Backend.h"
29 #include "backend/IConfig.h"
30 #include "compiler/BackendResolver.h"
31 #include "compiler/ManualScheduler.h"
32 #include "compiler/HEScheduler.h"
39 LoweredGraph::LoweredGraph(const Graph &graph, const compiler::CompilerOptions &options)
42 bool linear_executor = (options.executor == "Linear");
44 // Build backend contexts
45 auto &backend_manager = compiler::BackendManager::get();
47 // Always create Controlflow backend context
48 auto cf_backend = backend_manager.getControlflow();
49 _backend_contexts.emplace(
50 cf_backend, cf_backend->newContext(_graph, _graph.getKernelBuilder(), linear_executor));
52 // Create contexts for other backends
53 for (auto backend_str : options.backend_list)
55 backend_manager.loadBackend(backend_str);
56 auto backend = backend_manager.get(backend_str);
58 // TODO As the default value of backend list contains "cpu", "acl_cl" and "acl_neon", and some
59 // are not available on x64 or some other platforms. So this may be a workaround for x64 and
60 // we should change it back(throw if backend is not loaded) later.
63 VERBOSE(LoweredGraph) << "Cannot load backend - " << backend_str;
67 _backend_contexts.emplace(
68 backend, backend->newContext(_graph, _graph.getKernelBuilder(), linear_executor));
70 if (backend_manager.num_backends() == 0)
71 throw std::runtime_error{"No available backends loaded."};
73 // TODO Move "schedule" phase out of here
75 std::unique_ptr<compiler::BackendResolver> backend_resolver;
76 if (options.he_scheduler)
78 auto scheduler = compiler::HEScheduler(_backend_contexts, options);
79 backend_resolver = scheduler.schedule(_graph);
80 _indexed_ranks = scheduler.getIndexedRanks();
84 auto scheduler = compiler::ManualScheduler(_backend_contexts, options);
85 backend_resolver = scheduler.schedule(_graph);
89 // operand::LowerInfo holder
90 OperandIndexMap<std::unique_ptr<operand::LowerInfo>> operands_lower_info;
92 _graph.operands().iterate([&](const OperandIndex &index, const Operand &) {
93 operands_lower_info[index] = std::make_unique<operand::LowerInfo>();
96 // Make op_seqs while checking whether a node can be merged into a op_seq.
97 makeOpSequences(operands_lower_info, options, *backend_resolver);
99 _op_seqs.iterate([&](const OpSequenceIndex &, OpSequence &op_seq) {
100 assert(op_seq.operations().size() > 0);
101 std::reverse(std::begin(op_seq.operations()), std::end(op_seq.operations()));
104 _op_seqs.dump("merged and sorted operations without permutation", _graph.operations());
106 pass::ConstantInsertionPass ci_pass(*this);
109 pass::ConstantLoweringPass cl_pass(*this);
112 // Set LowerInfo for each operand from the operand::LowerInfo holder
113 manipulateLowerInfo(operands_lower_info, options.is_primary_subgraph);
118 // Run Permutation Passes
120 pass::PermutationOperationPass po_pass(*this);
123 pass::PermutationInsertionPass pi_pass(*this);
125 // Implemented code no longer works.
126 // pass::PermutationEliminationPass pe_pass(*this);
129 _op_seqs.dump("merged and sorted operations with permutation", _graph.operations());
132 // Graph verifications
134 assert(verifier::DAGChecker().verify(_graph));
135 assert(verifier::EdgeConsistencyChecker().verify(_graph));
139 const operation::LowerInfo *LoweredGraph::getLowerInfo(const OpSequenceIndex &op_seq_index) const
141 auto itr = _lower_info_map.op_seq.find(op_seq_index);
142 if (itr == _lower_info_map.op_seq.end())
144 return itr->second.get();
147 void LoweredGraph::setLowerInfo(const OpSequenceIndex &op_seq_index,
148 std::unique_ptr<operation::LowerInfo> &&lower_info)
150 _lower_info_map.op_seq.insert(std::make_pair(op_seq_index, std::move(lower_info)));
153 void LoweredGraph::removeLowerInfo(const OpSequenceIndex &op_seq_index)
155 auto &op_seq_lower_info = _lower_info_map.op_seq;
156 assert(op_seq_lower_info.find(op_seq_index) != op_seq_lower_info.end());
157 for (auto it = op_seq_lower_info.begin(); it != op_seq_lower_info.end(); ++it)
159 if (it->first == op_seq_index)
161 op_seq_lower_info.erase(it);
167 const operand::LowerInfo *LoweredGraph::getLowerInfo(const OperandIndex &index) const
169 auto itr = _lower_info_map.operand.find(index);
170 if (itr == _lower_info_map.operand.end())
172 return itr->second.get();
175 operand::LowerInfo *LoweredGraph::getLowerInfo(const OperandIndex &index)
177 auto itr = _lower_info_map.operand.find(index);
178 if (itr == _lower_info_map.operand.end())
180 return itr->second.get();
183 void LoweredGraph::setLowerInfo(const OperandIndex &index,
184 std::unique_ptr<operand::LowerInfo> &&lower_info)
186 _lower_info_map.operand.insert(std::make_pair(index, std::move(lower_info)));
189 void LoweredGraph::removeLowerInfo(const OperandIndex &index)
191 _lower_info_map.operand.erase(index);
194 void LoweredGraph::iterateTopolOpSeqs(
195 const std::function<void(const OpSequenceIndex &, const OpSequence &)> &fn) const
197 // Topological Sorting for OpSequences
198 std::vector<OpSequenceIndex> topol_sorted;
199 PostDfsIterator<true>{}.iterateOpSeqs(
201 [&](const OpSequenceIndex &index, const OpSequence &) { topol_sorted.emplace_back(index); });
202 std::reverse(topol_sorted.begin(), topol_sorted.end());
203 for (const auto op_seq_idx : topol_sorted)
205 const auto &op_seq = _op_seqs.at(op_seq_idx);
206 fn(op_seq_idx, op_seq);
210 void LoweredGraph::iterateTopolOpSeqs(
211 const std::function<void(const OpSequenceIndex &, OpSequence &)> &fn)
213 // Topological Sorting for OpSequences
214 std::vector<OpSequenceIndex> topol_sorted;
215 PostDfsIterator<false>{}.iterateOpSeqs(
216 *this, [&](const OpSequenceIndex &index, OpSequence &) { topol_sorted.emplace_back(index); });
217 std::reverse(topol_sorted.begin(), topol_sorted.end());
218 for (const auto op_seq_idx : topol_sorted)
220 auto &op_seq = _op_seqs.at(op_seq_idx);
221 fn(op_seq_idx, op_seq);
225 OpSequenceIndex LoweredGraph::appendFreshSingleOpSequence(const OperationIndex &node_index,
226 const Operation &node)
228 // Create a fresh op_seq with one operation, and append it to op_seqs
229 // Create a fresh op_seq
230 auto op_seq = std::make_unique<OpSequence>(_graph.layout());
233 op_seq->appendOperation(node_index);
235 // Update input/output
236 op_seq->setOutputs(node.getOutputs());
237 op_seq->setInputs(node.getInputs());
239 return _op_seqs.emplace(std::move(op_seq));
242 void LoweredGraph::makeOpSequences(
243 OperandIndexMap<std::unique_ptr<operand::LowerInfo>> &operands_lower_info,
244 const compiler::CompilerOptions &options, const compiler::BackendResolver &backend_resolver)
246 // if SUBG_MAX_NODE == 0, no limit on nodes of a op_seq
247 const int op_seq_max_node = options.op_seq_max_node;
248 assert(op_seq_max_node >= 0);
250 bool is_profiling = options.he_profiling_mode;
251 OpSequence *op_seq = nullptr;
252 OpSequenceIndex op_seq_index;
254 // NOTE: The below method appends nodes while making one op_seq if needed. If something better
255 // ways, happy to update this code.
256 PostDfsConstIterator{}.iterate(
257 _graph, [&](const OperationIndex &node_index, const Operation &node) {
258 // LowerInfo for in/output operands
259 auto backend = backend_resolver.getBackend(node_index);
261 // Get frontend's layout
262 auto frontend_layout = _graph.layout();
264 // The layout of each backend should be set at another place
265 // TODO Change setting layout of each backend at another place
266 auto backend_layout = backend->config()->supportLayout(node, frontend_layout);
268 for (auto operand : node.getInputs() | ir::Remove::UNDEFINED)
270 auto &&lower_info = operands_lower_info.at(operand);
271 lower_info->addUsePermuteFactor(operand::PermuteFactor{backend, backend_layout});
273 for (auto operand : node.getOutputs())
275 auto &&lower_info = operands_lower_info.at(operand);
276 lower_info->addDefPermuteFactor(operand::PermuteFactor{backend, backend_layout});
279 bool new_op_seq = (op_seq == nullptr ||
280 (op_seq_max_node != 0 &&
281 op_seq->operations().size() >= static_cast<size_t>(op_seq_max_node)));
283 // for profiling each op_seq must contain just one node,
284 // so that we can measure a node separately
285 if (new_op_seq || is_profiling ||
286 !mergeable(op_seq_index, node_index, backend_layout, backend_resolver))
288 auto new_op_seq_index = appendFreshSingleOpSequence(node_index, node);
290 // OpSequence LowerInfo
291 setLowerInfo(new_op_seq_index,
292 std::make_unique<operation::LowerInfo>(backend, backend_layout));
294 op_seq_index = new_op_seq_index;
295 op_seq = &(_op_seqs.at(new_op_seq_index));
297 VERBOSE(Lower) << "OpSequence#" << op_seq_index.value() << " is created for "
298 << "NODE#" << node_index.value() << "(" << node.name() << ")" << std::endl;
302 op_seq->appendOperation(node_index);
304 auto new_inputs = node.getInputs();
305 // Add inputs except outputs of the previous node
306 for (auto ind : op_seq->getInputs())
308 if (!node.getOutputs().contains(ind))
309 new_inputs.append(ind);
311 op_seq->setInputs(new_inputs);
313 VERBOSE(Lower) << "OpSequence#" << op_seq_index.value() << " merges "
314 << "NODE#" << node_index.value() << "(" << node.name() << ")" << std::endl;
319 void LoweredGraph::manipulateLowerInfo(
320 OperandIndexMap<std::unique_ptr<operand::LowerInfo>> &operands_lower_info, bool is_primary)
322 const auto controlflow_backend = compiler::BackendManager::get().getControlflow();
324 // TODO Rather than handling primary graph specially,
325 // let the permute inserted and remove it later
328 // TODO Rather than using NHWC Get frontend layout of this node from IR
329 auto factor = operand::PermuteFactor{controlflow_backend, Layout::NHWC};
330 for (auto index : _graph.getInputs() | ir::Remove::UNDEFINED)
332 auto &&lower_info = operands_lower_info.at(index);
333 assert(lower_info->def_factors().empty());
334 lower_info->addDefPermuteFactor(factor);
336 for (auto index : _graph.getOutputs())
338 auto &&lower_info = operands_lower_info.at(index);
339 lower_info->addUsePermuteFactor(factor);
344 for (auto index : _graph.getInputs() | ir::Remove::UNDEFINED)
346 auto &&lower_info = operands_lower_info.at(index);
347 if (!(lower_info->def_factors().size() == 0 && lower_info->use_factors().size() == 0))
349 // In case of not that Graph's input is not used in any operation and not the graph's
351 // In other words, it is not unused input in Graph.
352 lower_info->addDefPermuteFactor(*lower_info->use_factors().begin());
356 // In case of that an operand is Graph's input and not input or output of any operation
357 lower_info->addDefPermuteFactor(operand::PermuteFactor{
359 Layout::NHWC // TODO Get frontend layout of this node from IR
364 for (auto index : _graph.getOutputs())
366 auto &&lower_info = operands_lower_info.at(index);
367 if (lower_info->def_factors().size() == 0)
369 // In case of that an operand is Graph's output and not input or output of any operation
370 lower_info->addDefPermuteFactor(operand::PermuteFactor{
372 Layout::NHWC // TODO Get frontend layout of this node from IR
377 // Set LowerInfo for each operand from the operand::LowerInfo holder
378 _graph.operands().iterate([&](const OperandIndex &index, Operand &) {
379 setLowerInfo(index, std::move(operands_lower_info[index]));
383 void LoweredGraph::dumpLowerInfo()
385 if (::onert::util::logging::ctx.enabled() == false)
388 std::map<uint32_t, std::string> dumps;
390 _graph.operands().iterate([&](const OperandIndex &index, Operand &object) {
391 std::stringstream sstream;
392 if (!getLowerInfo(index)->def_factors().empty() || !getLowerInfo(index)->use_factors().empty())
394 auto factors_to_string = [](const operand::PermuteFactorSet &factors) {
396 for (auto factor : factors)
398 str += factor.backend()->config()->id();
399 str += "(" + to_string(factor.layout()) + ")";
402 return "{ " + str + "}";
405 auto operation_index_to_string = [](const OperationIndexSet &operations) {
407 for (auto op : operations)
409 str += std::to_string(op.value());
412 return "{ " + str + "}";
415 const auto lower_info = getLowerInfo(index);
416 const auto &shape = object.shape();
417 std::string def_ops = operation_index_to_string(object.getDef());
418 std::string use_ops = operation_index_to_string(object.getUses());
419 std::string def_layouts = factors_to_string(lower_info->def_factors());
420 std::string use_layouts = factors_to_string(lower_info->use_factors());
421 sstream << "Operand #" << index.value() << " LowerInfo" << std::endl;
422 sstream << " - Shape : { ";
423 for (auto i = 0; i < shape.rank(); ++i)
425 sstream << (shape.dim(i)) << " ";
427 sstream << "}" << std::endl;
428 sstream << " - Def Operations : " << def_ops << std::endl;
429 sstream << " - Use Operations : " << use_ops << std::endl;
430 sstream << " - Lower Info" << std::endl;
431 sstream << " - Def Backends : " << def_layouts << std::endl;
432 sstream << " - Use Backends : " << use_layouts << std::endl;
434 dumps.emplace(index.value(), sstream.str());
437 for (const auto &e : dumps)
439 if (!e.second.empty())
441 VERBOSE(Lower) << e.second;
446 bool LoweredGraph::mergeable(const OpSequenceIndex &op_seq_index, const OperationIndex &node_index,
447 Layout layout, const compiler::BackendResolver &backend_resolver)
449 // Are they mergeable?
450 // 1. the same backend id and layout?
451 // 2. Is op_seq or node branched?
452 // 3. if 1 is true, the op_seq and a node are connected?
453 const auto &op_seq = _op_seqs.at(op_seq_index);
454 const auto &node = _graph.operations().at(node_index);
456 // The same backend id and layout?
458 const auto op_seq_backend_layout = getLowerInfo(op_seq_index)->layout();
459 const auto &op_seq_backend_id = getLowerInfo(op_seq_index)->backend()->config()->id();
460 const auto &node_backend_id = backend_resolver.getBackend(node_index)->config()->id();
461 VERBOSE(Lower) << "OpSequence#" << op_seq_index.value() << " { " << op_seq_backend_id << "("
462 << to_string(op_seq_backend_layout) << ") } "
463 << " NODE#" << node_index.value() << " (" << node.name() << ") { "
464 << node_backend_id << "(" << to_string(layout) << ") } " << std::endl;
465 if (op_seq_backend_id != node_backend_id || op_seq_backend_layout != layout)
471 std::unordered_set<OperationIndex> branched_set;
473 // Check for branching up
474 for (const auto &input : op_seq.getInputs() | Remove::DUPLICATED | ir::Remove::UNDEFINED)
476 const auto &input_obj = _graph.operands().at(input);
477 for (const auto &def : input_obj.getDef())
479 branched_set.insert(def);
480 if (branched_set.size() > 1)
486 branched_set.clear();
488 // Check for branching down
489 for (const auto &output : node.getOutputs() | Remove::DUPLICATED)
491 const auto &output_obj = _graph.operands().at(output);
492 for (const auto &use : output_obj.getUses())
494 branched_set.insert(use);
495 if (branched_set.size() > 1)
504 // an input of one node is an output of the other node? or vice-versa?
506 const auto &node_inputs = node.getInputs();
507 const auto &node_outputs = node.getOutputs();
509 // op_seq's operations are in order so that we just check the first and the last
510 std::vector<OperationIndex> op_seq_ops{op_seq.operations()[0]};
511 if (op_seq.operations().size() > 1)
512 op_seq_ops.emplace_back(op_seq.operations()[op_seq.operations().size() - 1]);
514 for (const auto &n_index : op_seq_ops)
516 const auto &n = _graph.operations().at(n_index);
518 // node's output == op_seq's input?
519 for (const auto input : n.getInputs() | ir::Remove::UNDEFINED)
521 if (node_outputs.contains(input))
523 VERBOSE(Lower) << "OpSequence#" << op_seq_index.value() << " 's NODE#" << n_index.value()
524 << "(" << n.name() << ") is connected to NODE#" << node_index.value()
525 << "(" << node.name() << ")" << std::endl;
530 // node's input == op_seq's output?
531 for (const auto output : n.getOutputs())
533 if (node_inputs.contains(output))
535 VERBOSE(Lower) << "OpSequence#" << op_seq_index.value() << " 's NODE#" << n_index.value()
536 << " (" << n.name() << ") is connected to NODE#" << node_index.value()
543 VERBOSE(Lower) << "OpSequence#" << op_seq_index.value() << " is not connected to NODE#"
544 << node_index.value() << "(" << node.name() << ")" << std::endl;