a0d478603c490503a21cad71a258249ff2d057fe
[platform/core/ml/nnfw.git] / runtime / onert / core / src / backend / controlflow / kernel / WhileLayer.cc
1 /*
2  * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16
17 #include "WhileLayer.h"
18
19 #include <backend/ITensor.h>
20 #include "exec/ExecutorBase.h"
21 #include <misc/polymorphic_downcast.h>
22 #include "PermuteLayer.h"
23
24 namespace onert
25 {
26 namespace backend
27 {
28 namespace controlflow
29 {
30 namespace kernel
31 {
32
33 WhileLayer::WhileLayer(const std::vector<backend::ITensor *> input_tensors,
34                        const std::vector<backend::ITensor *> output_tensors,
35                        const ir::OperandIndexSequence &output_indices, const ir::Graph &graph,
36                        const ir::SubgraphIndex &cond_subg_index,
37                        const ir::SubgraphIndex &body_subg_index, exec::ExecutorMap *executor_map,
38                        const std::shared_ptr<ExternalContext> &external_context)
39     : _cond_subg_index{cond_subg_index}, _body_subg_index{body_subg_index},
40       _output_indices{output_indices}, _graph{graph}, _input_tensors{input_tensors},
41       _output_tensors{output_tensors}, _executor_map{executor_map},
42       _external_context{external_context}
43 {
44   // At this point, executor_map may not have executors of cond subg and body subg
45 }
46
47 void WhileLayer::run()
48 {
49   // Copy "_input_tensors" -> "cond subg inputs"
50   // Run cond subg
51   // Start loop while output of cond subg is ture
52   // // Copy "_input_tensors" -> "body subg inputs" in the first iteration, then copy "body subg
53   // outputs" -> "body subg inputs" in the second or more iterations
54   // // Run body subg
55   // // Copy "body subg outputs" -> "cond subg inputs"
56   // // Run cond subg
57   // If there is no loop copy "_input_tensors" -> "_dst_tensors", else copy "cond subg inputs" ->
58   // "_dst_tensors"
59   auto cond_exec = nnfw::misc::polymorphic_downcast<exec::ExecutorBase *>(
60       _executor_map->at(_cond_subg_index).get());
61   auto body_exec = nnfw::misc::polymorphic_downcast<exec::ExecutorBase *>(
62       _executor_map->at(_body_subg_index).get());
63
64   const auto &cond_graph = cond_exec->graph();
65   const auto &body_graph = body_exec->graph();
66
67   std::vector<backend::ITensor *> input_tensors;
68   std::vector<backend::ITensor *> cond_input_tensors;
69   std::vector<backend::ITensor *> body_input_tensors;
70   std::vector<backend::ITensor *> body_output_tensors;
71   std::vector<backend::ITensor *> output_tensors;
72
73   // Add only used tensors in cond subgraph
74   assert(cond_graph.getInputs().size() == _input_tensors.size());
75   assert(cond_graph.getInputs().size() == cond_exec->getInputTensors().size());
76   for (uint32_t i = 0; i < cond_graph.getInputs().size(); ++i)
77   {
78     const auto &cond_input = cond_graph.operands().at(cond_graph.getInputs().at(i));
79     if (cond_input.getUses().size() > 0)
80     {
81       input_tensors.emplace_back(_input_tensors.at(i));
82       cond_input_tensors.emplace_back(cond_exec->getInputTensors().at(i));
83     }
84   }
85   const auto permute_op_input_to_cond_input =
86       std::make_shared<PermuteLayer>(input_tensors, cond_input_tensors, _external_context);
87
88   // Add only used tensors among outputs of while operation
89   assert(_output_indices.size() == _input_tensors.size());
90   assert(_output_indices.size() == _output_tensors.size());
91   input_tensors.clear();
92   output_tensors.clear();
93   for (size_t i = 0; i < _output_indices.size(); ++i)
94   {
95     const auto &output_index = _output_indices.at(i);
96     const auto &output = _graph.operands().at(output_index);
97     if (output.getUses().size() > 0 || _graph.getOutputs().contains(output_index))
98     {
99       input_tensors.emplace_back(_input_tensors.at(i));
100       output_tensors.emplace_back(_output_tensors.at(i));
101     }
102   }
103   const auto permute_op_input_to_op_output =
104       std::make_shared<PermuteLayer>(input_tensors, output_tensors, _external_context);
105
106   // Add all tensors with unused tensors in body subgraph because unused input tensors will be
107   // copied output tensors in body subgraph
108   assert(_input_tensors.size() == body_exec->getInputTensors().size());
109   input_tensors = _input_tensors;
110   body_input_tensors = body_exec->getInputTensors();
111   const auto permute_op_input_to_body_input =
112       std::make_shared<PermuteLayer>(input_tensors, body_input_tensors, _external_context);
113
114   // Add only used tensors in cond subgraph
115   assert(cond_graph.getInputs().size() == body_exec->getOutputTensors().size());
116   assert(cond_graph.getInputs().size() == cond_exec->getInputTensors().size());
117   body_output_tensors.clear();
118   cond_input_tensors.clear();
119   for (uint32_t i = 0; i < cond_graph.getInputs().size(); ++i)
120   {
121     const auto &cond_input = cond_graph.operands().at(cond_graph.getInputs().at(i));
122     if (cond_input.getUses().size() > 0)
123     {
124       body_output_tensors.emplace_back(body_exec->getOutputTensors().at(i));
125       cond_input_tensors.emplace_back(cond_exec->getInputTensors().at(i));
126     }
127   }
128   const auto permute_body_output_to_cond_input =
129       std::make_shared<PermuteLayer>(body_output_tensors, cond_input_tensors, _external_context);
130
131   // Add only used tensors in body subgraph
132   assert(body_graph.getInputs().size() == body_exec->getOutputTensors().size());
133   assert(body_graph.getInputs().size() == body_exec->getInputTensors().size());
134   body_output_tensors.clear();
135   body_input_tensors.clear();
136   for (uint32_t i = 0; i < body_graph.getInputs().size(); ++i)
137   {
138     const auto &body_input_index = body_graph.getInputs().at(i);
139     const auto &body_input = body_graph.operands().at(body_input_index);
140     if (body_input.getUses().size() > 0 &&
141         !body_exec->graph().getOutputs().contains(body_input_index))
142     {
143       body_output_tensors.emplace_back(body_exec->getOutputTensors().at(i));
144       body_input_tensors.emplace_back(body_exec->getInputTensors().at(i));
145     }
146   }
147   const auto permute_body_output_to_body_input =
148       std::make_shared<PermuteLayer>(body_output_tensors, body_input_tensors, _external_context);
149
150   // Add only used tensors among outputs of while operation
151   assert(_output_indices.size() == body_exec->getOutputTensors().size());
152   assert(_output_indices.size() == _output_tensors.size());
153   body_output_tensors.clear();
154   output_tensors.clear();
155   for (size_t i = 0; i < _output_indices.size(); ++i)
156   {
157     const auto &output_index = _output_indices.at(i);
158     const auto &output = _graph.operands().at(output_index);
159     if (output.getUses().size() > 0 || _graph.getOutputs().contains(output_index))
160     {
161       body_output_tensors.emplace_back(body_exec->getOutputTensors().at(i));
162       output_tensors.emplace_back(_output_tensors.at(i));
163     }
164   }
165   const auto permute_body_output_to_op_output =
166       std::make_shared<PermuteLayer>(body_output_tensors, output_tensors, _external_context);
167
168   // Remove copying of unused tensor
169   permute_op_input_to_cond_input->prepare();
170   permute_op_input_to_op_output->prepare();
171   permute_op_input_to_body_input->prepare();
172   permute_body_output_to_cond_input->prepare();
173   permute_body_output_to_body_input->prepare();
174   permute_body_output_to_op_output->prepare();
175
176   VERBOSE(While) << "Call to $" << _cond_subg_index << " (cond)" << std::endl;
177   cond_exec->execute(_input_tensors, permute_op_input_to_cond_input);
178   VERBOSE(While) << "Return from $" << _cond_subg_index << std::endl;
179
180   assert(cond_exec->getOutputTensors().size() == 1);
181   auto &cond_output_tensor = cond_exec->getOutputTensors().at(0);
182   auto getResultCond = [](backend::ITensor *tensor) -> bool {
183     bool ret = false;
184     tensor->access([&](ITensor &tensor) { ret = *reinterpret_cast<bool *>(tensor.buffer()); });
185     return ret;
186   };
187
188   const auto body_execute_with_op_inputs = [&]() {
189     VERBOSE(While) << "Call to $" << _body_subg_index << " (body)" << std::endl;
190     body_exec->execute(_input_tensors, permute_op_input_to_body_input);
191     VERBOSE(While) << "Return from $" << _body_subg_index << std::endl;
192   };
193
194   const auto body_execute_with_body_outputs = [&]() {
195     VERBOSE(While) << "Call to $" << _body_subg_index << " (body)" << std::endl;
196     body_exec->execute(body_exec->getOutputTensors(), permute_body_output_to_body_input);
197     VERBOSE(While) << "Return from $" << _body_subg_index << std::endl;
198   };
199
200   std::function<void()> body_execute = body_execute_with_op_inputs;
201   const auto cond_execute = [&]() {
202     VERBOSE(While) << "Call to $" << _cond_subg_index << " (cond)" << std::endl;
203     cond_exec->execute(body_exec->getOutputTensors(), permute_body_output_to_cond_input);
204     VERBOSE(While) << "Return from $" << _cond_subg_index << std::endl;
205   };
206   auto permute_to_outputs_fn = permute_op_input_to_op_output;
207
208   // Loop while Cond subgraph's output is true
209   while (getResultCond(cond_output_tensor))
210   {
211     body_execute();
212     cond_execute();
213     body_execute = body_execute_with_body_outputs;
214     permute_to_outputs_fn = permute_body_output_to_op_output;
215   }
216   permute_to_outputs_fn->run();
217 }
218
219 } // namespace kernel
220 } // namespace controlflow
221 } // namespace backend
222 } // namespace onert