673d7d3e863972e0455dd55e7e44004fac0751fe
[platform/core/ml/nnfw.git] / runtime / onert / core / src / compiler / LoweredGraph.cc
1 /*
2  * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16
17 #include "compiler/LoweredGraph.h"
18
19 #include <assert.h>
20 #include <sstream>
21 #include "util/logging.h"
22 #include "compiler/pass/ConstantInsertionPass.h"
23 #include "compiler/pass/ConstantLoweringPass.h"
24 #include "compiler/pass/PassRunner.h"
25 #include "compiler/pass/PermutationOperationPass.h"
26 #include "compiler/pass/PermutationInsertionPass.h"
27 #include "compiler/pass/PermutationEliminationPass.h"
28 #include "ir/GraphIterator.h"
29 #include "ir/verifier/Verifier.h"
30 #include "backend/Backend.h"
31 #include "backend/IConfig.h"
32 #include "compiler/BackendResolver.h"
33 #include "compiler/ManualScheduler.h"
34 #include "compiler/HEScheduler.h"
35
36 namespace onert
37 {
38 namespace compiler
39 {
40
41 LoweredGraph::LoweredGraph(const ir::Graph &graph, const CompilerOptions &options) : _graph{graph}
42 {
43   bool linear_executor = (options.executor == "Linear");
44
45   // Build backend contexts
46   auto &backend_manager = BackendManager::get();
47
48   // Always create Controlflow backend context
49   auto cf_backend = backend_manager.getControlflow();
50   _backend_contexts.emplace(
51       cf_backend, cf_backend->newContext(_graph, _graph.getKernelBuilder(), linear_executor));
52
53   // Create contexts for other backends
54   for (auto backend_str : options.backend_list)
55   {
56     backend_manager.loadBackend(backend_str);
57     auto backend = backend_manager.get(backend_str);
58
59     // TODO As the default value of backend list contains "cpu", "acl_cl" and "acl_neon", and some
60     // are not available on x64 or some other platforms. So this may be a workaround for x64 and
61     // we should change it back(throw if backend is not loaded) later.
62     if (!backend)
63     {
64       VERBOSE(LoweredGraph) << "Cannot load backend - " << backend_str << std::endl;
65       continue;
66     }
67
68     _backend_contexts.emplace(
69         backend, backend->newContext(_graph, _graph.getKernelBuilder(), linear_executor));
70   }
71   if (backend_manager.num_backends() == 0)
72     throw std::runtime_error{"No available backends loaded."};
73
74   // TODO Move "schedule" phase out of here
75   // Schedule
76   std::unique_ptr<BackendResolver> backend_resolver;
77   if (options.he_scheduler)
78   {
79     auto scheduler = HEScheduler(_backend_contexts, options);
80     backend_resolver = scheduler.schedule(_graph);
81     _indexed_ranks = scheduler.getIndexedRanks();
82   }
83   else
84   {
85     auto scheduler = ManualScheduler(_backend_contexts, options);
86     backend_resolver = scheduler.schedule(_graph);
87   }
88
89   {
90     // operand::LowerInfo holder
91     ir::OperandIndexMap<std::unique_ptr<ir::operand::LowerInfo>> operands_lower_info;
92
93     _graph.operands().iterate([&](const ir::OperandIndex &index, const ir::Operand &) {
94       operands_lower_info[index] = std::make_unique<ir::operand::LowerInfo>();
95     });
96
97     // Make op_seqs while checking whether a node can be merged into a op_seq.
98     makeOpSequences(operands_lower_info, options, *backend_resolver);
99
100     _op_seqs.iterate([&](const ir::OpSequenceIndex &, ir::OpSequence &op_seq) {
101       assert(op_seq.operations().size() > 0);
102       std::reverse(std::begin(op_seq.operations()), std::end(op_seq.operations()));
103     });
104
105     VERBOSE(OpSequences) << "dump before permutation insertion" << std::endl;
106     dumpOpSequences(_op_seqs, _graph.operations());
107
108     // Mandatory passes
109     pass::PassRunner{}
110         .append(std::make_unique<pass::ConstantInsertionPass>(*this))
111         .append(std::make_unique<pass::ConstantLoweringPass>(*this))
112         .run();
113
114     // Set LowerInfo for each operand from the operand::LowerInfo holder
115     manipulateLowerInfo(operands_lower_info, options.is_primary_subgraph);
116
117     dumpLowerInfo();
118   }
119
120   // Mandatory passes
121   pass::PassRunner{}
122       .append(std::make_unique<pass::PermutationOperationPass>(*this))
123       .append(std::make_unique<pass::PermutationInsertionPass>(*this))
124       .run();
125
126   // Optimization passes
127   pass::PassRunner{}.append(std::make_unique<pass::PermutationEliminationPass>(*this)).run();
128
129   VERBOSE(OpSequences) << "Dump after permutation insertion" << std::endl;
130   dumpOpSequences(_op_seqs, _graph.operations());
131
132   // Graph verifications
133   {
134     assert(ir::verifier::InputOutputChecker().verify(_graph));
135     assert(ir::verifier::DAGChecker().verify(_graph));
136     assert(ir::verifier::EdgeConsistencyChecker().verify(_graph));
137   }
138 }
139
140 const ir::operation::LowerInfo *
141 LoweredGraph::getLowerInfo(const ir::OpSequenceIndex &op_seq_index) const
142 {
143   auto itr = _lower_info_map.op_seq.find(op_seq_index);
144   if (itr == _lower_info_map.op_seq.end())
145     return nullptr;
146   return itr->second.get();
147 }
148
149 void LoweredGraph::setLowerInfo(const ir::OpSequenceIndex &op_seq_index,
150                                 std::unique_ptr<ir::operation::LowerInfo> &&lower_info)
151 {
152   _lower_info_map.op_seq.insert(std::make_pair(op_seq_index, std::move(lower_info)));
153 }
154
155 void LoweredGraph::removeLowerInfo(const ir::OpSequenceIndex &op_seq_index)
156 {
157   auto &op_seq_lower_info = _lower_info_map.op_seq;
158   assert(op_seq_lower_info.find(op_seq_index) != op_seq_lower_info.end());
159   for (auto it = op_seq_lower_info.begin(); it != op_seq_lower_info.end(); ++it)
160   {
161     if (it->first == op_seq_index)
162     {
163       op_seq_lower_info.erase(it);
164       break;
165     }
166   }
167 }
168
169 const ir::operand::LowerInfo *LoweredGraph::getLowerInfo(const ir::OperandIndex &index) const
170 {
171   auto itr = _lower_info_map.operand.find(index);
172   if (itr == _lower_info_map.operand.end())
173     return nullptr;
174   return itr->second.get();
175 }
176
177 ir::operand::LowerInfo *LoweredGraph::getLowerInfo(const ir::OperandIndex &index)
178 {
179   auto itr = _lower_info_map.operand.find(index);
180   if (itr == _lower_info_map.operand.end())
181     return nullptr;
182   return itr->second.get();
183 }
184
185 void LoweredGraph::setLowerInfo(const ir::OperandIndex &index,
186                                 std::unique_ptr<ir::operand::LowerInfo> &&lower_info)
187 {
188   _lower_info_map.operand.insert(std::make_pair(index, std::move(lower_info)));
189 }
190
191 void LoweredGraph::removeLowerInfo(const ir::OperandIndex &index)
192 {
193   _lower_info_map.operand.erase(index);
194 }
195
196 void LoweredGraph::iterateTopolOpSeqs(
197     const std::function<void(const ir::OpSequenceIndex &, const ir::OpSequence &)> &fn) const
198 {
199   // Topological Sorting for ir::OpSequences
200   std::vector<ir::OpSequenceIndex> topol_sorted;
201   ir::PostDfsIterator<true>{}.iterateOpSeqs(
202       *this, [&](const ir::OpSequenceIndex &index, const ir::OpSequence &) {
203         topol_sorted.emplace_back(index);
204       });
205   std::reverse(topol_sorted.begin(), topol_sorted.end());
206   for (const auto op_seq_idx : topol_sorted)
207   {
208     const auto &op_seq = _op_seqs.at(op_seq_idx);
209     fn(op_seq_idx, op_seq);
210   }
211 }
212
213 void LoweredGraph::iterateTopolOpSeqs(
214     const std::function<void(const ir::OpSequenceIndex &, ir::OpSequence &)> &fn)
215 {
216   // Topological Sorting for ir::OpSequences
217   std::vector<ir::OpSequenceIndex> topol_sorted;
218   ir::PostDfsIterator<false>{}.iterateOpSeqs(
219       *this, [&](const ir::OpSequenceIndex &index, ir::OpSequence &) {
220         topol_sorted.emplace_back(index);
221       });
222   std::reverse(topol_sorted.begin(), topol_sorted.end());
223   for (const auto op_seq_idx : topol_sorted)
224   {
225     auto &op_seq = _op_seqs.at(op_seq_idx);
226     fn(op_seq_idx, op_seq);
227   }
228 }
229
230 ir::OpSequenceIndex LoweredGraph::appendFreshSingleOpSequence(const ir::OperationIndex &node_index,
231                                                               const ir::Operation &node)
232 {
233   // Create a fresh op_seq with one operation, and append it to op_seqs
234   // Create a fresh op_seq
235   auto op_seq = std::make_unique<ir::OpSequence>(_graph.layout());
236
237   // Add an operation
238   op_seq->appendOperation(node_index);
239
240   // Update input/output
241   op_seq->setOutputs(node.getOutputs());
242   op_seq->setInputs(node.getInputs());
243
244   return _op_seqs.emplace(std::move(op_seq));
245 }
246
247 void LoweredGraph::makeOpSequences(
248     ir::OperandIndexMap<std::unique_ptr<ir::operand::LowerInfo>> &operands_lower_info,
249     const CompilerOptions &options, const BackendResolver &backend_resolver)
250 {
251   // if SUBG_MAX_NODE == 0, no limit on nodes of a op_seq
252   const int op_seq_max_node = options.op_seq_max_node;
253   assert(op_seq_max_node >= 0);
254
255   bool is_profiling = options.he_profiling_mode;
256   ir::OpSequence *op_seq = nullptr;
257   ir::OpSequenceIndex op_seq_index;
258
259   // NOTE: The below method appends nodes while making one op_seq if needed. If something better
260   // ways, happy to update this code.
261   ir::PostDfsConstIterator{}.iterate(
262       _graph, [&](const ir::OperationIndex &node_index, const ir::Operation &node) {
263         // LowerInfo for in/output operands
264         auto backend = backend_resolver.getBackend(node_index);
265
266         // Get frontend's layout
267         auto frontend_layout = _graph.layout();
268
269         // The layout of each backend should be set at another place
270         // TODO Change setting layout of each backend at another place
271         auto backend_layout = backend->config()->supportLayout(node, frontend_layout);
272
273         for (auto operand : node.getInputs() | ir::Remove::UNDEFINED)
274         {
275           auto &&lower_info = operands_lower_info.at(operand);
276           lower_info->addUsePermuteFactor(ir::operand::PermuteFactor{backend, backend_layout});
277         }
278         for (auto operand : node.getOutputs() | ir::Remove::UNDEFINED)
279         {
280           auto &&lower_info = operands_lower_info.at(operand);
281           lower_info->addDefPermuteFactor(ir::operand::PermuteFactor{backend, backend_layout});
282         }
283
284         bool new_op_seq = (op_seq == nullptr ||
285                            (op_seq_max_node != 0 &&
286                             op_seq->operations().size() >= static_cast<size_t>(op_seq_max_node)));
287
288         // for profiling each op_seq must contain just one node,
289         // so that we can measure a node separately
290         if (new_op_seq || is_profiling ||
291             !mergeable(op_seq_index, node_index, backend_layout, backend_resolver))
292         {
293           auto new_op_seq_index = appendFreshSingleOpSequence(node_index, node);
294
295           // ir::OpSequence LowerInfo
296           setLowerInfo(new_op_seq_index,
297                        std::make_unique<ir::operation::LowerInfo>(backend, backend_layout));
298
299           op_seq_index = new_op_seq_index;
300           op_seq = &(_op_seqs.at(new_op_seq_index));
301
302           VERBOSE(Lower) << "OpSequence#" << op_seq_index.value() << " is created for "
303                          << "NODE#" << node_index.value() << "(" << node.name() << ")" << std::endl;
304         }
305         else
306         {
307           op_seq->appendOperation(node_index);
308           // Set inputs
309           auto new_inputs = node.getInputs();
310           // Add inputs except outputs of the previous node
311           for (auto ind : op_seq->getInputs())
312           {
313             if (!node.getOutputs().contains(ind))
314               new_inputs.append(ind);
315           }
316           op_seq->setInputs(new_inputs);
317
318           VERBOSE(Lower) << "OpSequence#" << op_seq_index.value() << " merges "
319                          << "NODE#" << node_index.value() << "(" << node.name() << ")" << std::endl;
320         }
321       });
322 }
323
324 void LoweredGraph::manipulateLowerInfo(
325     ir::OperandIndexMap<std::unique_ptr<ir::operand::LowerInfo>> &operands_lower_info,
326     bool is_primary)
327 {
328   const auto controlflow_backend = BackendManager::get().getControlflow();
329
330   // TODO Rather than handling primary graph specially,
331   //      let the permute inserted and remove it later
332   if (is_primary)
333   {
334     // TODO Rather than using NHWC Get frontend layout of this node from IR
335     auto factor = ir::operand::PermuteFactor{controlflow_backend, ir::Layout::NHWC};
336     for (auto index : _graph.getInputs() | ir::Remove::UNDEFINED)
337     {
338       auto &&lower_info = operands_lower_info.at(index);
339       assert(lower_info->def_factors().empty());
340       lower_info->addDefPermuteFactor(factor);
341     }
342     for (auto index : _graph.getOutputs() | ir::Remove::UNDEFINED)
343     {
344       auto &&lower_info = operands_lower_info.at(index);
345       lower_info->addUsePermuteFactor(factor);
346     }
347   }
348   else
349   {
350     for (auto index : _graph.getInputs() | ir::Remove::UNDEFINED)
351     {
352       auto &&lower_info = operands_lower_info.at(index);
353       if (!(lower_info->def_factors().size() == 0 && lower_info->use_factors().size() == 0))
354       {
355         // In case of not that Graph's input is not used in any operation and not the graph's
356         // output.
357         // In other words, it is not unused input in Graph.
358         lower_info->addDefPermuteFactor(*lower_info->use_factors().begin());
359       }
360       else
361       {
362         // In case of that an operand is Graph's input and not input or output of any operation
363         lower_info->addDefPermuteFactor(ir::operand::PermuteFactor{
364             controlflow_backend,
365             ir::Layout::NHWC // TODO Get frontend layout of this node from IR
366         });
367       }
368     }
369   }
370   for (auto index : _graph.getOutputs() | ir::Remove::UNDEFINED)
371   {
372     auto &&lower_info = operands_lower_info.at(index);
373     if (lower_info->def_factors().size() == 0)
374     {
375       // In case of that an operand is Graph's output and not input or output of any operation
376       lower_info->addDefPermuteFactor(ir::operand::PermuteFactor{
377           controlflow_backend,
378           ir::Layout::NHWC // TODO Get frontend layout of this node from IR
379       });
380     }
381   }
382
383   // 1. Add def of variable operand
384   // 2. Set LowerInfo for each operand from the operand::LowerInfo holder
385   _graph.operands().iterate([&](const ir::OperandIndex &index, ir::Operand &operand) {
386     // Some inputs of an operation could be non-constant, but not existed in graph inputs/outputs
387     // and not undefined operand. Those inputs must have exist as a Tensor. For example,
388     // UnidirectionalSequenceLSTM operation could have state inputs such as it.
389     if (operand.info().isVariable())
390     {
391       // The variable operand with buffer is not supported yet
392       assert(operand.data() == nullptr);
393       assert(operand.getUses().size() == 1 && !operand.getDef().valid());
394       auto &lowered_info = operands_lower_info[index];
395       assert(lowered_info->def_factors().empty());
396       lowered_info->addDefPermuteFactor(lowered_info->use_factors().getOnlyElement());
397     }
398
399     setLowerInfo(index, std::move(operands_lower_info[index]));
400   });
401 }
402
403 void LoweredGraph::dumpLowerInfo()
404 {
405   if (::onert::util::logging::ctx.enabled() == false)
406     return;
407
408   std::map<uint32_t, std::string> dumps;
409
410   _graph.operands().iterate([&](const ir::OperandIndex &index, ir::Operand &object) {
411     std::stringstream sstream;
412     if (!getLowerInfo(index)->def_factors().empty() || !getLowerInfo(index)->use_factors().empty())
413     {
414       auto factors_to_string = [](const ir::operand::PermuteFactorSet &factors) {
415         std::string str;
416         for (auto factor : factors)
417         {
418           str += factor.backend()->config()->id();
419           str += "(" + to_string(factor.layout()) + ")";
420           str += " ";
421         }
422         return "{ " + str + "}";
423       };
424
425       auto operation_index_to_string = [](const ir::OperationIndexSet &operations) {
426         std::string str;
427         for (auto op : operations)
428         {
429           str += std::to_string(op.value());
430           str += " ";
431         }
432         return "{ " + str + "}";
433       };
434
435       const auto lower_info = getLowerInfo(index);
436       const auto &shape = object.shape();
437       std::string def_ops =
438           object.getDef().valid() ? std::to_string(object.getDef().value()) : "N/A";
439       std::string use_ops = operation_index_to_string(object.getUses());
440       std::string def_layouts = factors_to_string(lower_info->def_factors());
441       std::string use_layouts = factors_to_string(lower_info->use_factors());
442       sstream << "Operand #" << index.value() << " LowerInfo" << std::endl;
443       sstream << "  - Shape           : { ";
444       for (auto i = 0; i < shape.rank(); ++i)
445       {
446         sstream << (shape.dim(i)) << " ";
447       }
448       sstream << "}" << std::endl;
449       sstream << "  - Def ir::Operations  : " << def_ops << std::endl;
450       sstream << "  - Use ir::Operations  : " << use_ops << std::endl;
451       sstream << "  - Lower Info" << std::endl;
452       sstream << "    - Def Backends    : " << def_layouts << std::endl;
453       sstream << "    - Use Backends    : " << use_layouts << std::endl;
454     }
455     dumps.emplace(index.value(), sstream.str());
456   });
457
458   for (const auto &e : dumps)
459   {
460     if (!e.second.empty())
461     {
462       VERBOSE(Lower) << e.second;
463     }
464   }
465 }
466
467 bool LoweredGraph::mergeable(const ir::OpSequenceIndex &op_seq_index,
468                              const ir::OperationIndex &node_index, ir::Layout layout,
469                              const BackendResolver &backend_resolver)
470 {
471   // Are they mergeable?
472   // 1. the same backend id and layout?
473   // 2. Is op_seq or node branched?
474   // 3. if 1 is true, the op_seq and a node are connected?
475   const auto &op_seq = _op_seqs.at(op_seq_index);
476   const auto &node = _graph.operations().at(node_index);
477
478   // The same backend id and layout?
479   {
480     const auto op_seq_backend_layout = getLowerInfo(op_seq_index)->layout();
481     const auto &op_seq_backend_id = getLowerInfo(op_seq_index)->backend()->config()->id();
482     const auto &node_backend_id = backend_resolver.getBackend(node_index)->config()->id();
483     VERBOSE(Lower) << "OpSequence#" << op_seq_index.value() << " { " << op_seq_backend_id << "("
484                    << to_string(op_seq_backend_layout) << ") } "
485                    << " NODE#" << node_index.value() << " (" << node.name() << ") { "
486                    << node_backend_id << "(" << to_string(layout) << ") } " << std::endl;
487     if (op_seq_backend_id != node_backend_id || op_seq_backend_layout != layout)
488       return false;
489   }
490
491   // Branched?
492   {
493     std::unordered_set<ir::OperationIndex> branched_set;
494
495     // Check for branching up
496     for (const auto &input : op_seq.getInputs() | ir::Remove::DUPLICATED | ir::Remove::UNDEFINED)
497     {
498       const auto &input_obj = _graph.operands().at(input);
499       auto def = input_obj.getDef();
500       if (def.valid())
501       {
502         branched_set.insert(def);
503         if (branched_set.size() > 1)
504         {
505           return false;
506         }
507       }
508     }
509     branched_set.clear();
510
511     // Check for branching down
512     for (const auto &output : node.getOutputs() | ir::Remove::DUPLICATED | ir::Remove::UNDEFINED)
513     {
514       // TODO Fix this workaround for the case of model outputs that are used by another operation
515       //      This is needed since the branching is decided by operation, but for model outputs,
516       //      there is controlflow backen(use backend) but no actual use operation exists
517       if (_graph.getOutputs().contains(output))
518         return false;
519
520       const auto &output_obj = _graph.operands().at(output);
521       for (const auto &use : output_obj.getUses())
522       {
523         branched_set.insert(use);
524         if (branched_set.size() > 1)
525         {
526           return false;
527         }
528       }
529     }
530   }
531
532   // Connected?
533   // an input of one node is an output of the other node? or vice-versa?
534   {
535     const auto &node_inputs = node.getInputs();
536     const auto &node_outputs = node.getOutputs();
537
538     // op_seq's operations are in order so that we just check the first and the last
539     std::vector<ir::OperationIndex> op_seq_ops{op_seq.operations()[0]};
540     if (op_seq.operations().size() > 1)
541       op_seq_ops.emplace_back(op_seq.operations()[op_seq.operations().size() - 1]);
542
543     for (const auto &n_index : op_seq_ops)
544     {
545       const auto &n = _graph.operations().at(n_index);
546
547       // node's output == op_seq's input?
548       for (const auto input : n.getInputs() | ir::Remove::UNDEFINED)
549       {
550         if (node_outputs.contains(input))
551         {
552           VERBOSE(Lower) << "OpSequence#" << op_seq_index.value() << " 's NODE#" << n_index.value()
553                          << "(" << n.name() << ") is connected to NODE#" << node_index.value()
554                          << "(" << node.name() << ")" << std::endl;
555           return true;
556         }
557       }
558
559       // node's input == op_seq's output?
560       for (const auto output : n.getOutputs() | ir::Remove::UNDEFINED)
561       {
562         if (node_inputs.contains(output))
563         {
564           VERBOSE(Lower) << "OpSequence#" << op_seq_index.value() << " 's NODE#" << n_index.value()
565                          << " (" << n.name() << ") is connected to NODE#" << node_index.value()
566                          << std::endl;
567           return true;
568         }
569       }
570     }
571
572     VERBOSE(Lower) << "OpSequence#" << op_seq_index.value() << " is not connected to NODE#"
573                    << node_index.value() << "(" << node.name() << ")" << std::endl;
574   }
575
576   return false;
577 }
578
579 } // namespace compiler
580 } // namespace onert