Imported Upstream version 1.12.0
[platform/core/ml/nnfw.git] / runtime / onert / core / src / backend / controlflow / KernelGenerator.cc
1 /*
2  * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16
17 #include "KernelGenerator.h"
18
19 #include <backend/BackendContext.h>
20 #include <util/Utils.h>
21 #include "kernel/IfLayer.h"
22 #include "kernel/WhileLayer.h"
23 #include "kernel/PermuteLayer.h"
24 #include "exec/ExecutorBase.h"
25 #include "exec/FunctionSequence.h"
26
27 namespace onert
28 {
29 namespace backend
30 {
31 namespace controlflow
32 {
33
34 KernelGenerator::KernelGenerator(const ir::Graph &graph, DynamicTensorManager *dyn_tensor_manager,
35                                  const std::shared_ptr<TensorRegistry> &tensor_reg,
36                                  const std::shared_ptr<ExternalContext> &external_context)
37     : _graph{graph}, _dyn_tensor_manager{dyn_tensor_manager}, _tensor_reg{tensor_reg},
38       _tensor_registries{}, _executor_map{nullptr}, _external_context{external_context}
39 {
40   UNUSED_RELEASE(_graph);
41   UNUSED_RELEASE(_tensor_registries);
42   UNUSED_RELEASE(_executor_map);
43 }
44
45 void KernelGenerator::visit(const ir::OpSequence &op_seq)
46 {
47   assert(!_return_fn_seq);
48   assert(_dyn_tensor_manager);
49   assert(_tensor_reg);
50
51   auto dyn_shape_inferer =
52       std::make_unique<exec::DynamicShapeInferer>(_graph.operands(), _tensor_reg);
53
54   _return_fn_seq = std::make_unique<exec::FunctionSequence>();
55
56   // Prepare to handle dynamic tensors later
57   auto dyn_ctx = std::make_shared<exec::FunctionSequence::DynamicTensorCtx>();
58   {
59     dyn_ctx->op_seq = &op_seq;
60     dyn_ctx->operations = &_graph.operations();
61     dyn_ctx->dynamic_shape_inferer = std::move(dyn_shape_inferer);
62     dyn_ctx->dynamic_tensor_manager = _dyn_tensor_manager;
63
64     _return_fn_seq->dynamic_tensor_ctx(dyn_ctx);
65   }
66
67   for (const auto &op_idx : op_seq.operations())
68   {
69     const auto &node = _graph.operations().at(op_idx);
70     node.accept(*this);
71     _return_fn_seq->append(releaseFunction());
72   }
73 }
74
75 void KernelGenerator::visit(const ir::operation::If &node)
76 {
77   const auto then_subg_index = node.param().then_subg_index;
78   const auto else_subg_index = node.param().else_subg_index;
79
80   std::vector<backend::IPortableTensor *> input_tensors;
81   for (const auto input_index : node.getInputs())
82   {
83     auto input_tensor = getPortableTensor(input_index);
84     input_tensors.emplace_back(input_tensor);
85   }
86
87   std::vector<backend::IPortableTensor *> output_tensors;
88   for (const auto output_index : node.getOutputs())
89   {
90     auto output_tensor = getPortableTensor(output_index);
91     output_tensors.emplace_back(output_tensor);
92   }
93
94   // IfLayer just set ExecutorMap instead of then and else executor to avoid complexity of
95   // creating executor recusively
96   const auto cond_tensor = input_tensors.front();
97   input_tensors.erase(input_tensors.begin());
98   auto fn = std::make_unique<::onert::backend::controlflow::kernel::IfLayer>(
99       cond_tensor, input_tensors, output_tensors, then_subg_index, else_subg_index, _executor_map,
100       _external_context);
101
102   _return_fn = std::move(fn);
103 }
104
105 void KernelGenerator::visit(const ir::operation::Permute &node)
106 {
107   const auto output_index{node.getOutputs().at(0)};
108   const auto input_index{node.getInputs().at(0)};
109
110   // Add PermuteLayer
111   std::vector<ITensor *> output_tensors{getTensor(output_index)};
112   std::vector<ITensor *> input_tensors{getTensor(input_index)};
113
114   auto fn =
115       std::make_unique<kernel::PermuteLayer>(input_tensors, output_tensors, _external_context);
116   _return_fn = std::move(fn);
117 }
118
119 void KernelGenerator::visit(const ir::operation::While &node)
120 {
121   const auto cond_subg_index = node.param().cond_subg_index;
122   const auto body_subg_index = node.param().body_subg_index;
123
124   // This op does not support input as a constant, because controlflow backend does not have
125   // TensorBuilder
126   std::vector<backend::IPortableTensor *> input_tensors;
127   for (const auto input_index : node.getInputs())
128   {
129     auto input_tensor = getPortableTensor(input_index);
130     input_tensors.emplace_back(input_tensor);
131   }
132
133   std::vector<backend::IPortableTensor *> output_tensors;
134   for (const auto output_index : node.getOutputs())
135   {
136     auto output_tensor = getPortableTensor(output_index);
137     output_tensors.emplace_back(output_tensor);
138   }
139
140   // WhileLayer just set ExecutorMap instead of cond and body executor to avoid complexity of
141   // creating executor recusively
142   auto fn = std::make_unique<::onert::backend::controlflow::kernel::WhileLayer>(
143       input_tensors, output_tensors, cond_subg_index, body_subg_index, _executor_map,
144       _dyn_tensor_manager->dynamic_mem_mgr().get(), _external_context);
145
146   _return_fn = std::move(fn);
147 }
148
149 backend::ITensor *KernelGenerator::getTensor(const ir::OperandIndex &index)
150 {
151   // get Tensor from all tensor registries (for Permute op)
152   auto ret = _tensor_registries.getITensor(index);
153   assert(ret != nullptr);
154   return ret;
155 }
156
157 backend::IPortableTensor *KernelGenerator::getPortableTensor(const ir::OperandIndex &index)
158 {
159   auto ret = _tensor_reg->getPortableTensor(index);
160   assert(ret != nullptr);
161   return ret;
162 }
163
164 } // namespace controlflow
165 } // namespace backend
166 } // namespace onert