Imported Upstream version 1.7.0
[platform/core/ml/nnfw.git] / runtime / onert / core / src / ir / LoweredGraph.cc
1 /*
2  * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16
17 #include "ir/LoweredGraph.h"
18
19 #include <assert.h>
20 #include <sstream>
21 #include "util/logging.h"
22 #include "pass/ConstantInsertionPass.h"
23 #include "pass/ConstantLoweringPass.h"
24 #include "pass/PermutationOperationPass.h"
25 #include "pass/PermutationInsertionPass.h"
26 #include "ir/GraphIterator.h"
27 #include "verifier/Verifier.h"
28 #include "backend/Backend.h"
29 #include "backend/IConfig.h"
30 #include "compiler/BackendResolver.h"
31 #include "compiler/ManualScheduler.h"
32 #include "compiler/HEScheduler.h"
33
34 namespace onert
35 {
36 namespace ir
37 {
38
39 LoweredGraph::LoweredGraph(const Graph &graph, const compiler::CompilerOptions &options)
40     : _graph{graph}
41 {
42   bool linear_executor = (options.executor == "Linear");
43
44   // Build backend contexts
45   auto &backend_manager = compiler::BackendManager::get();
46
47   // Always create Controlflow backend context
48   auto cf_backend = backend_manager.getControlflow();
49   _backend_contexts.emplace(
50       cf_backend, cf_backend->newContext(_graph, _graph.getKernelBuilder(), linear_executor));
51
52   // Create contexts for other backends
53   for (auto backend_str : options.backend_list)
54   {
55     backend_manager.loadBackend(backend_str);
56     auto backend = backend_manager.get(backend_str);
57
58     // TODO As the default value of backend list contains "cpu", "acl_cl" and "acl_neon", and some
59     // are not available on x64 or some other platforms. So this may be a workaround for x64 and
60     // we should change it back(throw if backend is not loaded) later.
61     if (!backend)
62     {
63       VERBOSE(LoweredGraph) << "Cannot load backend - " << backend_str;
64       continue;
65     }
66
67     _backend_contexts.emplace(
68         backend, backend->newContext(_graph, _graph.getKernelBuilder(), linear_executor));
69   }
70   if (backend_manager.num_backends() == 0)
71     throw std::runtime_error{"No available backends loaded."};
72
73   // TODO Move "schedule" phase out of here
74   // Schedule
75   std::unique_ptr<compiler::BackendResolver> backend_resolver;
76   if (options.he_scheduler)
77   {
78     auto scheduler = compiler::HEScheduler(_backend_contexts, options);
79     backend_resolver = scheduler.schedule(_graph);
80     _indexed_ranks = scheduler.getIndexedRanks();
81   }
82   else
83   {
84     auto scheduler = compiler::ManualScheduler(_backend_contexts, options);
85     backend_resolver = scheduler.schedule(_graph);
86   }
87
88   {
89     // operand::LowerInfo holder
90     OperandIndexMap<std::unique_ptr<operand::LowerInfo>> operands_lower_info;
91
92     _graph.operands().iterate([&](const OperandIndex &index, const Operand &) {
93       operands_lower_info[index] = std::make_unique<operand::LowerInfo>();
94     });
95
96     // Make op_seqs while checking whether a node can be merged into a op_seq.
97     makeOpSequences(operands_lower_info, options, *backend_resolver);
98
99     _op_seqs.iterate([&](const OpSequenceIndex &, OpSequence &op_seq) {
100       assert(op_seq.operations().size() > 0);
101       std::reverse(std::begin(op_seq.operations()), std::end(op_seq.operations()));
102     });
103
104     _op_seqs.dump("merged and sorted operations without permutation", _graph.operations());
105
106     pass::ConstantInsertionPass ci_pass(*this);
107     ci_pass.run();
108
109     pass::ConstantLoweringPass cl_pass(*this);
110     cl_pass.run();
111
112     // Set LowerInfo for each operand from the operand::LowerInfo holder
113     manipulateLowerInfo(operands_lower_info, options.is_primary_subgraph);
114
115     dumpLowerInfo();
116   }
117
118   // Run Permutation Passes
119   {
120     pass::PermutationOperationPass po_pass(*this);
121     po_pass.run();
122
123     pass::PermutationInsertionPass pi_pass(*this);
124     pi_pass.run();
125     // Implemented code no longer works.
126     // pass::PermutationEliminationPass pe_pass(*this);
127     // pe_pass.run();
128
129     _op_seqs.dump("merged and sorted operations with permutation", _graph.operations());
130   }
131
132   // Graph verifications
133   {
134     assert(verifier::DAGChecker().verify(_graph));
135     assert(verifier::EdgeConsistencyChecker().verify(_graph));
136   }
137 }
138
139 const operation::LowerInfo *LoweredGraph::getLowerInfo(const OpSequenceIndex &op_seq_index) const
140 {
141   auto itr = _lower_info_map.op_seq.find(op_seq_index);
142   if (itr == _lower_info_map.op_seq.end())
143     return nullptr;
144   return itr->second.get();
145 }
146
147 void LoweredGraph::setLowerInfo(const OpSequenceIndex &op_seq_index,
148                                 std::unique_ptr<operation::LowerInfo> &&lower_info)
149 {
150   _lower_info_map.op_seq.insert(std::make_pair(op_seq_index, std::move(lower_info)));
151 }
152
153 void LoweredGraph::removeLowerInfo(const OpSequenceIndex &op_seq_index)
154 {
155   auto &op_seq_lower_info = _lower_info_map.op_seq;
156   assert(op_seq_lower_info.find(op_seq_index) != op_seq_lower_info.end());
157   for (auto it = op_seq_lower_info.begin(); it != op_seq_lower_info.end(); ++it)
158   {
159     if (it->first == op_seq_index)
160     {
161       op_seq_lower_info.erase(it);
162       break;
163     }
164   }
165 }
166
167 const operand::LowerInfo *LoweredGraph::getLowerInfo(const OperandIndex &index) const
168 {
169   auto itr = _lower_info_map.operand.find(index);
170   if (itr == _lower_info_map.operand.end())
171     return nullptr;
172   return itr->second.get();
173 }
174
175 operand::LowerInfo *LoweredGraph::getLowerInfo(const OperandIndex &index)
176 {
177   auto itr = _lower_info_map.operand.find(index);
178   if (itr == _lower_info_map.operand.end())
179     return nullptr;
180   return itr->second.get();
181 }
182
183 void LoweredGraph::setLowerInfo(const OperandIndex &index,
184                                 std::unique_ptr<operand::LowerInfo> &&lower_info)
185 {
186   _lower_info_map.operand.insert(std::make_pair(index, std::move(lower_info)));
187 }
188
189 void LoweredGraph::removeLowerInfo(const OperandIndex &index)
190 {
191   _lower_info_map.operand.erase(index);
192 }
193
194 void LoweredGraph::iterateTopolOpSeqs(
195     const std::function<void(const OpSequenceIndex &, const OpSequence &)> &fn) const
196 {
197   // Topological Sorting for OpSequences
198   std::vector<OpSequenceIndex> topol_sorted;
199   PostDfsIterator<true>{}.iterateOpSeqs(
200       *this,
201       [&](const OpSequenceIndex &index, const OpSequence &) { topol_sorted.emplace_back(index); });
202   std::reverse(topol_sorted.begin(), topol_sorted.end());
203   for (const auto op_seq_idx : topol_sorted)
204   {
205     const auto &op_seq = _op_seqs.at(op_seq_idx);
206     fn(op_seq_idx, op_seq);
207   }
208 }
209
210 void LoweredGraph::iterateTopolOpSeqs(
211     const std::function<void(const OpSequenceIndex &, OpSequence &)> &fn)
212 {
213   // Topological Sorting for OpSequences
214   std::vector<OpSequenceIndex> topol_sorted;
215   PostDfsIterator<false>{}.iterateOpSeqs(
216       *this, [&](const OpSequenceIndex &index, OpSequence &) { topol_sorted.emplace_back(index); });
217   std::reverse(topol_sorted.begin(), topol_sorted.end());
218   for (const auto op_seq_idx : topol_sorted)
219   {
220     auto &op_seq = _op_seqs.at(op_seq_idx);
221     fn(op_seq_idx, op_seq);
222   }
223 }
224
225 OpSequenceIndex LoweredGraph::appendFreshSingleOpSequence(const OperationIndex &node_index,
226                                                           const Operation &node)
227 {
228   // Create a fresh op_seq with one operation, and append it to op_seqs
229   // Create a fresh op_seq
230   auto op_seq = std::make_unique<OpSequence>(_graph.layout());
231
232   // Add an operation
233   op_seq->appendOperation(node_index);
234
235   // Update input/output
236   op_seq->setOutputs(node.getOutputs());
237   op_seq->setInputs(node.getInputs());
238
239   return _op_seqs.emplace(std::move(op_seq));
240 }
241
242 void LoweredGraph::makeOpSequences(
243     OperandIndexMap<std::unique_ptr<operand::LowerInfo>> &operands_lower_info,
244     const compiler::CompilerOptions &options, const compiler::BackendResolver &backend_resolver)
245 {
246   // if SUBG_MAX_NODE == 0, no limit on nodes of a op_seq
247   const int op_seq_max_node = options.op_seq_max_node;
248   assert(op_seq_max_node >= 0);
249
250   bool is_profiling = options.he_profiling_mode;
251   OpSequence *op_seq = nullptr;
252   OpSequenceIndex op_seq_index;
253
254   // NOTE: The below method appends nodes while making one op_seq if needed. If something better
255   // ways, happy to update this code.
256   PostDfsConstIterator{}.iterate(
257       _graph, [&](const OperationIndex &node_index, const Operation &node) {
258         // LowerInfo for in/output operands
259         auto backend = backend_resolver.getBackend(node_index);
260
261         // Get frontend's layout
262         auto frontend_layout = _graph.layout();
263
264         // The layout of each backend should be set at another place
265         // TODO Change setting layout of each backend at another place
266         auto backend_layout = backend->config()->supportLayout(node, frontend_layout);
267
268         for (auto operand : node.getInputs() | ir::Remove::UNDEFINED)
269         {
270           auto &&lower_info = operands_lower_info.at(operand);
271           lower_info->addUsePermuteFactor(operand::PermuteFactor{backend, backend_layout});
272         }
273         for (auto operand : node.getOutputs())
274         {
275           auto &&lower_info = operands_lower_info.at(operand);
276           lower_info->addDefPermuteFactor(operand::PermuteFactor{backend, backend_layout});
277         }
278
279         bool new_op_seq = (op_seq == nullptr ||
280                            (op_seq_max_node != 0 &&
281                             op_seq->operations().size() >= static_cast<size_t>(op_seq_max_node)));
282
283         // for profiling each op_seq must contain just one node,
284         // so that we can measure a node separately
285         if (new_op_seq || is_profiling ||
286             !mergeable(op_seq_index, node_index, backend_layout, backend_resolver))
287         {
288           auto new_op_seq_index = appendFreshSingleOpSequence(node_index, node);
289
290           // OpSequence LowerInfo
291           setLowerInfo(new_op_seq_index,
292                        std::make_unique<operation::LowerInfo>(backend, backend_layout));
293
294           op_seq_index = new_op_seq_index;
295           op_seq = &(_op_seqs.at(new_op_seq_index));
296
297           VERBOSE(Lower) << "OpSequence#" << op_seq_index.value() << " is created for "
298                          << "NODE#" << node_index.value() << "(" << node.name() << ")" << std::endl;
299         }
300         else
301         {
302           op_seq->appendOperation(node_index);
303           // Set inputs
304           auto new_inputs = node.getInputs();
305           // Add inputs except outputs of the previous node
306           for (auto ind : op_seq->getInputs())
307           {
308             if (!node.getOutputs().contains(ind))
309               new_inputs.append(ind);
310           }
311           op_seq->setInputs(new_inputs);
312
313           VERBOSE(Lower) << "OpSequence#" << op_seq_index.value() << " merges "
314                          << "NODE#" << node_index.value() << "(" << node.name() << ")" << std::endl;
315         }
316       });
317 }
318
319 void LoweredGraph::manipulateLowerInfo(
320     OperandIndexMap<std::unique_ptr<operand::LowerInfo>> &operands_lower_info, bool is_primary)
321 {
322   const auto controlflow_backend = compiler::BackendManager::get().getControlflow();
323
324   // TODO Rather than handling primary graph specially,
325   //      let the permute inserted and remove it later
326   if (is_primary)
327   {
328     // TODO Rather than using NHWC Get frontend layout of this node from IR
329     auto factor = operand::PermuteFactor{controlflow_backend, Layout::NHWC};
330     for (auto index : _graph.getInputs() | ir::Remove::UNDEFINED)
331     {
332       auto &&lower_info = operands_lower_info.at(index);
333       assert(lower_info->def_factors().empty());
334       lower_info->addDefPermuteFactor(factor);
335     }
336     for (auto index : _graph.getOutputs())
337     {
338       auto &&lower_info = operands_lower_info.at(index);
339       lower_info->addUsePermuteFactor(factor);
340     }
341   }
342   else
343   {
344     for (auto index : _graph.getInputs() | ir::Remove::UNDEFINED)
345     {
346       auto &&lower_info = operands_lower_info.at(index);
347       if (!(lower_info->def_factors().size() == 0 && lower_info->use_factors().size() == 0))
348       {
349         // In case of not that Graph's input is not used in any operation and not the graph's
350         // output.
351         // In other words, it is not unused input in Graph.
352         lower_info->addDefPermuteFactor(*lower_info->use_factors().begin());
353       }
354       else
355       {
356         // In case of that an operand is Graph's input and not input or output of any operation
357         lower_info->addDefPermuteFactor(operand::PermuteFactor{
358             controlflow_backend,
359             Layout::NHWC // TODO Get frontend layout of this node from IR
360         });
361       }
362     }
363   }
364   for (auto index : _graph.getOutputs())
365   {
366     auto &&lower_info = operands_lower_info.at(index);
367     if (lower_info->def_factors().size() == 0)
368     {
369       // In case of that an operand is Graph's output and not input or output of any operation
370       lower_info->addDefPermuteFactor(operand::PermuteFactor{
371           controlflow_backend,
372           Layout::NHWC // TODO Get frontend layout of this node from IR
373       });
374     }
375   }
376
377   // Set LowerInfo for each operand from the operand::LowerInfo holder
378   _graph.operands().iterate([&](const OperandIndex &index, Operand &) {
379     setLowerInfo(index, std::move(operands_lower_info[index]));
380   });
381 }
382
383 void LoweredGraph::dumpLowerInfo()
384 {
385   if (::onert::util::logging::ctx.enabled() == false)
386     return;
387
388   std::map<uint32_t, std::string> dumps;
389
390   _graph.operands().iterate([&](const OperandIndex &index, Operand &object) {
391     std::stringstream sstream;
392     if (!getLowerInfo(index)->def_factors().empty() || !getLowerInfo(index)->use_factors().empty())
393     {
394       auto factors_to_string = [](const operand::PermuteFactorSet &factors) {
395         std::string str;
396         for (auto factor : factors)
397         {
398           str += factor.backend()->config()->id();
399           str += "(" + to_string(factor.layout()) + ")";
400           str += " ";
401         }
402         return "{ " + str + "}";
403       };
404
405       auto operation_index_to_string = [](const OperationIndexSet &operations) {
406         std::string str;
407         for (auto op : operations)
408         {
409           str += std::to_string(op.value());
410           str += " ";
411         }
412         return "{ " + str + "}";
413       };
414
415       const auto lower_info = getLowerInfo(index);
416       const auto &shape = object.shape();
417       std::string def_ops = operation_index_to_string(object.getDef());
418       std::string use_ops = operation_index_to_string(object.getUses());
419       std::string def_layouts = factors_to_string(lower_info->def_factors());
420       std::string use_layouts = factors_to_string(lower_info->use_factors());
421       sstream << "Operand #" << index.value() << " LowerInfo" << std::endl;
422       sstream << "  - Shape           : { ";
423       for (auto i = 0; i < shape.rank(); ++i)
424       {
425         sstream << (shape.dim(i)) << " ";
426       }
427       sstream << "}" << std::endl;
428       sstream << "  - Def Operations  : " << def_ops << std::endl;
429       sstream << "  - Use Operations  : " << use_ops << std::endl;
430       sstream << "  - Lower Info" << std::endl;
431       sstream << "    - Def Backends    : " << def_layouts << std::endl;
432       sstream << "    - Use Backends    : " << use_layouts << std::endl;
433     }
434     dumps.emplace(index.value(), sstream.str());
435   });
436
437   for (const auto &e : dumps)
438   {
439     if (!e.second.empty())
440     {
441       VERBOSE(Lower) << e.second;
442     }
443   }
444 }
445
446 bool LoweredGraph::mergeable(const OpSequenceIndex &op_seq_index, const OperationIndex &node_index,
447                              Layout layout, const compiler::BackendResolver &backend_resolver)
448 {
449   // Are they mergeable?
450   // 1. the same backend id and layout?
451   // 2. Is op_seq or node branched?
452   // 3. if 1 is true, the op_seq and a node are connected?
453   const auto &op_seq = _op_seqs.at(op_seq_index);
454   const auto &node = _graph.operations().at(node_index);
455
456   // The same backend id and layout?
457   {
458     const auto op_seq_backend_layout = getLowerInfo(op_seq_index)->layout();
459     const auto &op_seq_backend_id = getLowerInfo(op_seq_index)->backend()->config()->id();
460     const auto &node_backend_id = backend_resolver.getBackend(node_index)->config()->id();
461     VERBOSE(Lower) << "OpSequence#" << op_seq_index.value() << " { " << op_seq_backend_id << "("
462                    << to_string(op_seq_backend_layout) << ") } "
463                    << " NODE#" << node_index.value() << " (" << node.name() << ") { "
464                    << node_backend_id << "(" << to_string(layout) << ") } " << std::endl;
465     if (op_seq_backend_id != node_backend_id || op_seq_backend_layout != layout)
466       return false;
467   }
468
469   // Branched?
470   {
471     std::unordered_set<OperationIndex> branched_set;
472
473     // Check for branching up
474     for (const auto &input : op_seq.getInputs() | Remove::DUPLICATED | ir::Remove::UNDEFINED)
475     {
476       const auto &input_obj = _graph.operands().at(input);
477       for (const auto &def : input_obj.getDef())
478       {
479         branched_set.insert(def);
480         if (branched_set.size() > 1)
481         {
482           return false;
483         }
484       }
485     }
486     branched_set.clear();
487
488     // Check for branching down
489     for (const auto &output : node.getOutputs() | Remove::DUPLICATED)
490     {
491       const auto &output_obj = _graph.operands().at(output);
492       for (const auto &use : output_obj.getUses())
493       {
494         branched_set.insert(use);
495         if (branched_set.size() > 1)
496         {
497           return false;
498         }
499       }
500     }
501   }
502
503   // Connected?
504   // an input of one node is an output of the other node? or vice-versa?
505   {
506     const auto &node_inputs = node.getInputs();
507     const auto &node_outputs = node.getOutputs();
508
509     // op_seq's operations are in order so that we just check the first and the last
510     std::vector<OperationIndex> op_seq_ops{op_seq.operations()[0]};
511     if (op_seq.operations().size() > 1)
512       op_seq_ops.emplace_back(op_seq.operations()[op_seq.operations().size() - 1]);
513
514     for (const auto &n_index : op_seq_ops)
515     {
516       const auto &n = _graph.operations().at(n_index);
517
518       // node's output == op_seq's input?
519       for (const auto input : n.getInputs() | ir::Remove::UNDEFINED)
520       {
521         if (node_outputs.contains(input))
522         {
523           VERBOSE(Lower) << "OpSequence#" << op_seq_index.value() << " 's NODE#" << n_index.value()
524                          << "(" << n.name() << ") is connected to NODE#" << node_index.value()
525                          << "(" << node.name() << ")" << std::endl;
526           return true;
527         }
528       }
529
530       // node's input == op_seq's output?
531       for (const auto output : n.getOutputs())
532       {
533         if (node_inputs.contains(output))
534         {
535           VERBOSE(Lower) << "OpSequence#" << op_seq_index.value() << " 's NODE#" << n_index.value()
536                          << " (" << n.name() << ") is connected to NODE#" << node_index.value()
537                          << std::endl;
538           return true;
539         }
540       }
541     }
542
543     VERBOSE(Lower) << "OpSequence#" << op_seq_index.value() << " is not connected to NODE#"
544                    << node_index.value() << "(" << node.name() << ")" << std::endl;
545   }
546
547   return false;
548 }
549
550 } // namespace ir
551 } // namespace onert