2 * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
8 * http://www.apache.org/licenses/LICENSE-2.0
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
17 #include "compiler/LoweredGraph.h"
21 #include "util/logging.h"
22 #include "compiler/pass/ConstantInsertionPass.h"
23 #include "compiler/pass/ConstantLoweringPass.h"
24 #include "compiler/pass/PassRunner.h"
25 #include "compiler/pass/PermutationOperationPass.h"
26 #include "compiler/pass/PermutationInsertionPass.h"
27 #include "compiler/pass/PermutationEliminationPass.h"
28 #include "ir/GraphIterator.h"
29 #include "ir/verifier/Verifier.h"
30 #include "backend/Backend.h"
31 #include "backend/IConfig.h"
32 #include "compiler/BackendResolver.h"
33 #include "compiler/ManualScheduler.h"
34 #include "compiler/HEScheduler.h"
35 #include "util/TracingCtx.h"
42 LoweredGraph::LoweredGraph(const ir::Graph &graph, const CompilerOptions &options) : _graph{graph}
44 // set tracing_ctx for copied graph
45 if (options.tracing_ctx)
47 auto subgraph_index = options.tracing_ctx->getSubgraphIndex(&graph);
48 options.tracing_ctx->setSubgraphIndex(&_graph, subgraph_index.value());
51 bool linear_executor = (options.executor == "Linear");
53 // Build backend contexts
54 auto &backend_manager = BackendManager::get();
56 // Always create Controlflow backend context
57 auto cf_backend = backend_manager.getControlflow();
58 _backend_contexts.emplace(
59 cf_backend, cf_backend->newContext(_graph, _graph.getKernelBuilder(), linear_executor));
61 // Create contexts for other backends
62 for (auto backend_str : options.backend_list)
64 backend_manager.loadBackend(backend_str);
65 auto backend = backend_manager.get(backend_str);
67 // TODO As the default value of backend list contains "cpu", "acl_cl" and "acl_neon", and some
68 // are not available on x64 or some other platforms. So this may be a workaround for x64 and
69 // we should change it back(throw if backend is not loaded) later.
72 VERBOSE(LoweredGraph) << "Cannot load backend - " << backend_str << std::endl;
76 _backend_contexts.emplace(
77 backend, backend->newContext(_graph, _graph.getKernelBuilder(), linear_executor));
79 if (backend_manager.num_backends() == 0)
80 throw std::runtime_error{"No available backends loaded."};
82 // TODO Move "schedule" phase out of here
84 std::unique_ptr<BackendResolver> backend_resolver;
85 if (options.he_scheduler)
87 auto scheduler = HEScheduler(_backend_contexts, options);
88 backend_resolver = scheduler.schedule(_graph);
89 _indexed_ranks = scheduler.getIndexedRanks();
93 auto scheduler = ManualScheduler(_backend_contexts, options);
94 backend_resolver = scheduler.schedule(_graph);
98 // operand::LowerInfo holder
99 ir::OperandIndexMap<std::unique_ptr<ir::operand::LowerInfo>> operands_lower_info;
101 _graph.operands().iterate([&](const ir::OperandIndex &index, const ir::Operand &) {
102 operands_lower_info[index] = std::make_unique<ir::operand::LowerInfo>();
105 // Make op_seqs while checking whether a node can be merged into a op_seq.
106 makeOpSequences(operands_lower_info, options, *backend_resolver);
108 _op_seqs.iterate([&](const ir::OpSequenceIndex &, ir::OpSequence &op_seq) {
109 assert(op_seq.operations().size() > 0);
110 std::reverse(std::begin(op_seq.operations()), std::end(op_seq.operations()));
113 VERBOSE(OpSequences) << "dump before permutation insertion" << std::endl;
114 dumpOpSequences(_op_seqs, _graph.operations());
118 .append(std::make_unique<pass::ConstantInsertionPass>(*this))
119 .append(std::make_unique<pass::ConstantLoweringPass>(*this))
122 // Set LowerInfo for each operand from the operand::LowerInfo holder
123 manipulateLowerInfo(operands_lower_info);
130 .append(std::make_unique<pass::PermutationOperationPass>(*this))
131 .append(std::make_unique<pass::PermutationInsertionPass>(*this))
134 // Optimization passes
135 pass::PassRunner{}.append(std::make_unique<pass::PermutationEliminationPass>(*this)).run();
137 VERBOSE(LoweredGraph) << "Dump after permutation insertion" << std::endl;
138 for (auto operand : _graph.getInputs())
139 VERBOSE(LoweredGraph) << "Graph Input : " << operand << std::endl;
140 for (auto operand : _graph.getOutputs())
141 VERBOSE(LoweredGraph) << "Graph Output : " << operand << std::endl;
142 dumpOpSequences(_op_seqs, _graph.operations());
144 // Graph verifications
146 assert(ir::verifier::InputOutputChecker().verify(_graph));
147 assert(ir::verifier::DAGChecker().verify(_graph));
148 assert(ir::verifier::EdgeConsistencyChecker().verify(_graph));
152 const ir::operation::LowerInfo *
153 LoweredGraph::getLowerInfo(const ir::OpSequenceIndex &op_seq_index) const
155 auto itr = _lower_info_map.op_seq.find(op_seq_index);
156 if (itr == _lower_info_map.op_seq.end())
158 return itr->second.get();
161 void LoweredGraph::setLowerInfo(const ir::OpSequenceIndex &op_seq_index,
162 std::unique_ptr<ir::operation::LowerInfo> &&lower_info)
164 _lower_info_map.op_seq.insert(std::make_pair(op_seq_index, std::move(lower_info)));
167 void LoweredGraph::removeLowerInfo(const ir::OpSequenceIndex &op_seq_index)
169 auto &op_seq_lower_info = _lower_info_map.op_seq;
170 assert(op_seq_lower_info.find(op_seq_index) != op_seq_lower_info.end());
171 for (auto it = op_seq_lower_info.begin(); it != op_seq_lower_info.end(); ++it)
173 if (it->first == op_seq_index)
175 op_seq_lower_info.erase(it);
181 const ir::operand::LowerInfo *LoweredGraph::getLowerInfo(const ir::OperandIndex &index) const
183 auto itr = _lower_info_map.operand.find(index);
184 if (itr == _lower_info_map.operand.end())
186 return itr->second.get();
189 ir::operand::LowerInfo *LoweredGraph::getLowerInfo(const ir::OperandIndex &index)
191 auto itr = _lower_info_map.operand.find(index);
192 if (itr == _lower_info_map.operand.end())
194 return itr->second.get();
197 void LoweredGraph::setLowerInfo(const ir::OperandIndex &index,
198 std::unique_ptr<ir::operand::LowerInfo> &&lower_info)
200 _lower_info_map.operand.insert(std::make_pair(index, std::move(lower_info)));
203 void LoweredGraph::removeLowerInfo(const ir::OperandIndex &index)
205 _lower_info_map.operand.erase(index);
208 void LoweredGraph::iterateTopolOpSeqs(
209 const std::function<void(const ir::OpSequenceIndex &, const ir::OpSequence &)> &fn) const
211 // Topological Sorting for ir::OpSequences
212 std::vector<ir::OpSequenceIndex> topol_sorted;
213 ir::PostDfsIterator<true>{}.iterateOpSeqs(
214 *this, [&](const ir::OpSequenceIndex &index, const ir::OpSequence &) {
215 topol_sorted.emplace_back(index);
217 std::reverse(topol_sorted.begin(), topol_sorted.end());
218 for (const auto op_seq_idx : topol_sorted)
220 const auto &op_seq = _op_seqs.at(op_seq_idx);
221 fn(op_seq_idx, op_seq);
225 void LoweredGraph::iterateTopolOpSeqs(
226 const std::function<void(const ir::OpSequenceIndex &, ir::OpSequence &)> &fn)
228 // Topological Sorting for ir::OpSequences
229 std::vector<ir::OpSequenceIndex> topol_sorted;
230 ir::PostDfsIterator<false>{}.iterateOpSeqs(
231 *this, [&](const ir::OpSequenceIndex &index, ir::OpSequence &) {
232 topol_sorted.emplace_back(index);
234 std::reverse(topol_sorted.begin(), topol_sorted.end());
235 for (const auto op_seq_idx : topol_sorted)
237 auto &op_seq = _op_seqs.at(op_seq_idx);
238 fn(op_seq_idx, op_seq);
242 ir::OpSequenceIndex LoweredGraph::appendFreshSingleOpSequence(const ir::OperationIndex &node_index,
243 const ir::Operation &node)
245 // Create a fresh op_seq with one operation, and append it to op_seqs
246 // Create a fresh op_seq
247 auto op_seq = std::make_unique<ir::OpSequence>(_graph.layout());
250 op_seq->appendOperation(node_index);
252 // Update input/output
253 op_seq->setOutputs(node.getOutputs());
254 op_seq->setInputs(node.getInputs());
256 return _op_seqs.emplace(std::move(op_seq));
259 void LoweredGraph::makeOpSequences(
260 ir::OperandIndexMap<std::unique_ptr<ir::operand::LowerInfo>> &operands_lower_info,
261 const CompilerOptions &options, const BackendResolver &backend_resolver)
263 // if SUBG_MAX_NODE == 0, no limit on nodes of a op_seq
264 const int op_seq_max_node = options.op_seq_max_node;
265 assert(op_seq_max_node >= 0);
267 bool is_profiling = options.he_profiling_mode;
268 ir::OpSequence *op_seq = nullptr;
269 ir::OpSequenceIndex op_seq_index;
271 // NOTE: The below method appends nodes while making one op_seq if needed. If something better
272 // ways, happy to update this code.
273 ir::PostDfsConstIterator{}.iterate(
274 _graph, [&](const ir::OperationIndex &node_index, const ir::Operation &node) {
275 // LowerInfo for in/output operands
276 auto backend = backend_resolver.getBackend(node_index);
278 // Get frontend's layout
279 auto frontend_layout = _graph.layout();
281 // The layout of each backend should be set at another place
282 // TODO Change setting layout of each backend at another place
283 auto backend_layout = backend->config()->supportLayout(node, frontend_layout);
285 for (auto operand : node.getInputs() | ir::Remove::UNDEFINED)
287 auto &&lower_info = operands_lower_info.at(operand);
288 lower_info->addUsePermuteFactor(ir::operand::PermuteFactor{backend, backend_layout});
290 for (auto operand : node.getOutputs() | ir::Remove::UNDEFINED)
292 auto &&lower_info = operands_lower_info.at(operand);
293 lower_info->addDefPermuteFactor(ir::operand::PermuteFactor{backend, backend_layout});
296 bool new_op_seq = (op_seq == nullptr ||
297 (op_seq_max_node != 0 &&
298 op_seq->operations().size() >= static_cast<size_t>(op_seq_max_node)));
300 // for profiling each op_seq must contain just one node,
301 // so that we can measure a node separately
302 if (new_op_seq || is_profiling ||
303 !mergeable(op_seq_index, node_index, backend_layout, backend_resolver))
305 auto new_op_seq_index = appendFreshSingleOpSequence(node_index, node);
307 // ir::OpSequence LowerInfo
308 setLowerInfo(new_op_seq_index,
309 std::make_unique<ir::operation::LowerInfo>(backend, backend_layout));
311 op_seq_index = new_op_seq_index;
312 op_seq = &(_op_seqs.at(new_op_seq_index));
314 VERBOSE(Lower) << "OpSequence#" << op_seq_index.value() << " is created for "
315 << "NODE#" << node_index.value() << "(" << node.name() << ")" << std::endl;
319 op_seq->appendOperation(node_index);
321 auto new_inputs = node.getInputs();
322 // Add inputs except outputs of the previous node
323 for (auto ind : op_seq->getInputs())
325 if (!node.getOutputs().contains(ind))
326 new_inputs.append(ind);
328 op_seq->setInputs(new_inputs);
330 VERBOSE(Lower) << "OpSequence#" << op_seq_index.value() << " merges "
331 << "NODE#" << node_index.value() << "(" << node.name() << ")" << std::endl;
336 void LoweredGraph::manipulateLowerInfo(
337 ir::OperandIndexMap<std::unique_ptr<ir::operand::LowerInfo>> &operands_lower_info)
339 const auto controlflow_backend = BackendManager::get().getControlflow();
341 // TODO Rather than using NHWC Get frontend layout of this node from IR
342 auto factor = ir::operand::PermuteFactor{controlflow_backend, ir::Layout::NHWC};
343 for (auto index : _graph.getInputs() | ir::Remove::UNDEFINED)
345 auto &&lower_info = operands_lower_info.at(index);
346 assert(lower_info->def_factors().empty());
347 lower_info->addDefPermuteFactor(factor);
349 for (auto index : _graph.getOutputs() | ir::Remove::UNDEFINED)
351 auto &&lower_info = operands_lower_info.at(index);
352 lower_info->addUsePermuteFactor(factor);
354 for (auto index : _graph.getOutputs() | ir::Remove::UNDEFINED)
356 auto &&lower_info = operands_lower_info.at(index);
357 if (lower_info->def_factors().size() == 0)
359 // In case of that an operand is Graph's output and not input or output of any operation
360 lower_info->addDefPermuteFactor(ir::operand::PermuteFactor{
362 ir::Layout::NHWC // TODO Get frontend layout of this node from IR
367 // 1. Add def of variable operand
368 // 2. Set LowerInfo for each operand from the operand::LowerInfo holder
369 _graph.operands().iterate([&](const ir::OperandIndex &index, ir::Operand &operand) {
370 // Some inputs of an operation could be non-constant, but not existed in graph inputs/outputs
371 // and not undefined operand. Those inputs must have exist as a Tensor. For example,
372 // UnidirectionalSequenceLSTM operation could have state inputs such as it.
373 if (operand.info().isVariable())
375 // The variable operand with buffer is not supported yet
376 assert(operand.data() == nullptr);
377 assert(operand.getUses().size() == 1 && !operand.getDef().valid());
378 auto &lowered_info = operands_lower_info[index];
379 assert(lowered_info->def_factors().empty());
380 lowered_info->addDefPermuteFactor(lowered_info->use_factors().getOnlyElement());
383 setLowerInfo(index, std::move(operands_lower_info[index]));
387 void LoweredGraph::dumpLowerInfo()
389 if (::onert::util::logging::ctx.enabled() == false)
392 std::map<uint32_t, std::string> dumps;
394 _graph.operands().iterate([&](const ir::OperandIndex &index, ir::Operand &object) {
395 std::stringstream sstream;
396 if (!getLowerInfo(index)->def_factors().empty() || !getLowerInfo(index)->use_factors().empty())
398 auto factors_to_string = [](const ir::operand::PermuteFactorSet &factors) {
400 for (auto factor : factors)
402 str += factor.backend()->config()->id();
403 str += "(" + to_string(factor.layout()) + ")";
406 return "{ " + str + "}";
409 auto operation_index_to_string = [](const ir::OperationIndexSet &operations) {
411 for (auto op : operations)
413 str += std::to_string(op.value());
416 return "{ " + str + "}";
419 const auto lower_info = getLowerInfo(index);
420 const auto &shape = object.shape();
421 std::string def_ops =
422 object.getDef().valid() ? std::to_string(object.getDef().value()) : "N/A";
423 std::string use_ops = operation_index_to_string(object.getUses());
424 std::string def_layouts = factors_to_string(lower_info->def_factors());
425 std::string use_layouts = factors_to_string(lower_info->use_factors());
426 sstream << "Operand #" << index.value() << " LowerInfo" << std::endl;
427 sstream << " - Shape : { ";
428 for (auto i = 0; i < shape.rank(); ++i)
430 sstream << (shape.dim(i)) << " ";
432 sstream << "}" << std::endl;
433 sstream << " - Def Operations : " << def_ops << std::endl;
434 sstream << " - Use Operations : " << use_ops << std::endl;
435 sstream << " - Data : "
436 << (object.data() ? (std::to_string(object.data()->size()) + " bytes") : "N/A")
438 sstream << " - Lower Info" << std::endl;
439 sstream << " - Def Backends : " << def_layouts << std::endl;
440 sstream << " - Use Backends : " << use_layouts << std::endl;
442 dumps.emplace(index.value(), sstream.str());
445 for (const auto &e : dumps)
447 if (!e.second.empty())
449 VERBOSE(Lower) << e.second;
454 bool LoweredGraph::mergeable(const ir::OpSequenceIndex &op_seq_index,
455 const ir::OperationIndex &node_index, ir::Layout layout,
456 const BackendResolver &backend_resolver)
458 // Are they mergeable?
459 // 1. the same backend id and layout?
460 // 2. Is op_seq or node branched?
461 // 3. if 1 is true, the op_seq and a node are connected?
462 const auto &op_seq = _op_seqs.at(op_seq_index);
463 const auto &node = _graph.operations().at(node_index);
465 // The same backend id and layout?
467 const auto op_seq_backend_layout = getLowerInfo(op_seq_index)->layout();
468 const auto &op_seq_backend_id = getLowerInfo(op_seq_index)->backend()->config()->id();
469 const auto &node_backend_id = backend_resolver.getBackend(node_index)->config()->id();
470 VERBOSE(Lower) << "OpSequence#" << op_seq_index.value() << " { " << op_seq_backend_id << "("
471 << to_string(op_seq_backend_layout) << ") } "
472 << " NODE#" << node_index.value() << " (" << node.name() << ") { "
473 << node_backend_id << "(" << to_string(layout) << ") } " << std::endl;
474 if (op_seq_backend_id != node_backend_id || op_seq_backend_layout != layout)
480 std::unordered_set<ir::OperationIndex> branched_set;
482 // Check for branching up
483 for (const auto &input : op_seq.getInputs() | ir::Remove::DUPLICATED | ir::Remove::UNDEFINED)
485 const auto &input_obj = _graph.operands().at(input);
486 auto def = input_obj.getDef();
489 branched_set.insert(def);
490 if (branched_set.size() > 1)
496 branched_set.clear();
498 // Check for branching down
499 for (const auto &output : node.getOutputs() | ir::Remove::DUPLICATED | ir::Remove::UNDEFINED)
501 // TODO Fix this workaround for the case of model outputs that are used by another operation
502 // This is needed since the branching is decided by operation, but for model outputs,
503 // there is controlflow backen(use backend) but no actual use operation exists
504 if (_graph.getOutputs().contains(output))
507 const auto &output_obj = _graph.operands().at(output);
508 for (const auto &use : output_obj.getUses())
510 branched_set.insert(use);
511 if (branched_set.size() > 1)
520 // an input of one node is an output of the other node? or vice-versa?
522 const auto &node_inputs = node.getInputs();
523 const auto &node_outputs = node.getOutputs();
525 // op_seq's operations are in order so that we just check the first and the last
526 std::vector<ir::OperationIndex> op_seq_ops{op_seq.operations()[0]};
527 if (op_seq.operations().size() > 1)
528 op_seq_ops.emplace_back(op_seq.operations()[op_seq.operations().size() - 1]);
530 for (const auto &n_index : op_seq_ops)
532 const auto &n = _graph.operations().at(n_index);
534 // node's output == op_seq's input?
535 for (const auto input : n.getInputs() | ir::Remove::UNDEFINED)
537 if (node_outputs.contains(input))
539 VERBOSE(Lower) << "OpSequence#" << op_seq_index.value() << " 's NODE#" << n_index.value()
540 << "(" << n.name() << ") is connected to NODE#" << node_index.value()
541 << "(" << node.name() << ")" << std::endl;
546 // node's input == op_seq's output?
547 for (const auto output : n.getOutputs() | ir::Remove::UNDEFINED)
549 if (node_inputs.contains(output))
551 VERBOSE(Lower) << "OpSequence#" << op_seq_index.value() << " 's NODE#" << n_index.value()
552 << " (" << n.name() << ") is connected to NODE#" << node_index.value()
559 VERBOSE(Lower) << "OpSequence#" << op_seq_index.value() << " is not connected to NODE#"
560 << node_index.value() << "(" << node.name() << ")" << std::endl;
566 } // namespace compiler