Imported Upstream version 1.25.0
[platform/core/ml/nnfw.git] / onert-micro / luci-interpreter / src / kernels / While.cpp
1 /*
2  * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
3  * Copyright 2019 The TensorFlow Authors. All Rights Reserved.
4  *
5  * Licensed under the Apache License, Version 2.0 (the "License");
6  * you may not use this file except in compliance with the License.
7  * You may obtain a copy of the License at
8  *
9  *    http://www.apache.org/licenses/LICENSE-2.0
10  *
11  * Unless required by applicable law or agreed to in writing, software
12  * distributed under the License is distributed on an "AS IS" BASIS,
13  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14  * See the License for the specific language governing permissions and
15  * limitations under the License.
16  */
17
18 #include "Builders.h"
19 #include "kernels/Utils.h"
20
21 #include <cstring>
22
23 namespace luci_interpreter
24 {
25
26 void configure_kernel_CircleWhile(const circle::Operator *cur_op, BaseRuntimeGraph *runtime_graph)
27 {
28   auto *main_runtime_graph = runtime_graph;
29
30   auto *runtime_module = runtime_graph->getRuntimeModule();
31
32   const auto *options = cur_op->builtin_options_as_WhileOptions();
33   const auto body_subgraph_index = options->body_subgraph_index();
34   const auto cond_subgraph_index = options->cond_subgraph_index();
35
36   auto *cond_runtime_graph = runtime_module->getRuntimeGraphAt(cond_subgraph_index);
37   auto *body_runtime_graph = runtime_module->getRuntimeGraphAt(body_subgraph_index);
38
39   body_runtime_graph->selectOwnSubgraph();
40   const auto body_input_size = body_runtime_graph->getNumOfInputTensors();
41   const auto body_output_size = body_runtime_graph->getNumOfOutputTensors();
42   LUCI_INTERPRETER_CHECK(body_input_size == cur_op->inputs()->size());
43   LUCI_INTERPRETER_CHECK(body_output_size == cur_op->outputs()->size());
44   LUCI_INTERPRETER_CHECK(body_output_size == cur_op->inputs()->size());
45   body_runtime_graph->invalidate();
46   body_runtime_graph->configure(false);
47
48   cond_runtime_graph->selectOwnSubgraph();
49   const auto cond_input_size = cond_runtime_graph->getNumOfInputTensors();
50   const auto cond_output_size = cond_runtime_graph->getNumOfOutputTensors();
51   LUCI_INTERPRETER_CHECK(cond_input_size == cur_op->inputs()->size());
52   LUCI_INTERPRETER_CHECK(cond_output_size == 1);
53   const circle::Tensor *cond_output_tensor = cond_runtime_graph->getOutputTensorByIndex(0);
54   LUCI_INTERPRETER_CHECK(Tensor::element_type(cond_output_tensor) == DataType::BOOL);
55   cond_runtime_graph->invalidate();
56   cond_runtime_graph->configure(false);
57
58   main_runtime_graph->selectOwnSubgraph();
59 }
60
61 void execute_kernel_CircleWhile(const circle::Operator *cur_op, BaseRuntimeGraph *runtime_graph)
62 {
63   auto *main_runtime_graph = runtime_graph;
64   auto *runtime_module = runtime_graph->getRuntimeModule();
65
66   const auto input_size = cur_op->inputs()->size();
67
68   std::vector<uint8_t *> operation_inputs_data(input_size);
69   std::vector<uint8_t *> operation_outputs_data;
70
71   std::vector<int32_t> input_sizes(input_size);
72
73   bool is_inplace = runtime_graph->is_inplace_op(cur_op);
74
75   for (int32_t i = 0; i < input_size; ++i)
76   {
77     const auto op_input_index = cur_op->inputs()->operator[](i);
78     const auto op_output_index = cur_op->outputs()->operator[](i);
79     assert(op_input_index != -1);
80     assert(op_output_index != -1);
81     const auto input = main_runtime_graph->getCircleTensorByIndex(op_input_index);
82     const auto output = main_runtime_graph->getCircleTensorByIndex(op_output_index);
83
84     input_sizes[i] = Tensor::num_elements(input) * size(Tensor::element_type(input));
85
86     auto *input_data = main_runtime_graph->getDataByTensor(input);
87
88     uint8_t *tensor_data = nullptr;
89     if (is_inplace)
90     {
91       if (input_data == nullptr)
92       {
93         tensor_data = new uint8_t[input_sizes[i]];
94         input_data = main_runtime_graph->getConstDataByTensor(input);
95         assert(input_data != nullptr);
96         std::memcpy(tensor_data, input_data, input_sizes[i]);
97       }
98       else
99       {
100         tensor_data = input_data;
101       }
102     }
103     else
104     {
105       if (input_data == nullptr)
106         input_data = main_runtime_graph->getConstDataByTensor(input);
107       assert(input_data != nullptr);
108       tensor_data = main_runtime_graph->getDataByTensor(output);
109       assert(tensor_data != nullptr);
110       std::memcpy(tensor_data, input_data, input_sizes[i]);
111     }
112     assert(tensor_data != nullptr);
113
114     operation_inputs_data[i] = tensor_data;
115   }
116
117   const auto *options = cur_op->builtin_options_as_WhileOptions();
118   const auto body_subgraph_index = options->body_subgraph_index();
119   const auto cond_subgraph_index = options->cond_subgraph_index();
120
121   auto *cond_runtime_graph = runtime_module->getRuntimeGraphAt(cond_subgraph_index);
122   auto *body_runtime_graph = runtime_module->getRuntimeGraphAt(body_subgraph_index);
123
124   do
125   {
126     cond_runtime_graph->selectOwnSubgraph();
127
128     for (int32_t i = 0; i < input_size; ++i)
129       cond_runtime_graph->configureGraphInput(i, operation_inputs_data[i]);
130
131     cond_runtime_graph->execute();
132
133     bool cond_value = (cond_runtime_graph->getOutputDataByIndex(0))[0];
134     if (!cond_value)
135       break;
136
137     body_runtime_graph->selectOwnSubgraph();
138     for (int32_t i = 0; i < input_size; ++i)
139       body_runtime_graph->configureGraphInput(i, operation_inputs_data[i]);
140
141     body_runtime_graph->execute();
142
143     for (int32_t i = 0; i < input_size; ++i)
144     {
145       auto cur_output_body_data = body_runtime_graph->getOutputDataByIndex(i);
146       if (cur_output_body_data == nullptr)
147         continue;
148       std::memcpy(operation_inputs_data[i], cur_output_body_data, input_sizes[i]);
149     }
150   } while (true);
151
152   cond_runtime_graph->resetOutputTensorsData();
153   cond_runtime_graph->clearTensors();
154
155   body_runtime_graph->selectOwnSubgraph();
156   body_runtime_graph->resetOutputTensorsData();
157   body_runtime_graph->clearTensors();
158
159   main_runtime_graph->selectOwnSubgraph();
160
161   if (is_inplace)
162   {
163     for (int32_t i = 0; i < input_size; ++i)
164     {
165       const auto op_input_index = cur_op->inputs()->operator[](i);
166       const auto op_output_index = cur_op->outputs()->operator[](i);
167       assert(op_input_index != -1);
168       assert(op_output_index != -1);
169       const auto input = main_runtime_graph->getCircleTensorByIndex(op_input_index);
170       const auto output = main_runtime_graph->getCircleTensorByIndex(op_output_index);
171
172       if (main_runtime_graph->getDataByTensor(input))
173       {
174         main_runtime_graph->makeInplaceOperation(input, output);
175       }
176       else
177       {
178         main_runtime_graph->setDataToTensor(output, operation_inputs_data[i]);
179       }
180     }
181   }
182 }
183
184 } // namespace luci_interpreter