8aedfbdf08300418f825194a60be43ecc17495f9
[platform/core/ml/nnfw.git] / runtime / onert / core / src / ir / LoweredGraph.cc
1 /*
2  * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16
17 #include "ir/LoweredGraph.h"
18
19 #include <assert.h>
20 #include <sstream>
21 #include "util/logging.h"
22 #include "pass/ConstantInsertionPass.h"
23 #include "pass/ConstantLoweringPass.h"
24 #include "pass/PermutationOperationPass.h"
25 #include "pass/PermutationInsertionPass.h"
26 #include "pass/PermutationEliminationPass.h"
27 #include "ir/GraphIterator.h"
28 #include "verifier/Verifier.h"
29 #include "backend/Backend.h"
30 #include "backend/IConfig.h"
31 #include "compiler/BackendResolver.h"
32 #include "compiler/ManualScheduler.h"
33 #include "compiler/HEScheduler.h"
34
35 namespace onert
36 {
37 namespace ir
38 {
39
40 LoweredGraph::LoweredGraph(const Graph &graph, const compiler::CompilerOptions &options)
41     : _graph{graph}
42 {
43   bool linear_executor = (options.executor == "Linear");
44
45   // Build backend contexts
46   auto &backend_manager = compiler::BackendManager::get();
47
48   // Always create Controlflow backend context
49   auto cf_backend = backend_manager.getControlflow();
50   _backend_contexts.emplace(
51       cf_backend, cf_backend->newContext(_graph, _graph.getKernelBuilder(), linear_executor));
52
53   // Create contexts for other backends
54   for (auto backend_str : options.backend_list)
55   {
56     backend_manager.loadBackend(backend_str);
57     auto backend = backend_manager.get(backend_str);
58
59     // TODO As the default value of backend list contains "cpu", "acl_cl" and "acl_neon", and some
60     // are not available on x64 or some other platforms. So this may be a workaround for x64 and
61     // we should change it back(throw if backend is not loaded) later.
62     if (!backend)
63     {
64       VERBOSE(LoweredGraph) << "Cannot load backend - " << backend_str;
65       continue;
66     }
67
68     _backend_contexts.emplace(
69         backend, backend->newContext(_graph, _graph.getKernelBuilder(), linear_executor));
70   }
71   if (backend_manager.num_backends() == 0)
72     throw std::runtime_error{"No available backends loaded."};
73
74   // TODO Move "schedule" phase out of here
75   // Schedule
76   std::unique_ptr<compiler::BackendResolver> backend_resolver;
77   if (options.he_scheduler)
78   {
79     auto scheduler = compiler::HEScheduler(_backend_contexts, options);
80     backend_resolver = scheduler.schedule(_graph);
81     _indexed_ranks = scheduler.getIndexedRanks();
82   }
83   else
84   {
85     auto scheduler = compiler::ManualScheduler(_backend_contexts, options);
86     backend_resolver = scheduler.schedule(_graph);
87   }
88
89   {
90     // operand::LowerInfo holder
91     OperandIndexMap<std::unique_ptr<operand::LowerInfo>> operands_lower_info;
92
93     _graph.operands().iterate([&](const OperandIndex &index, const Operand &) {
94       operands_lower_info[index] = std::make_unique<operand::LowerInfo>();
95     });
96
97     // Make op_seqs while checking whether a node can be merged into a op_seq.
98     makeOpSequences(operands_lower_info, options, *backend_resolver);
99
100     _op_seqs.iterate([&](const OpSequenceIndex &, OpSequence &op_seq) {
101       assert(op_seq.operations().size() > 0);
102       std::reverse(std::begin(op_seq.operations()), std::end(op_seq.operations()));
103     });
104
105     _op_seqs.dump("merged and sorted operations without permutation", _graph.operations());
106
107     pass::ConstantInsertionPass ci_pass(*this);
108     ci_pass.run();
109
110     pass::ConstantLoweringPass cl_pass(*this);
111     cl_pass.run();
112
113     // Set LowerInfo for each operand from the operand::LowerInfo holder
114     manipulateLowerInfo(operands_lower_info, options.is_primary_subgraph);
115
116     dumpLowerInfo();
117   }
118
119   // Run Permutation Passes
120   {
121     pass::PermutationOperationPass po_pass(*this);
122     po_pass.run();
123
124     pass::PermutationInsertionPass pi_pass(*this);
125     pi_pass.run();
126
127     pass::PermutationEliminationPass pe_pass(*this);
128     pe_pass.run();
129
130     _op_seqs.dump("merged and sorted operations with permutation", _graph.operations());
131   }
132
133   // Graph verifications
134   {
135     assert(verifier::DAGChecker().verify(_graph));
136     assert(verifier::EdgeConsistencyChecker().verify(_graph));
137   }
138 }
139
140 const operation::LowerInfo *LoweredGraph::getLowerInfo(const OpSequenceIndex &op_seq_index) const
141 {
142   auto itr = _lower_info_map.op_seq.find(op_seq_index);
143   if (itr == _lower_info_map.op_seq.end())
144     return nullptr;
145   return itr->second.get();
146 }
147
148 void LoweredGraph::setLowerInfo(const OpSequenceIndex &op_seq_index,
149                                 std::unique_ptr<operation::LowerInfo> &&lower_info)
150 {
151   _lower_info_map.op_seq.insert(std::make_pair(op_seq_index, std::move(lower_info)));
152 }
153
154 void LoweredGraph::removeLowerInfo(const OpSequenceIndex &op_seq_index)
155 {
156   auto &op_seq_lower_info = _lower_info_map.op_seq;
157   assert(op_seq_lower_info.find(op_seq_index) != op_seq_lower_info.end());
158   for (auto it = op_seq_lower_info.begin(); it != op_seq_lower_info.end(); ++it)
159   {
160     if (it->first == op_seq_index)
161     {
162       op_seq_lower_info.erase(it);
163       break;
164     }
165   }
166 }
167
168 const operand::LowerInfo *LoweredGraph::getLowerInfo(const OperandIndex &index) const
169 {
170   auto itr = _lower_info_map.operand.find(index);
171   if (itr == _lower_info_map.operand.end())
172     return nullptr;
173   return itr->second.get();
174 }
175
176 operand::LowerInfo *LoweredGraph::getLowerInfo(const OperandIndex &index)
177 {
178   auto itr = _lower_info_map.operand.find(index);
179   if (itr == _lower_info_map.operand.end())
180     return nullptr;
181   return itr->second.get();
182 }
183
184 void LoweredGraph::setLowerInfo(const OperandIndex &index,
185                                 std::unique_ptr<operand::LowerInfo> &&lower_info)
186 {
187   _lower_info_map.operand.insert(std::make_pair(index, std::move(lower_info)));
188 }
189
190 void LoweredGraph::removeLowerInfo(const OperandIndex &index)
191 {
192   _lower_info_map.operand.erase(index);
193 }
194
195 void LoweredGraph::iterateTopolOpSeqs(
196     const std::function<void(const OpSequenceIndex &, const OpSequence &)> &fn) const
197 {
198   // Topological Sorting for OpSequences
199   std::vector<OpSequenceIndex> topol_sorted;
200   PostDfsIterator<true>{}.iterateOpSeqs(
201       *this,
202       [&](const OpSequenceIndex &index, const OpSequence &) { topol_sorted.emplace_back(index); });
203   std::reverse(topol_sorted.begin(), topol_sorted.end());
204   for (const auto op_seq_idx : topol_sorted)
205   {
206     const auto &op_seq = _op_seqs.at(op_seq_idx);
207     fn(op_seq_idx, op_seq);
208   }
209 }
210
211 void LoweredGraph::iterateTopolOpSeqs(
212     const std::function<void(const OpSequenceIndex &, OpSequence &)> &fn)
213 {
214   // Topological Sorting for OpSequences
215   std::vector<OpSequenceIndex> topol_sorted;
216   PostDfsIterator<false>{}.iterateOpSeqs(
217       *this, [&](const OpSequenceIndex &index, OpSequence &) { topol_sorted.emplace_back(index); });
218   std::reverse(topol_sorted.begin(), topol_sorted.end());
219   for (const auto op_seq_idx : topol_sorted)
220   {
221     auto &op_seq = _op_seqs.at(op_seq_idx);
222     fn(op_seq_idx, op_seq);
223   }
224 }
225
226 OpSequenceIndex LoweredGraph::appendFreshSingleOpSequence(const OperationIndex &node_index,
227                                                           const Operation &node)
228 {
229   // Create a fresh op_seq with one operation, and append it to op_seqs
230   // Create a fresh op_seq
231   auto op_seq = std::make_unique<OpSequence>(_graph.layout());
232
233   // Add an operation
234   op_seq->appendOperation(node_index);
235
236   // Update input/output
237   op_seq->setOutputs(node.getOutputs());
238   op_seq->setInputs(node.getInputs());
239
240   return _op_seqs.emplace(std::move(op_seq));
241 }
242
243 void LoweredGraph::makeOpSequences(
244     OperandIndexMap<std::unique_ptr<operand::LowerInfo>> &operands_lower_info,
245     const compiler::CompilerOptions &options, const compiler::BackendResolver &backend_resolver)
246 {
247   // if SUBG_MAX_NODE == 0, no limit on nodes of a op_seq
248   const int op_seq_max_node = options.op_seq_max_node;
249   assert(op_seq_max_node >= 0);
250
251   bool is_profiling = options.he_profiling_mode;
252   OpSequence *op_seq = nullptr;
253   OpSequenceIndex op_seq_index;
254
255   // NOTE: The below method appends nodes while making one op_seq if needed. If something better
256   // ways, happy to update this code.
257   PostDfsConstIterator{}.iterate(
258       _graph, [&](const OperationIndex &node_index, const Operation &node) {
259         // LowerInfo for in/output operands
260         auto backend = backend_resolver.getBackend(node_index);
261
262         // Get frontend's layout
263         auto frontend_layout = _graph.layout();
264
265         // The layout of each backend should be set at another place
266         // TODO Change setting layout of each backend at another place
267         auto backend_layout = backend->config()->supportLayout(node, frontend_layout);
268
269         for (auto operand : node.getInputs() | ir::Remove::UNDEFINED)
270         {
271           auto &&lower_info = operands_lower_info.at(operand);
272           lower_info->addUsePermuteFactor(operand::PermuteFactor{backend, backend_layout});
273         }
274         for (auto operand : node.getOutputs())
275         {
276           auto &&lower_info = operands_lower_info.at(operand);
277           lower_info->addDefPermuteFactor(operand::PermuteFactor{backend, backend_layout});
278         }
279
280         bool new_op_seq = (op_seq == nullptr ||
281                            (op_seq_max_node != 0 &&
282                             op_seq->operations().size() >= static_cast<size_t>(op_seq_max_node)));
283
284         // for profiling each op_seq must contain just one node,
285         // so that we can measure a node separately
286         if (new_op_seq || is_profiling ||
287             !mergeable(op_seq_index, node_index, backend_layout, backend_resolver))
288         {
289           auto new_op_seq_index = appendFreshSingleOpSequence(node_index, node);
290
291           // OpSequence LowerInfo
292           setLowerInfo(new_op_seq_index,
293                        std::make_unique<operation::LowerInfo>(backend, backend_layout));
294
295           op_seq_index = new_op_seq_index;
296           op_seq = &(_op_seqs.at(new_op_seq_index));
297
298           VERBOSE(Lower) << "OpSequence#" << op_seq_index.value() << " is created for "
299                          << "NODE#" << node_index.value() << "(" << node.name() << ")" << std::endl;
300         }
301         else
302         {
303           op_seq->appendOperation(node_index);
304           // Set inputs
305           auto new_inputs = node.getInputs();
306           // Add inputs except outputs of the previous node
307           for (auto ind : op_seq->getInputs())
308           {
309             if (!node.getOutputs().contains(ind))
310               new_inputs.append(ind);
311           }
312           op_seq->setInputs(new_inputs);
313
314           VERBOSE(Lower) << "OpSequence#" << op_seq_index.value() << " merges "
315                          << "NODE#" << node_index.value() << "(" << node.name() << ")" << std::endl;
316         }
317       });
318 }
319
320 void LoweredGraph::manipulateLowerInfo(
321     OperandIndexMap<std::unique_ptr<operand::LowerInfo>> &operands_lower_info, bool is_primary)
322 {
323   const auto controlflow_backend = compiler::BackendManager::get().getControlflow();
324
325   // TODO Rather than handling primary graph specially,
326   //      let the permute inserted and remove it later
327   if (is_primary)
328   {
329     // TODO Rather than using NHWC Get frontend layout of this node from IR
330     auto factor = operand::PermuteFactor{controlflow_backend, Layout::NHWC};
331     for (auto index : _graph.getInputs() | ir::Remove::UNDEFINED)
332     {
333       auto &&lower_info = operands_lower_info.at(index);
334       assert(lower_info->def_factors().empty());
335       lower_info->addDefPermuteFactor(factor);
336     }
337     for (auto index : _graph.getOutputs())
338     {
339       auto &&lower_info = operands_lower_info.at(index);
340       lower_info->addUsePermuteFactor(factor);
341     }
342   }
343   else
344   {
345     for (auto index : _graph.getInputs() | ir::Remove::UNDEFINED)
346     {
347       auto &&lower_info = operands_lower_info.at(index);
348       if (!(lower_info->def_factors().size() == 0 && lower_info->use_factors().size() == 0))
349       {
350         // In case of not that Graph's input is not used in any operation and not the graph's
351         // output.
352         // In other words, it is not unused input in Graph.
353         lower_info->addDefPermuteFactor(*lower_info->use_factors().begin());
354       }
355       else
356       {
357         // In case of that an operand is Graph's input and not input or output of any operation
358         lower_info->addDefPermuteFactor(operand::PermuteFactor{
359             controlflow_backend,
360             Layout::NHWC // TODO Get frontend layout of this node from IR
361         });
362       }
363     }
364   }
365   for (auto index : _graph.getOutputs())
366   {
367     auto &&lower_info = operands_lower_info.at(index);
368     if (lower_info->def_factors().size() == 0)
369     {
370       // In case of that an operand is Graph's output and not input or output of any operation
371       lower_info->addDefPermuteFactor(operand::PermuteFactor{
372           controlflow_backend,
373           Layout::NHWC // TODO Get frontend layout of this node from IR
374       });
375     }
376   }
377
378   // Set LowerInfo for each operand from the operand::LowerInfo holder
379   _graph.operands().iterate([&](const OperandIndex &index, Operand &) {
380     setLowerInfo(index, std::move(operands_lower_info[index]));
381   });
382 }
383
384 void LoweredGraph::dumpLowerInfo()
385 {
386   if (::onert::util::logging::ctx.enabled() == false)
387     return;
388
389   std::map<uint32_t, std::string> dumps;
390
391   _graph.operands().iterate([&](const OperandIndex &index, Operand &object) {
392     std::stringstream sstream;
393     if (!getLowerInfo(index)->def_factors().empty() || !getLowerInfo(index)->use_factors().empty())
394     {
395       auto factors_to_string = [](const operand::PermuteFactorSet &factors) {
396         std::string str;
397         for (auto factor : factors)
398         {
399           str += factor.backend()->config()->id();
400           str += "(" + to_string(factor.layout()) + ")";
401           str += " ";
402         }
403         return "{ " + str + "}";
404       };
405
406       auto operation_index_to_string = [](const OperationIndexSet &operations) {
407         std::string str;
408         for (auto op : operations)
409         {
410           str += std::to_string(op.value());
411           str += " ";
412         }
413         return "{ " + str + "}";
414       };
415
416       const auto lower_info = getLowerInfo(index);
417       const auto &shape = object.shape();
418       std::string def_ops =
419           object.getDef().valid() ? std::to_string(object.getDef().value()) : "N/A";
420       std::string use_ops = operation_index_to_string(object.getUses());
421       std::string def_layouts = factors_to_string(lower_info->def_factors());
422       std::string use_layouts = factors_to_string(lower_info->use_factors());
423       sstream << "Operand #" << index.value() << " LowerInfo" << std::endl;
424       sstream << "  - Shape           : { ";
425       for (auto i = 0; i < shape.rank(); ++i)
426       {
427         sstream << (shape.dim(i)) << " ";
428       }
429       sstream << "}" << std::endl;
430       sstream << "  - Def Operations  : " << def_ops << std::endl;
431       sstream << "  - Use Operations  : " << use_ops << std::endl;
432       sstream << "  - Lower Info" << std::endl;
433       sstream << "    - Def Backends    : " << def_layouts << std::endl;
434       sstream << "    - Use Backends    : " << use_layouts << std::endl;
435     }
436     dumps.emplace(index.value(), sstream.str());
437   });
438
439   for (const auto &e : dumps)
440   {
441     if (!e.second.empty())
442     {
443       VERBOSE(Lower) << e.second;
444     }
445   }
446 }
447
448 bool LoweredGraph::mergeable(const OpSequenceIndex &op_seq_index, const OperationIndex &node_index,
449                              Layout layout, const compiler::BackendResolver &backend_resolver)
450 {
451   // Are they mergeable?
452   // 1. the same backend id and layout?
453   // 2. Is op_seq or node branched?
454   // 3. if 1 is true, the op_seq and a node are connected?
455   const auto &op_seq = _op_seqs.at(op_seq_index);
456   const auto &node = _graph.operations().at(node_index);
457
458   // The same backend id and layout?
459   {
460     const auto op_seq_backend_layout = getLowerInfo(op_seq_index)->layout();
461     const auto &op_seq_backend_id = getLowerInfo(op_seq_index)->backend()->config()->id();
462     const auto &node_backend_id = backend_resolver.getBackend(node_index)->config()->id();
463     VERBOSE(Lower) << "OpSequence#" << op_seq_index.value() << " { " << op_seq_backend_id << "("
464                    << to_string(op_seq_backend_layout) << ") } "
465                    << " NODE#" << node_index.value() << " (" << node.name() << ") { "
466                    << node_backend_id << "(" << to_string(layout) << ") } " << std::endl;
467     if (op_seq_backend_id != node_backend_id || op_seq_backend_layout != layout)
468       return false;
469   }
470
471   // Branched?
472   {
473     std::unordered_set<OperationIndex> branched_set;
474
475     // Check for branching up
476     for (const auto &input : op_seq.getInputs() | Remove::DUPLICATED | ir::Remove::UNDEFINED)
477     {
478       const auto &input_obj = _graph.operands().at(input);
479       auto def = input_obj.getDef();
480       if (def.valid())
481       {
482         branched_set.insert(def);
483         if (branched_set.size() > 1)
484         {
485           return false;
486         }
487       }
488     }
489     branched_set.clear();
490
491     // Check for branching down
492     for (const auto &output : node.getOutputs() | Remove::DUPLICATED)
493     {
494       // TODO Fix this workaround for the case of model outputs that are used by another operation
495       //      This is needed since the branching is decided by operation, but for model outputs,
496       //      there is controlflow backen(use backend) but no actual use operation exists
497       if (_graph.getOutputs().contains(output))
498         return false;
499
500       const auto &output_obj = _graph.operands().at(output);
501       for (const auto &use : output_obj.getUses())
502       {
503         branched_set.insert(use);
504         if (branched_set.size() > 1)
505         {
506           return false;
507         }
508       }
509     }
510   }
511
512   // Connected?
513   // an input of one node is an output of the other node? or vice-versa?
514   {
515     const auto &node_inputs = node.getInputs();
516     const auto &node_outputs = node.getOutputs();
517
518     // op_seq's operations are in order so that we just check the first and the last
519     std::vector<OperationIndex> op_seq_ops{op_seq.operations()[0]};
520     if (op_seq.operations().size() > 1)
521       op_seq_ops.emplace_back(op_seq.operations()[op_seq.operations().size() - 1]);
522
523     for (const auto &n_index : op_seq_ops)
524     {
525       const auto &n = _graph.operations().at(n_index);
526
527       // node's output == op_seq's input?
528       for (const auto input : n.getInputs() | ir::Remove::UNDEFINED)
529       {
530         if (node_outputs.contains(input))
531         {
532           VERBOSE(Lower) << "OpSequence#" << op_seq_index.value() << " 's NODE#" << n_index.value()
533                          << "(" << n.name() << ") is connected to NODE#" << node_index.value()
534                          << "(" << node.name() << ")" << std::endl;
535           return true;
536         }
537       }
538
539       // node's input == op_seq's output?
540       for (const auto output : n.getOutputs())
541       {
542         if (node_inputs.contains(output))
543         {
544           VERBOSE(Lower) << "OpSequence#" << op_seq_index.value() << " 's NODE#" << n_index.value()
545                          << " (" << n.name() << ") is connected to NODE#" << node_index.value()
546                          << std::endl;
547           return true;
548         }
549       }
550     }
551
552     VERBOSE(Lower) << "OpSequence#" << op_seq_index.value() << " is not connected to NODE#"
553                    << node_index.value() << "(" << node.name() << ")" << std::endl;
554   }
555
556   return false;
557 }
558
559 } // namespace ir
560 } // namespace onert